MODLI commited on
Commit
5f1bae3
·
verified ·
1 Parent(s): acd685d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -208
app.py CHANGED
@@ -1,18 +1,15 @@
1
  import os
2
- import json
3
- import time
4
  os.environ['HF_HOME'] = '/tmp/cache'
5
  os.environ['TORCH_HOME'] = '/tmp/cache'
6
 
7
- from fastapi import FastAPI, File, UploadFile, HTTPException
 
8
  from fastapi.middleware.cors import CORSMiddleware
9
- from fastapi.responses import HTMLResponse
10
  from PIL import Image
11
  import torch
12
- import io
13
- import colorthief
14
- import tempfile
15
- import numpy as np
16
 
17
  app = FastAPI(title="Fashion Classification API")
18
 
@@ -26,150 +23,84 @@ app.add_middleware(
26
  expose_headers=["*"]
27
  )
28
 
29
- # --- ÉTAT DU MODÈLE ---
30
- print("⚠️ Démarrage du chargement du modèle...")
31
  model = None
32
  processor = None
33
- model_loaded = False
34
- model_error = None
35
-
36
- # Modèles disponibles (garantis de fonctionner)
37
- AVAILABLE_MODELS = {
38
- "siglip-base": {
39
- "name": "google/siglip-base-patch16-224",
40
- "type": "siglip",
41
- "description": "SigLIP base - Excellente précision"
42
- },
43
- "clip-fashion": {
44
- "name": "patrickjohncyh/fashion-clip",
45
- "type": "clip",
46
- "description": "CLIP spécialisé mode"
47
- },
48
- "openclip": {
49
- "name": "laion/CLIP-ViT-B-32-laion2B-s34B-b79K",
50
- "type": "clip",
51
- "description": "OpenCLIP performant"
52
- }
53
- }
54
-
55
- SELECTED_MODEL = "siglip-base" # ← MODÈLE GARANTI
56
 
57
  def load_model():
58
- global model, processor, model_loaded, model_error
59
-
60
  try:
61
- from transformers import AutoModel, AutoProcessor, AutoTokenizer, CLIPModel, CLIPProcessor
62
-
63
- model_info = AVAILABLE_MODELS[SELECTED_MODEL]
64
- model_name = model_info["name"]
65
-
66
- print(f"📦 Chargement du modèle: {model_name}")
67
- print(f"📝 Description: {model_info['description']}")
68
-
69
- if model_info["type"] == "siglip":
70
- # Charger SigLIP
71
- model = AutoModel.from_pretrained(
72
- model_name,
73
- cache_dir="/tmp/cache",
74
- torch_dtype=torch.float16
75
- )
76
- processor = AutoProcessor.from_pretrained(model_name)
77
-
78
- else:
79
- # Charger CLIP
80
- model = CLIPModel.from_pretrained(
81
- model_name,
82
- cache_dir="/tmp/cache",
83
- torch_dtype=torch.float16
84
- )
85
- processor = CLIPProcessor.from_pretrained(model_name)
86
-
87
- print(f"✅ Modèle {model_name} chargé avec succès !")
88
- model_loaded = True
89
-
90
  except Exception as e:
91
- model_error = f"Erreur avec {SELECTED_MODEL}: {str(e)}"
92
- print(f"❌ {model_error}")
93
- # Essayer le modèle suivant en cas d'erreur
94
- try_next_model()
95
 
96
- def try_next_model():
97
- """Essaye le modèle suivant si le premier échoue"""
98
- global SELECTED_MODEL
99
- models = list(AVAILABLE_MODELS.keys())
100
- current_index = models.index(SELECTED_MODEL)
101
-
102
- if current_index < len(models) - 1:
103
- SELECTED_MODEL = models[current_index + 1]
104
- print(f"🔄 Essai du modèle suivant: {SELECTED_MODEL}")
105
- load_model()
106
- else:
107
- print("❌ Tous les modèles ont échoué")
108
-
109
- # Démarrer le chargement
110
  load_model()
111
 
