gkdivya commited on
Commit
51000bd
·
1 Parent(s): 02d8d2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -7
app.py CHANGED
@@ -3,6 +3,8 @@ import torch
3
  import clip
4
  from PIL import Image
5
  import numpy as np
 
 
6
 
7
  # Load the CLIP model
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -17,9 +19,27 @@ with torch.no_grad():
17
  category_embeddings = model.encode_text(clip.tokenize(categories).to(device))
18
  attribute_embeddings = model.encode_text(clip.tokenize(attributes).to(device))
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def predict_apparel_and_attributes(image):
21
- pil_image = Image.fromarray((image * 255).astype(np.uint8))
22
- image_input = preprocess(pil_image).unsqueeze(0).to(device)
23
 
24
  with torch.no_grad():
25
  image_embedding = model.encode_image(image_input)
@@ -32,10 +52,30 @@ def predict_apparel_and_attributes(image):
32
  top_category = categories[category_similarities.argmax().item()]
33
  top_attributes = [attributes[i] for i in attribute_similarities.argsort(descending=True)[:3]] # top 3 attributes
34
  print(f"results:{top_category, ','.join(top_attributes)}")
35
- return top_category, ", ".join(top_attributes)
36
 
37
- # Define Gradio interface
38
- demo = gr.Interface(fn=predict_apparel_and_attributes, inputs=gr.Image(label="Upload an apparel image"),outputs="text")
39
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  if __name__ == "__main__":
41
- demo.launch()
 
3
  import clip
4
  from PIL import Image
5
  import numpy as np
6
+ from pytrends.request import TrendReq
7
+ import matplotlib.pyplot as plt
8
 
9
  # Load the CLIP model
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
19
  category_embeddings = model.encode_text(clip.tokenize(categories).to(device))
20
  attribute_embeddings = model.encode_text(clip.tokenize(attributes).to(device))
21
 
22
+ def plot_trends(dataframe):
23
+ plt.figure(figsize=(12,6))
24
+ for column in dataframe.columns:
25
+ if column != 'isPartial':
26
+ plt.plot(dataframe.index, dataframe[column], label=column)
27
+ plt.legend()
28
+ plt.title("Google Trends Over Time")
29
+ plt.xlabel("Time")
30
+ plt.ylabel("Interest")
31
+ plt.grid(True)
32
+ plt.tight_layout()
33
+
34
+ # Save the plot to a temporary file and return its path
35
+ path = "trends_plot.png"
36
+ plt.savefig(path)
37
+ plt.close()
38
+ return path
39
+
40
  def predict_apparel_and_attributes(image):
41
+ #pil_image = Image.fromarray((image * 255).astype(np.uint8))
42
+ image_input = preprocess(image).unsqueeze(0).to(device)
43
 
44
  with torch.no_grad():
45
  image_embedding = model.encode_image(image_input)
 
52
  top_category = categories[category_similarities.argmax().item()]
53
  top_attributes = [attributes[i] for i in attribute_similarities.argsort(descending=True)[:3]] # top 3 attributes
54
  print(f"results:{top_category, ','.join(top_attributes)}")
 
55
 
56
+ # Fetch trends for the top apparel category and attributes
57
+ pytrend = TrendReq()
58
+ keywords = [top_category] + top_attributes
59
+ pytrend.build_payload(kw_list=keywords, timeframe='now 1-H', geo='', gprop='')
60
+ interest_over_time_df = pytrend.interest_over_time()
61
+
62
+ # Plot the trends and get the path to the saved plot
63
+ plot_path = plot_trends(interest_over_time_df)
64
+
65
+ #trends_text = interest_over_time_df.to_string()
66
+
67
+ return top_category, ", ".join(top_attributes), plot_path
68
+
69
+ demo = gr.Interface(
70
+ predict_apparel_and_attributes,
71
+ gr.Image(type="pil"),
72
+ outputs=[ gr.Textbox(label="Apparel Category"),
73
+ gr.Textbox(label="Apparel Attributes"),
74
+ gr.Image(label="Google Trends Plot")], # Output types
75
+ examples=[
76
+ os.path.join(os.path.abspath(''), "images/jeans.jpeg")
77
+ ],
78
+ )
79
+
80
  if __name__ == "__main__":
81
+ demo.launch(debug=True)