MODLI commited on
Commit
b13493b
·
verified ·
1 Parent(s): 4e1c2b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +220 -51
app.py CHANGED
@@ -4,84 +4,253 @@ os.environ['HF_HOME'] = '/tmp/cache'
4
  os.environ['TORCH_HOME'] = '/tmp/cache'
5
 
6
  from fastapi import FastAPI, File, UploadFile
 
7
  from fastapi.responses import HTMLResponse
8
  from PIL import Image
9
  import torch
10
  import io
 
 
 
11
 
12
- app = FastAPI(title="Fashion Detection API")
13
 
14
- # Modèle et processor (chargement différé)
 
 
 
 
 
 
 
 
 
 
 
15
  model = None
16
  processor = None
17
 
18
- def load_model():
19
  global model, processor
20
  try:
21
- from transformers import CLIPModel, CLIPProcessor
22
- model_name = "Marqo/marqo-fashionCLIP"
23
- model = CLIPModel.from_pretrained(model_name)
24
- processor = CLIPProcessor.from_pretrained(model_name)
25
- print("✅ Modèle chargé!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  except Exception as e:
27
- print(f"❌ Erreur: {e}")
 
 
 
 
 
 
 
 
 
28
 
29
  @app.on_event("startup")
30
- async def startup():
31
  import threading
32
- threading.Thread(target=load_model).start()
33
-
34
- categories = ["t-shirt", "dress", "jeans", "shirt", "skirt", "jacket", "sweater"]
35
 
36
  @app.get("/")
37
- def home():
38
- return {"message": "API running", "status": "OK"}
 
 
 
 
 
 
 
 
 
39
 
40
  @app.post("/analyze")
41
- async def analyze(file: UploadFile = File(...)):
42
- if not model or not processor:
43
- return {"error": "Model loading..."}
44
 
45
  try:
46
- # Lire image
47
- image = Image.open(io.BytesIO(await file.read())).convert("RGB")
48
- image.thumbnail((256, 256))
 
 
 
49
 
50
- # Méthode SIMPLE et FIABLE
51
- results = {}
52
- for category in categories:
53
- inputs = processor(
54
- text=[category],
55
- images=image,
56
- return_tensors="pt",
57
- padding=True,
58
- truncation=True
59
- )
60
- with torch.no_grad():
61
- outputs = model(**inputs)
62
- results[category] = outputs.logits_per_image.item()
63
 
64
- # Trouver le meilleur résultat
65
- best_category = max(results, key=results.get)
66
- confidence = results[best_category]
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  return {
69
- "category": best_category,
70
- "confidence": round(confidence, 4),
71
- "color_hex": "#000000" # Couleur basique pour l'instant
 
 
 
 
72
  }
73
-
74
  except Exception as e:
75
- return {"error": str(e)}
76
 
77
- @app.get("/ui", response_class=HTMLResponse)
78
- def ui():
 
79
  return """
80
- <html><body>
81
- <h1>Fashion Detector</h1>
82
- <form action="/analyze" method="post" enctype="multipart/form-data">
83
- <input type="file" name="file" accept="image/*">
84
- <input type="submit" value="Analyze">
85
- </form>
86
- </body></html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  """
 
4
  os.environ['TORCH_HOME'] = '/tmp/cache'
5
 
6
  from fastapi import FastAPI, File, UploadFile
7
+ from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.responses import HTMLResponse
9
  from PIL import Image
10
  import torch
11
  import io
12
+ import colorthief
13
+ import tempfile
14
+ import numpy as np
15
 
16
+ app = FastAPI(title="Fashion Classification API")
17
 
18
+ # Middleware CORS
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"],
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ expose_headers=["*"]
26
+ )
27
+
28
+ # --- CHARGE LE MODÈLE MARQO FASHIONSIGLIP ---
29
+ print("⚠️ Démarrage du chargement du modèle Marqo-FashionSigLIP...")
30
  model = None
31
  processor = None
32
 
33
+ def load_fashion_model():
34
  global model, processor
35
  try:
