gkdivya commited on
Commit
19150e5
·
1 Parent(s): 776248a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import clip
4
+ from PIL import Image
5
+
6
+ # Load the CLIP model
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ model, preprocess = clip.load("ViT-B/32", device)
9
+
10
+ # Define apparel categories and attributes
11
+ categories = ["t-shirt", "jeans", "jacket", "dress", "shorts", "sweater", "skirt"]
12
+ attributes = ["striped", "plain", "floral", "polka dot", "denim", "leather", "wool"]
13
+
14
+ # Pre-compute embeddings for categories and attributes
15
+ with torch.no_grad():
16
+ category_embeddings = model.encode_text(clip.tokenize(categories).to(device))
17
+ attribute_embeddings = model.encode_text(clip.tokenize(attributes).to(device))
18
+
19
+ def predict_apparel_and_attributes(image):
20
+ # Process image and compute its embedding
21
+ image_input = preprocess(image).unsqueeze(0).to(device)
22
+ with torch.no_grad():
23
+ image_embedding = model.encode_image(image_input)
24
+
25
+ # Calculate similarity scores
26
+ category_similarities = (image_embedding @ category_embeddings.T).squeeze(0)
27
+ attribute_similarities = (image_embedding @ attribute_embeddings.T).squeeze(0)
28
+
29
+ # Get top category and attributes
30
+ top_category = categories[category_similarities.argmax().item()]
31
+ top_attributes = [attributes[i] for i in attribute_similarities.argsort(descending=True)[:3]] # top 3 attributes
32
+
33
+ return top_category, ", ".join(top_attributes)
34
+
35
+ # Define Gradio interface
36
+ iface = gr.Interface(
37
+ fn=predict_apparel_and_attributes,
38
+ inputs=gr.inputs.Image(label="Upload an apparel image"),
39
+ outputs=[gr.outputs.Textbox(label="Apparel Category"), gr.outputs.Textbox(label="Apparel Attributes")]
40
+ )
41
+ iface.launch()