File size: 2,827 Bytes
98ce0a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import streamlit as st
import requests
from io import BytesIO
from PIL import Image
import torch
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import itertools
from nltk.corpus import stopwords
import nltk
import easyocr
import numpy as np
import random
nltk.download('stopwords')

# load the model and tokenizer
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
reader = easyocr.Reader(['en'])

# set up Streamlit app
st.set_page_config(layout='wide', page_title='Image Hashtag Recommender')

# define function to extract image features and generate hashtags
def generate_hashtags(image_file):
    # get image and convert to RGB mode
    image = Image.open(image_file).convert('RGB')
    
    # extract image features
    pixel_values = feature_extractor(images=[image], return_tensors="pt").pixel_values
    output_ids = model.generate(pixel_values)

    # decode the model output to text and extract caption words
    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    caption_words = [word.lower() for word in output_text.split() if not word.startswith("#")]
    
    # remove stop words from caption words
    stop_words = set(stopwords.words('english'))
    caption_words = [word for word in caption_words if word not in stop_words]
    
    # use easyocr to extract text from the image
    text = reader.readtext(np.array(image))
    detected_text = " ".join([item[1] for item in text])
    
    # combine caption words and detected text
    all_words = caption_words + detected_text.split()
    
    # generate combinations of words for hashtags
    hashtags = []
    for n in range(1, 4):
        word_combinations = list(itertools.combinations(all_words, n))
        for combination in word_combinations:
            hashtag = "#" + "".join(combination)
            hashtags.append(hashtag)

    # return top 10 hashtags by frequency
    top_hashtags = [tag for tag in sorted(set(hashtags), key=hashtags.count, reverse=True) if tag != "#"]
    return top_hashtags[:10]


# display the Streamlit app
st.title("Image Hashtag Recommender")

image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])

# if the user has submitted an image, generate hashtags
if image_file is not None:
    try:
        hashtags = generate_hashtags(image_file)
        if len(hashtags) > 0:
            st.write("Top 10 hashtags for this image:")
            for tag in hashtags:
                st.write(tag)
        else:
            st.write("No hashtags found for this image.")
    except Exception as e:
        st.write(f"Error: {e}")