36
+ from transformers import AutoModel, AutoProcessor
37
+
38
+ model_name = "Marqo/Marqo-FashionSigLIP-Classification"
39
+
40
+ # Charger le modèle SigLIP spécialisé fashion
41
+ model = AutoModel.from_pretrained(
42
+ model_name,
43
+ cache_dir="/tmp/cache",
44
+ torch_dtype=torch.float16, # Moins de mémoire
45
+ trust_remote_code=True
46
+ )
47
+
48
+ processor = AutoProcessor.from_pretrained(
49
+ model_name,
50
+ trust_remote_code=True
51
+ )
52
+
53
+ print("✅ Modèle Marqo-FashionSigLIP chargé avec succès !")
54
+ print(f"📍 Modèle device: {next(model.parameters()).device}")
55
+
56
  except Exception as e:
57
+ print(f"❌ Erreur chargement modèle: {e}")
58
+ import traceback
59
+ traceback.print_exc()
60
+
61
+ # Catégories de mode pour SigLIP (adaptées au modèle)
62
+ categories = [
63
+ "t-shirt", "dress", "jeans", "shirt", "skirt",
64
+ "sneakers", "handbag", "jacket", "shorts", "sweater",
65
+ "coat", "high heels", "blouse", "boots", "hat"
66
+ ]
67
 
68
  @app.on_event("startup")
69
+ async def startup_event():
70
  import threading
71
+ thread = threading.Thread(target=load_fashion_model)
72
+ thread.daemon = True
73
+ thread.start()
74
 
75
  @app.get("/")
76
+ def read_root():
77
+ return {"message": "Fashion Classification API is running!", "status": "OK"}
78
+
79
+ @app.get("/health")
80
+ def health_check():
81
+ return {
82
+ "model_loaded": model is not None,
83
+ "processor_loaded": processor is not None,
84
+ "status": "ready" if model and processor else "loading",
85
+ "model_name": "Marqo-FashionSigLIP-Classification"
86
+ }
87
 
88
  @app.post("/analyze")
89
+ async def analyze_image(file: UploadFile = File(...)):
90
+ if model is None or processor is None:
91
+ return {"error": "Model not loaded yet. Please check /health endpoint."}
92
 
93
  try:
94
+ # Lire et préparer l'image
95
+ contents = await file.read()
96
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
97
+
98
+ # Redimensionner pour SigLIP
99
+ image = image.resize((384, 384))
100
 
101
+ # --- TRAITEMENT AVEC SIGLIP ---
102
+ # Préparer les inputs
103
+ inputs = processor(
104
+ text=categories,
105
+ images=image,
106
+ return_tensors="pt",
107
+ padding=True,
108
+ truncation=True,
109
+ max_length=64,
110
+ return_overflowing_tokens=False
111
+ )
 
 
112
 
113
+ # Déplacer sur le device du modèle
114
+ device = next(model.parameters()).device
115
+ inputs = {k: v.to(device) for k, v in inputs.items()}
116
 