112
- # Catégories de mode adaptées
113
- categories = [
114
- "a t-shirt", "a dress", "jeans", "a shirt", "a skirt",
115
- "sneakers", "a handbag", "a jacket", "shorts", "a sweater",
116
- "a coat", "high heels", "a blouse", "boots", "a hat"
117
- ]
 
 
 
 
 
 
 
 
118
 
119
  @app.get("/")
120
  def read_root():
121
- return {
122
- "message": "Fashion Classification API is running!",
123
- "status": "OK",
124
- "model_loaded": model_loaded,
125
- "current_model": SELECTED_MODEL,
126
- "model_name": AVAILABLE_MODELS[SELECTED_MODEL]["name"] if model_loaded else "loading"
127
- }
128
 
129
  @app.get("/health")
130
  def health_check():
131
  return {
132
- "model_loaded": model_loaded,
133
- "model_error": model_error,
134
- "current_model": SELECTED_MODEL,
135
- "model_details": AVAILABLE_MODELS[SELECTED_MODEL] if model_loaded else None,
136
- "available_models": list(AVAILABLE_MODELS.keys()),
137
- "status": "ready" if model_loaded else "error",
138
- "timestamp": time.time()
139
  }
140
 
141
- @app.post("/analyze")
142
- async def analyze_image(file: UploadFile = File(...)):
143
- if not model_loaded:
144
- raise HTTPException(status_code=423, detail="Model not loaded yet. Please check /health")
145
-
 
 
146
  try:
147
- # Lire et préparer l'image
148
- contents = await file.read()
149
- image = Image.open(io.BytesIO(contents)).convert("RGB")
150
- image = image.resize((224, 224)) # Taille standard
151
 
