MODLI commited on
Commit
fd7501f
·
verified ·
1 Parent(s): 025aa12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -103
app.py CHANGED
@@ -2,106 +2,73 @@ import gradio as gr
2
  from transformers import pipeline
3
  from PIL import Image
4
  import numpy as np
5
- import cv2
6
- import time
7
 
8
- # Import des catégories
9
- from categories import FASHION_CATEGORIES
 
 
 
 
 
 
 
 
10
 
11
- # Initialisation des modèles
12
- print("🔧 Loading models...")
13
- try:
14
- # Modèle de segmentation
15
- seg_pipe = pipeline(
16
- "image-segmentation",
17
- model="mattmdjaga/segformer_b2_clothes",
18
- device=-1 # Force l'utilisation du CPU pour plus de stabilité
19
- )
20
-
21
- # Modèle de classification
22
- class_pipe = pipeline(
23
- "zero-shot-image-classification",
24
- model="openai/clip-vit-base-patch32",
25
- device=-1 # Force l'utilisation du CPU
26
- )
27
- print("✅ Models loaded successfully!")
28
- except Exception as e:
29
- print(f"❌ Error loading models: {e}")
30
- raise e
31
 
32
- def process_image(input_image):
33
- """Traite l'image et retourne les résultats"""
34
  try:
35
  if input_image is None:
36
- return "⚠️ Please upload an image first", None, None
37
 
38
- # Conversion en PIL Image
39
  if isinstance(input_image, np.ndarray):
40
- pil_image = Image.fromarray(input_image)
41
- else:
42
- pil_image = input_image
43
-
44
- # Redimensionnement pour éviter les problèmes de mémoire
45
- pil_image = pil_image.resize((224, 224))
46
-
47
- # Étape 1: Segmentation
48
- print("🔍 Segmenting image...")
49
- segments = seg_pipe(pil_image)
50
-
51
- if not segments:
52
- return "❌ No clothing detected", None, None
53
-
54
- # Trouver le plus grand segment
55
- largest_segment = max(segments, key=lambda x: np.sum(x['mask']))
56
- mask = largest_segment['mask']
57
 
58
- # Étape 2: Extraction du vêtement
59
- mask_np = np.array(mask).astype(np.uint8) * 255
60
- masked_image = cv2.bitwise_and(np.array(pil_image), np.array(pil_image), mask=mask_np)
61
- masked_pil = Image.fromarray(masked_image)
62
 
63
- # Étape 3: Classification
64
- print("📊 Classifying...")
65
  predictions = class_pipe(
66
- masked_pil,
67
  candidate_labels=FASHION_CATEGORIES,
68
  hypothesis_template="This is a photo of {}"
69
  )
70
 
71
- # Formatage des résultats
72
- result_text = "🎯 Classification Results:\n\n"
73
- for i, pred in enumerate(predictions[:3]):
74
  result_text += f"{i+1}. {pred['label']}: {pred['score']*100:.1f}%\n"
75
 
76
- return result_text, masked_pil, pil_image
77
 
78
  except Exception as e:
79
- return f"Error: {str(e)}", None, None
80
 
81
- # Interface Gradio améliorée
82
- with gr.Blocks(
83
- title="Fashion Classifier",
84
- theme=gr.themes.Soft(),
85
- css="""
86
- .gradio-container {max-width: 900px !important;}
87
- """
88
- ) as demo:
89
-
90
  gr.Markdown("""
91
  # 👗 Fashion Category Classifier
92
- Upload a picture of clothing. The AI will detect and classify it.
93
  """)
94
 
95
  with gr.Row():
96
  with gr.Column():
97
  image_input = gr.Image(
98
- label="📤 Upload Image",
99
  type="pil",
100
- height=200
101
  )
102
 
103
- process_btn = gr.Button(
104
- "🚀 Process Image",
105
  variant="primary",
106
  size="lg"
107
  )