117
+ # Inférence
118
+ with torch.no_grad():
119
+ outputs = model(**inputs)
120
+
121
+ # SigLIP utilise des logits différents
122
+ logits_per_image = outputs.logits_per_image
123
+
124
+ # Convertir en probabilités
125
+ probs = torch.sigmoid(logits_per_image) # SigLIP utilise sigmoid, pas softmax!
126
+ probs = probs.cpu().numpy()[0]
127
+
128
+ # Trouver la meilleure catégorie
129
+ predicted_idx = np.argmax(probs)
130
+ category_name = categories[predicted_idx]
131
+ confidence_score = float(probs[predicted_idx])
132
+
133
+ # --- ANALYSE COULEUR ---
134
+ try:
135
+ with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
136
+ image.save(tmp, format='JPEG')
137
+ tmp_path = tmp.name
138
+
139
+ color_thief = colorthief.ColorThief(tmp_path)
140
+ dominant_color = color_thief.get_color(quality=1)
141
+ hex_color = '#%02x%02x%02x' % dominant_color
142
+
143
+ os.unlink(tmp_path)
144
+
145
+ except Exception as color_error:
146
+ print(f"⚠️ Erreur analyse couleur: {color_error}")
147
+ hex_color = "#000000"
148
+
149
+ # --- RÉSULTATS DÉTAILLÉS ---
150
+ top_categories = []
151
+ for i, (cat, prob) in enumerate(zip(categories, probs)):
152
+ if prob > 0.1: # Seuil minimal
153
+ top_categories.append({
154
+ "category": cat,
155
+ "score": round(float(prob), 4)
156
+ })
157
+
158
+ # Trier par score décroissant
159
+ top_categories.sort(key=lambda x: x["score"], reverse=True)
160
+ top_5 = top_categories[:5]
161
+
162
  return {
163
+ "top_prediction": {
164
+ "category": category_name,
165
+ "confidence": round(confidence_score, 4),
166
+ "color_hex": hex_color
167
+ },
168
+ "top_categories": top_5,
169
+ "model": "Marqo-FashionSigLIP-Classification"
170
  }
171
+
172
  except Exception as e:
173
+ return {"error": f"Erreur lors de l'analyse: {str(e)}"}
174
 
175
+ # Interface de test
176
+ @app.get("/test-ui", response_class=HTMLResponse)
177
+ async def test_ui():
178
  return """
179
+ <html>
180
+ <head>
181
+ <title>FashionSigLIP Detection</title>
182
+ <style>
183
+ body {
184
+ font-family: Arial, sans-serif;
185
+ margin: 40px;
186
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
187
+ color: white;
188
+ }
189
+ .container {
190
+ max-width: 600px;
191
+ margin: 0 auto;
192
+ background: rgba(255, 255, 255, 0.1);
193
+ padding: 30px;
194
+ border-radius: 15px;
195
+ backdrop-filter: blur(10px);
196
+ }
197
+ form {
198
+ border: 2px dashed rgba(255, 255, 255, 0.3);
199
+ padding: 30px;
200
+ text-align: center;
201
+ margin-bottom: 20px;
202
+ }
203
+ input[type="file"] {
204
+ margin: 15px 0;
205
+ padding: 10px;
206
+ background: rgba(255, 255, 255, 0.2);
207
+ border: none;
208
+ border-radius: 5px;
209
+ color: white;
210
+ }
211
+ input[type="submit"] {
212
+ background: #ff6b6b;
213
+ color: white;
214
+ padding: 12px 25px;
215
+ border: none;
216
+ cursor: pointer;
217
+ border-radius: 25px;
218
+ font-weight: bold;
219
+ transition: background 0.3s;
220
+ }
221
+ input[type="submit"]:hover {
222
+ background: #ee5a52;
223
+ }
224
+ .result {
225
+ margin-top: 20px;
226
+ padding: 20px;
227
+ background: rgba(255, 255, 255, 0.1);
228
+ border-radius: 10px;
229
+ }
230
+ </style>
231
+ </head>
232
+ <body>
233
+ <div class="container">
234
+ <h1>👗 FashionSigLIP Detector</h1>
235
+ <p>Powered by Marqo/Marqo-FashionSigLIP-Classification</p>
236
+
237
+ <form action="/analyze" method="post" enctype="multipart/form-data">
238
+ <h3>Uploader une image de vêtement :</h3>
239
+ <input type="file" name="file" accept="image/*" required>
240
+ <br>
241
+ <input type="submit" value="Analyser la mode ���">
242
+ </form>
243
+
244
+ <div class="result">
245
+ <h3>📊 Résultats :</h3>
246
+ <p>Les résultats apparaîtront ici après analyse...</p>
247
+ </div>
248
+
249
+ <div style="margin-top: 20px; font-size: 12px; opacity: 0.7;">
250
+ <p>Modèle : Marqo-FashionSigLIP-Classification</p>
251
+ <p>Spécialisé dans la classification de vêtements</p>
252
+ </div>
253
+ </div>
254
+ </body>
255
+ </html>
256
  """