152
- # Traitement selon le type de modèle
153
- if SELECTED_MODEL == "siglip-base":
154
- # SigLIP processing
155
- inputs = processor(
156
- text=categories,
157
- images=image,
158
- return_tensors="pt",
159
- padding=True,
160
- truncation=True
161
- )
162
-
163
- with torch.no_grad():
164
- outputs = model(**inputs)
165
-
166
- logits_per_image = outputs.logits_per_image
167
- probs = torch.sigmoid(logits_per_image)
168
-
169
- else:
170
- # CLIP processing
 
 
171
  inputs = processor(
172
- text=categories,
173
  images=image,
174
  return_tensors="pt",
175
  padding=True,
@@ -178,86 +109,45 @@ async def analyze_image(file: UploadFile = File(...)):
178
 
179
  with torch.no_grad():
180
  outputs = model(**inputs)
181
-
182
- logits_per_image = outputs.logits_per_image
183
- probs = torch.softmax(logits_per_image, dim=1)
184
 
185
- probs = probs.cpu().numpy()[0]
186
- predicted_idx = np.argmax(probs)
187
- category_name = categories[predicted_idx]
188
- confidence_score = float(probs[predicted_idx])
189
 
190
- # Analyse couleur simplifiée
191
- try:
192
- image_rgb = image.convert('RGB')
193
- small_img = image_rgb.resize((10, 10))
194
- colors = small_img.getcolors(100)
195
- if colors:
196
- dominant_color = max(colors, key=lambda x: x[0])[1]
197
- hex_color = '#%02x%02x%02x' % dominant_color
198
- else:
199
- hex_color = "#000000"
200
- except Exception:
201
- hex_color = "#000000"
202
-
203
  return {
204
- "category": category_name,
205
- "confidence": round(confidence_score, 4),
206
- "color_hex": hex_color,
207
- "model": AVAILABLE_MODELS[SELECTED_MODEL]["name"]
 
 
208
  }
209
-
 
 
210
  except Exception as e:
211
- raise HTTPException(status_code=500, detail=f"Analysis error: {str(e)}")
212
 
213
- @app.get("/test-ui", response_class=HTMLResponse)
214
- async def test_ui():
215
- health_status = health_check()
216
- status_class = "ready" if health_status["model_loaded"] else "error"
217
- status_text = "✅ Prêt" if health_status["model_loaded"] else "❌ Erreur"
218
-
219
- return f"""
220
- <html>
221
- <head>
222
- <title>Fashion Detection</title>
223
- <style>
224
- body {{ font-family: Arial, sans-serif; margin: 40px; }}
225
- .container {{ max-width: 600px; margin: 0 auto; }}
226
- .status {{ padding: 15px; margin: 10px 0; border-radius: 5px; }}
227
- .ready {{ background: #d4edda; color: #155724; }}
228
- .error {{ background: #f8d7da; color: #721c24; }}
229
- .model-info {{ background: #e9ecef; padding: 10px; border-radius: 5px; }}
230
- </style>
231
- </head>
232
- <body>
233
- <div class="container">
234
- <h1>👗 Fashion Detector</h1>
235
-
236
- <div class="status {status_class}">
237
- <b>Statut:</b> {status_text}
238
- </div>
239
-
240
- <div class="model-info">
241
- <b>Modèle:</b> {health_status['current_model']}<br>
242
- <b>Détails:</b> {AVAILABLE_MODELS[health_status['current_model']]['description'] if health_status['model_loaded'] else 'Chargement...'}
243
- </div>
244
-
245
- <form action="/analyze" method="post" enctype="multipart/form-data">
246
- <h3>Uploader une image de vêtement :</h3>
247
- <input type="file" name="file" accept="image/*" required>
248
- <br><br>
249
- <input type="submit" value="Analyser" {"disabled" if not health_status["model_loaded"] else ""}>
250
- </form>
251
-
252
- <div style="margin-top: 20px;">
253
- <h4>Modèles disponibles:</h4>
254
- <ul>
255
- <li><b>siglip-base</b>: SigLIP base - Excellente précision</li>
256
- <li><b>clip-fashion</b>: CLIP spécialisé mode</li>
257
- <li><b>openclip</b>: OpenCLIP performant</li>
258
- </ul>
259
- </div>
260
- </div>
261
- </body>
262
- </html>
263
- """
 
1
  import os
 
 
2
  os.environ['HF_HOME'] = '/tmp/cache'
3
  os.environ['TORCH_HOME'] = '/tmp/cache'
4
 
5
+ import json
6
+ from fastapi import FastAPI, HTTPException
7
  from fastapi.middleware.cors import CORSMiddleware
 
8
  from PIL import Image
9
  import torch
10
+ import requests
11
+ from io import BytesIO
12
+ from transformers import CLIPProcessor, CLIPModel
 
13
 
14
  app = FastAPI(title="Fashion Classification API")
15
 
 
23
  expose_headers=["*"]
24
  )
25
 
26
+ # --- Configuration du modèle ---
27
+ print("🔄 Chargement du modèle Fashion CLIP...")
28
  model = None
29
  processor = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  def load_model():
32
+ global model, processor
 
33
  try:
34
+ model_name = "patrickjohncyh/fashion-clip" # Modèle spécialisé mode
35
+ model = CLIPModel.from_pretrained(model_name)
36
+ processor = CLIPProcessor.from_pretrained(model_name)
37
+ print("✅ Modèle chargé avec succès!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  except Exception as e:
39
+ print(f"Erreur de chargement: {e}")
 
 
 
40
 
41
+ # Charger le modèle au démarrage
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  load_model()
43
 
44
+ # Catégories en français avec mapping vers anglais
45
+ CATEGORIES_FR = {
46
+ "haut": ["a t-shirt", "a shirt", "a sweater", "a blouse", "a top"],
47
+ "pantalon": ["jeans", "pants", "trousers", "leggings"],
48
+ "robe": ["a dress", "a gown", "a sundress"],
49
+ "jupe": ["a skirt"],
50
+ "short": ["shorts", "bermuda shorts"],
51
+ "veste": ["a jacket", "a blazer", "a leather jacket"],
52
+ "manteau": ["a coat", "a winter coat", "a parka"],
53
+ "chaussures": ["sneakers", "high heels", "boots", "sandals"],
54
+ "sac": ["a handbag", "a purse", "a backpack"],
55
+ "accessoire": ["a hat", "sunglasses", "a scarf", "a belt"],
56
+ "autre": ["clothing", "fashion item"]
57
+ }
58
 
59
  @app.get("/")
60
  def read_root():
61
+ return {"message": "Fashion Classification API is running!", "status": "OK"}
 
 
 
 
 
 
62
 
63
  @app.get("/health")
64
  def health_check():
65
  return {
66
+ "model_loaded": model is not None,
67
+ "status": "ready" if model else "loading"
 
 
 
 
 
68
  }
69
 
70
+ # --- NOUVELLE ROUTE POUR LOVABLE ---
71
+ @app.post("/classify")
72
+ async def classify_fashion(image_data: dict):
73
+ """
74
+ Endpoint pour Lovable - accepte une URL d'image
75
+ Format attendu: {"imageUrl": "https://example.com/image.jpg"}
76
+ """
77
  try:
78
+ if not model or not processor:
79
+ raise HTTPException(status_code=503, detail="Model not loaded yet")
 
 
80
 
81
+ # Vérifier et extraire l'URL de l'image
82
+ image_url = image_data.get("imageUrl")
83
+ if not image_url:
84
+ raise HTTPException(status_code=400, detail="imageUrl is required")
85
+
86
+ # Télécharger l'image depuis l'URL
87
+ response = requests.get(image_url)
88
+ response.raise_for_status()
89
+
90
+ # Ouvrir et préparer l'image
91
+ image = Image.open(BytesIO(response.content)).convert("RGB")
92
+ image.thumbnail((512, 512)) # Réduire la taille pour plus d'efficacité
93
+
94
+ # Préparer toutes les catégories en anglais
95
+ all_english_categories = []
96
+ for fr_cat, en_categories in CATEGORIES_FR.items():
97
+ all_english_categories.extend(en_categories)
98
+
99
+ # Traitement par lots pour éviter les problèmes de padding
100
+ results = {}
101
+ for category in all_english_categories:
102
  inputs = processor(
103
+ text=[category],
104
  images=image,
105
  return_tensors="pt",
106
  padding=True,
 
109
 
110
  with torch.no_grad():
111
  outputs = model(**inputs)
112
+ results[category] = outputs.logits_per_image.item()
 
 
113
 
114
+ # Trouver la catégorie anglaise avec le meilleur score
115
+ best_english_category = max(results, key=results.get)
116
+ confidence = results[best_english_category]
 
117
 
118
+ # Convertir en catégorie française
119
+ best_french_category = "autre"
120
+ for fr_cat, en_categories in CATEGORIES_FR.items():
121
+ if best_english_category in en_categories:
122
+ best_french_category = fr_cat
123
+ break
124
+
125
+ # Normaliser la confiance entre 0 et 1
126
+ confidence_normalized = 1 / (1 + torch.exp(torch.tensor(-confidence))).item()
127
+
128
+ # Format de réponse exact pour Lovable
 
 
129
  return {
130
+ "success": True,
131
+ "category": best_french_category,
132
+ "confidence": round(confidence_normalized, 4),
133
+ "colorHex": "#000000",
134
+ "originalCategory": best_english_category,
135
+ "method": "modli-api"
136
  }
137
+
138
+ except requests.exceptions.RequestException as e:
139
+ raise HTTPException(status_code=400, detail=f"Invalid image URL: {str(e)}")
140
  except Exception as e:
141
+ raise HTTPException(status_code=500, detail=f"Classification error: {str(e)}")
142
 
143
+ # Ancienne route pour compatibilité (si nécessaire)
144
+ @app.post("/analyze")
145
+ async def analyze_image_old():
146
+ return {"error": "Use /classify endpoint instead"}
147
+
148
+ # Route de test
149
+ @app.get("/test")
150
+ async def test_endpoint():
151
+ """Endpoint de test avec une image exemple"""
152
+ test_url = "https://images.unsplash.com/photo-1521572163474-6864f9cf17ab?w=400"
153
+ return await classify_fashion({"imageUrl": test_url})