@@ -109,52 +76,36 @@ with gr.Blocks(
109
  with gr.Column():
110
  output_text = gr.Textbox(
111
  label="📊 Results",
112
- lines=5,
113
  interactive=False
114
  )
115
-
116
- with gr.Row():
117
- original_output = gr.Image(
118
- label="Original",
119
- type="pil",
120
- height=200,
121
- interactive=False
122
- )
123
- masked_output = gr.Image(
124
- label="Detected Item",
125
- type="pil",
126
- height=200,
127
- interactive=False
128
- )
129
 
130
  # Instructions
131
  gr.Markdown("""
132
- ### 📝 Instructions:
133
- 1. Upload an image of clothing
134
- 2. Click 'Process Image'
135
  3. See the classification results
136
  """)
137
 
138
- # Lier les événements
139
- process_btn.click(
140
- fn=process_image,
141
  inputs=image_input,
142
- outputs=[output_text, masked_output, original_output]
143
  )
144
 
145
- # Exemple de texte
146
- gr.Markdown("""
147
- ### 💡 Tips:
148
- - Use clear, well-lit photos
149
- - Focus on one clothing item at a time
150
- - Avoid busy backgrounds for better results
151
- """)
152
 
153
- # Lancement de l'application
154
  if __name__ == "__main__":
155
  demo.launch(
156
  server_name="0.0.0.0",
157
  server_port=7860,
158
- share=False,
159
- debug=True
160
  )
 
2
  from transformers import pipeline
3
  from PIL import Image
4
  import numpy as np
 
 
5
 
6
+ # Liste des catégories de vêtements
7
+ FASHION_CATEGORIES = [
8
+ "t-shirt", "long sleeve shirt", "short sleeve shirt",
9
+ "sleeveless shirt", "polo shirt", "sweatshirt",
10
+ "hoodie", "sweater", "cardigan", "jacket", "coat",
11
+ "blazer", "dress", "long dress", "short dress",
12
+ "skirt", "long skirt", "short skirt", "jeans",
13
+ "pants", "trousers", "shorts", "leggings",
14
+ "sports shoes", "sneakers", "boots", "heels", "sandals"
15
+ ]
16
 
17
+ # Charger le modèle de classification
18
+ print("Loading classification model...")
19
+ class_pipe = pipeline(
20
+ "zero-shot-image-classification",
21
+ model="openai/clip-vit-base-patch32"
22
+ )
23
+ print("Model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ def classify_image(input_image):
26
+ """Fonction simple pour classifier les images"""
27
  try:
28
  if input_image is None:
29
+ return "Please upload an image first"
30
 
31
+ # Convertir en format PIL si nécessaire
32
  if isinstance(input_image, np.ndarray):
33
+ input_image = Image.fromarray(input_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ # Redimensionner pour de meilleures performances
36
+ input_image = input_image.resize((224, 224))
 
 
37
 
38
+ # Classification
 
39
  predictions = class_pipe(
40
+ input_image,
41
  candidate_labels=FASHION_CATEGORIES,
42
  hypothesis_template="This is a photo of {}"
43
  )
44
 
45
+ # Formater les résultats
46
+ result_text = "👗 Classification Results:\n\n"
47
+ for i, pred in enumerate(predictions[:5]):
48
  result_text += f"{i+1}. {pred['label']}: {pred['score']*100:.1f}%\n"
49
 
50
+ return result_text
51
 
52
  except Exception as e:
53
+ return f"Error: {str(e)}"
54
 
55
+ # Interface Gradio simple
56
+ with gr.Blocks(title="Fashion Classifier", theme=gr.themes.Soft()) as demo:
 
 
 
 
 
 
 
57
  gr.Markdown("""
58
  # 👗 Fashion Category Classifier
59
+ Upload a picture of clothing to classify it.
60
  """)
61
 
62
  with gr.Row():
63
  with gr.Column():
64
  image_input = gr.Image(
65
+ label="📤 Upload Clothing Image",
66
  type="pil",
67
+ height=300
68
  )
69
 
70
+ classify_btn = gr.Button(
71
+ "🔍 Classify Image",
72
  variant="primary",
73
  size="lg"
74
  )
 
76
  with gr.Column():
77
  output_text = gr.Textbox(
78
  label="📊 Results",
79
+ lines=8,
80
  interactive=False
81
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  # Instructions
84
  gr.Markdown("""
85
+ ### 📝 How to use:
86
+ 1. Upload an image of a clothing item
87
+ 2. Click the 'Classify Image' button
88
  3. See the classification results
89
  """)
90
 
91
+ # Lier le bouton à la fonction
92
+ classify_btn.click(
93
+ fn=classify_image,
94
  inputs=image_input,
95
+ outputs=output_text
96
  )
97
 
98
+ # Ajouter aussi le changement sur l'upload
99
+ image_input.upload(
100
+ fn=classify_image,
101
+ inputs=image_input,
102
+ outputs=output_text
103
+ )
 
104
 
105
+ # Lancer l'application
106
  if __name__ == "__main__":
107
  demo.launch(
108
  server_name="0.0.0.0",
109
  server_port=7860,
110
+ share=False
 
111
  )