Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import clip | |
| from PIL import Image | |
| import numpy as np | |
| from pytrends.request import TrendReq | |
| import matplotlib.pyplot as plt | |
| import os | |
| # Load the CLIP model | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model, preprocess = clip.load("ViT-B/32", device) | |
| # Define apparel categories and attributes | |
| categories = ["t-shirt", "jeans", "jacket", "dress", "shorts", "sweater", "skirt"] | |
| attributes = ["striped", "plain", "floral", "polka dot", "denim", "leather", "wool"] | |
| # Pre-compute embeddings for categories and attributes | |
| with torch.no_grad(): | |
| category_embeddings = model.encode_text(clip.tokenize(categories).to(device)) | |
| attribute_embeddings = model.encode_text(clip.tokenize(attributes).to(device)) | |
| def plot_trends(dataframe): | |
| plt.figure(figsize=(12,6)) | |
| for column in dataframe.columns: | |
| if column != 'isPartial': | |
| plt.plot(dataframe.index, dataframe[column], label=column) | |
| plt.legend() | |
| plt.title("Google Trends Over Time") | |
| plt.xlabel("Time") | |
| plt.ylabel("Interest") | |
| plt.grid(True) | |
| plt.tight_layout() | |
| # Save the plot to a temporary file and return its path | |
| path = "trends_plot.png" | |
| plt.savefig(path) | |
| plt.close() | |
| return path | |
| def predict_apparel_and_attributes(image): | |
| #pil_image = Image.fromarray((image * 255).astype(np.uint8)) | |
| image_input = preprocess(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| image_embedding = model.encode_image(image_input) | |
| # Calculate similarity scores | |
| category_similarities = (image_embedding @ category_embeddings.T).squeeze(0) | |
| attribute_similarities = (image_embedding @ attribute_embeddings.T).squeeze(0) | |
| # Get top category and attributes | |
| top_category = categories[category_similarities.argmax().item()] | |
| top_attributes = [attributes[i] for i in attribute_similarities.argsort(descending=True)[:3]] # top 3 attributes | |
| print(f"results:{top_category, ','.join(top_attributes)}") | |
| # Fetch trends for the top apparel category and attributes | |
| pytrend = TrendReq() | |
| keywords = [top_category] + top_attributes | |
| pytrend.build_payload(kw_list=keywords, timeframe='now 1-H', geo='', gprop='') | |
| interest_over_time_df = pytrend.interest_over_time() | |
| # Plot the trends and get the path to the saved plot | |
| plot_path = plot_trends(interest_over_time_df) | |
| #trends_text = interest_over_time_df.to_string() | |
| return top_category, ", ".join(top_attributes), plot_path | |
| demo = gr.Interface( | |
| predict_apparel_and_attributes, | |
| gr.Image(type="pil"), | |
| outputs=[ gr.Textbox(label="Apparel Category"), | |
| gr.Textbox(label="Apparel Attributes"), | |
| gr.Image(label="Google Trends Plot")], # Output types | |
| examples=[ | |
| os.path.join(os.path.abspath(''), "images/jeans.jpeg"), | |
| os.path.join(os.path.abspath(''), "images/jacket.jpg"), | |
| os.path.join(os.path.abspath(''), "images/tshirt.png") | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |