MODLI commited on
Commit
4c5319b
·
verified ·
1 Parent(s): f1f2131

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -23
app.py CHANGED
@@ -2,34 +2,159 @@ import gradio as gr
2
  from transformers import pipeline
3
  from PIL import Image
4
  import numpy as np
 
 
5
 
6
- # Liste des catégories
7
- CATEGORIES = ["t-shirt", "dress", "jeans", "jacket", "skirt", "shoes"]
8
 
9
- # Charger le modèle (version simplifiée)
10
- model = pipeline("zero-shot-image-classification", model="openai/clip-vit-base-patch32")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- def classify_simple(image):
13
- """Version simplifiée sans segmentation"""
14
- if image is None:
15
- return "Please upload an image"
 
 
 
 
16
 
17
- results = model(image, candidate_labels=CATEGORIES)
 
 
 
18
 
19
- output = "Results:\n"
20
- for result in results:
21
- output += f"{result['label']}: {result['score']*100:.1f}%\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- return output
24
-
25
- # Interface minimaliste
26
- iface = gr.Interface(
27
- fn=classify_simple,
28
- inputs=gr.Image(type="pil"),
29
- outputs=gr.Textbox(),
30
- title="Simple Fashion Classifier",
31
- description="Upload a clothing item to classify it"
32
- )
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
34
  if __name__ == "__main__":
35
- iface.launch()
 
 
 
 
 
 
2
  from transformers import pipeline
3
  from PIL import Image
4
  import numpy as np
5
+ import cv2
6
+ import time
7
 
8
+ # Import des catégories
9
+ from categories import FASHION_CATEGORIES
10
 
11
+ # Initialisation des modèles
12
+ print("🔧 Loading models...")
13
+ try:
14
+ # Modèle de segmentation
15
+ seg_pipe = pipeline(
16
+ "image-segmentation",
17
+ model="mattmdjaga/segformer_b2_clothes",
18
+ device=-1 # Force l'utilisation du CPU pour plus de stabilité
19
+ )
20
+
21
+ # Modèle de classification
22
+ class_pipe = pipeline(
23
+ "zero-shot-image-classification",
24
+ model="openai/clip-vit-base-patch32",
25
+ device=-1 # Force l'utilisation du CPU
26
+ )
27
+ print("✅ Models loaded successfully!")
28
+ except Exception as e:
29
+ print(f"❌ Error loading models: {e}")
30
+ raise e
31
+
32
+ def process_image(input_image):
33
+ """Traite l'image et retourne les résultats"""
34
+ try:
35
+ if input_image is None:
36
+ return "⚠️ Please upload an image first", None, None
37
+
38
+ # Conversion en PIL Image
39
+ if isinstance(input_image, np.ndarray):
40
+ pil_image = Image.fromarray(input_image)
41
+ else:
42
+ pil_image = input_image
43
+
44
+ # Redimensionnement pour éviter les problèmes de mémoire
45
+ pil_image = pil_image.resize((224, 224))
46
+
47
+ # Étape 1: Segmentation
48
+ print("🔍 Segmenting image...")
49
+ segments = seg_pipe(pil_image)
50
+
51
+ if not segments:
52
+ return "❌ No clothing detected", None, None
53
+
54
+ # Trouver le plus grand segment
55
+ largest_segment = max(segments, key=lambda x: np.sum(x['mask']))
56
+ mask = largest_segment['mask']
57
+
58
+ # Étape 2: Extraction du vêtement
59
+ mask_np = np.array(mask).astype(np.uint8) * 255
60
+ masked_image = cv2.bitwise_and(np.array(pil_image), np.array(pil_image), mask=mask_np)
61
+ masked_pil = Image.fromarray(masked_image)
62
+
63
+ # Étape 3: Classification
64
+ print("📊 Classifying...")
65
+ predictions = class_pipe(
66
+ masked_pil,
67
+ candidate_labels=FASHION_CATEGORIES,
68
+ hypothesis_template="This is a photo of {}"
69
+ )
70
+
71
+ # Formatage des résultats
72
+ result_text = "🎯 Classification Results:\n\n"
73
+ for i, pred in enumerate(predictions[:3]):
74
+ result_text += f"{i+1}. {pred['label']}: {pred['score']*100:.1f}%\n"
75
+
76
+ return result_text, masked_pil, pil_image
77
+
78
+ except Exception as e:
79
+ return f"❌ Error: {str(e)}", None, None
80
 
81
+ # Interface Gradio améliorée
82
+ with gr.Blocks(
83
+ title="Fashion Classifier",
84
+ theme=gr.themes.Soft(),
85
+ css="""
86
+ .gradio-container {max-width: 900px !important;}
87
+ """
88
+ ) as demo:
89
 
90
+ gr.Markdown("""
91
+ # 👗 Fashion Category Classifier
92
+ Upload a picture of clothing. The AI will detect and classify it.
93
+ """)
94
 
95
+ with gr.Row():
96
+ with gr.Column():
97
+ image_input = gr.Image(
98
+ label="📤 Upload Image",
99
+ type="pil",
100
+ height=200
101
+ )
102
+
103
+ process_btn = gr.Button(
104
+ "🚀 Process Image",
105
+ variant="primary",
106
+ size="lg"
107
+ )
108
+
109
+ with gr.Column():
110
+ output_text = gr.Textbox(
111
+ label="📊 Results",
112
+ lines=5,
113
+ interactive=False
114
+ )
115
+
116
+ with gr.Row():
117
+ original_output = gr.Image(
118
+ label="Original",
119
+ type="pil",
120
+ height=200,
121
+ interactive=False
122
+ )
123
+ masked_output = gr.Image(
124
+ label="Detected Item",
125
+ type="pil",
126
+ height=200,
127
+ interactive=False
128
+ )
129
 
130
+ # Instructions
131
+ gr.Markdown("""
132
+ ### 📝 Instructions:
133
+ 1. Upload an image of clothing
134
+ 2. Click 'Process Image'
135
+ 3. See the classification results
136
+ """)
137
+
138
+ # Lier les événements
139
+ process_btn.click(
140
+ fn=process_image,
141
+ inputs=image_input,
142
+ outputs=[output_text, masked_output, original_output]
143
+ )
144
+
145
+ # Exemple de texte
146
+ gr.Markdown("""
147
+ ### 💡 Tips:
148
+ - Use clear, well-lit photos
149
+ - Focus on one clothing item at a time
150
+ - Avoid busy backgrounds for better results
151
+ """)
152
 
153
+ # Lancement de l'application
154
  if __name__ == "__main__":
155
+ demo.launch(
156
+ server_name="0.0.0.0",
157
+ server_port=7860,
158
+ share=False,
159
+ debug=True
160
+ )