File size: 3,037 Bytes
19150e5
 
 
 
795f603
51000bd
 
5fdae94
19150e5
 
 
 
 
 
 
 
 
 
 
 
 
 
51000bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19150e5
51000bd
 
5b443c1
19150e5
 
 
 
 
 
 
 
 
 
a320a5a
19150e5
51000bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f44d91b
 
 
51000bd
 
 
02d8d2c
51000bd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
import torch
import clip
from PIL import Image
import numpy as np
from pytrends.request import TrendReq
import matplotlib.pyplot as plt
import os

# Load the CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device)

# Define apparel categories and attributes
categories = ["t-shirt", "jeans", "jacket", "dress", "shorts", "sweater", "skirt"]
attributes = ["striped", "plain", "floral", "polka dot", "denim", "leather", "wool"]

# Pre-compute embeddings for categories and attributes
with torch.no_grad():
    category_embeddings = model.encode_text(clip.tokenize(categories).to(device))
    attribute_embeddings = model.encode_text(clip.tokenize(attributes).to(device))

def plot_trends(dataframe):
    plt.figure(figsize=(12,6))
    for column in dataframe.columns:
        if column != 'isPartial':
            plt.plot(dataframe.index, dataframe[column], label=column)
    plt.legend()
    plt.title("Google Trends Over Time")
    plt.xlabel("Time")
    plt.ylabel("Interest")
    plt.grid(True)
    plt.tight_layout()

    # Save the plot to a temporary file and return its path
    path = "trends_plot.png"
    plt.savefig(path)
    plt.close()
    return path
    
def predict_apparel_and_attributes(image):
    #pil_image = Image.fromarray((image * 255).astype(np.uint8))
    image_input = preprocess(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        image_embedding = model.encode_image(image_input)
    
    # Calculate similarity scores
    category_similarities = (image_embedding @ category_embeddings.T).squeeze(0)
    attribute_similarities = (image_embedding @ attribute_embeddings.T).squeeze(0)
    
    # Get top category and attributes
    top_category = categories[category_similarities.argmax().item()]
    top_attributes = [attributes[i] for i in attribute_similarities.argsort(descending=True)[:3]]  # top 3 attributes
    print(f"results:{top_category, ','.join(top_attributes)}")

    # Fetch trends for the top apparel category and attributes
    pytrend = TrendReq()
    keywords = [top_category] + top_attributes
    pytrend.build_payload(kw_list=keywords, timeframe='now 1-H', geo='', gprop='')
    interest_over_time_df = pytrend.interest_over_time()

    # Plot the trends and get the path to the saved plot
    plot_path = plot_trends(interest_over_time_df)

    #trends_text = interest_over_time_df.to_string()

    return top_category, ", ".join(top_attributes), plot_path

demo = gr.Interface(
    predict_apparel_and_attributes,
    gr.Image(type="pil"),
    outputs=[ gr.Textbox(label="Apparel Category"), 
              gr.Textbox(label="Apparel Attributes"),
              gr.Image(label="Google Trends Plot")],  # Output types
    examples=[
        os.path.join(os.path.abspath(''), "images/jeans.jpeg"),
        os.path.join(os.path.abspath(''), "images/jacket.jpg"),
        os.path.join(os.path.abspath(''), "images/tshirt.png")
    ],
)

if __name__ == "__main__":
    demo.launch(debug=True)