MODLI commited on
Commit
fe93604
·
verified ·
1 Parent(s): 2a39b6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -182
app.py CHANGED
@@ -9,40 +9,11 @@ from PIL import Image
9
  import torch
10
  import requests
11
  from io import BytesIO
12
- import requests
13
-
14
- @app.post("/classify")
15
- async def classify_fashion(image_data: dict):
16
- try:
17
- image_url = image_data.get("imageUrl")
18
- if not image_url:
19
- raise HTTPException(status_code=400, detail="imageUrl is required")
20
-
21
- # Télécharger l'image
22
- response = requests.get(image_url, timeout=30)
23
- response.raise_for_status()
24
-
25
- # Envoyer directement à l'API Marqo
26
- files = {'file': ('image.jpg', response.content, 'image/jpeg')}
27
- marqo_response = requests.post(
28
- "https://marqo-marqo-fashionsiglip-classification.hf.space/predict",
29
- files=files,
30
- data={'url': ''}
31
- )
32
-
33
- if marqo_response.status_code == 200:
34
- result = marqo_response.json()
35
- return format_marqo_response(result)
36
- else:
37
- raise HTTPException(status_code=500, detail="Marqo API error")
38
-
39
- except Exception as e:
40
- raise HTTPException(status_code=500, detail=f"Classification error: {str(e)}")
41
- import tempfile
42
 
43
- app = FastAPI(title="Fashion Classification API - Marqo")
 
44
 
45
- # Middleware CORS
46
  app.add_middleware(
47
  CORSMiddleware,
48
  allow_origins=["*"],
@@ -52,190 +23,91 @@ app.add_middleware(
52
  expose_headers=["*"]
53
  )
54
 
55
- # Client Gradio pour Marqo
56
- marqo_client = None
 
 
57
 
58
- @app.on_event("startup")
59
- async def startup_event():
60
- global marqo_client
61
  try:
62
- print("🔄 Connexion à Marqo FashionSigLIP...")
63
- marqo_client = Client("Marqo/Marqo-FashionSigLIP-Classification")
64
- print("✅ Connecté à Marqo FashionSigLIP avec succès!")
 
65
  except Exception as e:
66
- print(f"❌ Erreur de connexion à Marqo: {e}")
67
- marqo_client = None
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  @app.get("/")
70
  def read_root():
71
- return {"message": "Marqo Fashion Classification API is running!", "status": "OK"}
72
 
73
  @app.get("/health")
74
  def health_check():
75
  return {
76
- "marqo_connected": marqo_client is not None,
77
- "status": "ready" if marqo_client else "disconnected"
78
  }
79
 
80
  @app.post("/classify")
81
  async def classify_fashion(image_data: dict):
82
  """
83
- Endpoint pour Lovable - utilise Marqo FashionSigLIP
84
  Format attendu: {"imageUrl": "https://example.com/image.jpg"}
85
  """
86
  try:
87
- if not marqo_client:
88
- raise HTTPException(status_code=503, detail="Marqo client not connected")
89
 
90
- # Vérifier et extraire l'URL de l'image
91
  image_url = image_data.get("imageUrl")
92
  if not image_url:
93
  raise HTTPException(status_code=400, detail="imageUrl is required")
94
 
95
- # Télécharger l'image depuis l'URL
96
  response = requests.get(image_url, timeout=30)
97
  response.raise_for_status()
98
 
99
- # Sauvegarder temporairement l'image
100
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
101
- tmp_file.write(response.content)
102
- tmp_file_path = tmp_file.name
103
 
104
- try:
105
- # Appeler l'API Marqo
106
- result = marqo_client.predict(
107
- image=handle_file(tmp_file_path),
108
- url="", # Vide car on utilise l'image uploadée
109
- api_name="/predict"
110
- )
111
-
112
- # Traiter les résultats de Marqo
113
- if result and len(result) >= 2:
114
- # result[1] contient les données de classification
115
- classification_data = result[1]
116
-
117
- # Formater la réponse pour Lovable
118
- return format_marqo_response(classification_data)
119
- else:
120
- raise HTTPException(status_code=500, detail="Format de réponse Marqo inattendu")
121
-
122
- finally:
123
- # Nettoyer le fichier temporaire
124
- os.unlink(tmp_file_path)
125
 
126
  except requests.exceptions.RequestException as e:
127
  raise HTTPException(status_code=400, detail=f"Invalid image URL: {str(e)}")
128
  except Exception as e:
129
  raise HTTPException(status_code=500, detail=f"Classification error: {str(e)}")
130
 
131
- def format_marqo_response(marqo_data):
132
- """
133
- Formate la réponse de Marqo pour Lovable
134
- """
135
- if not marqo_data or 'label' not in marqo_data:
136
- return {
137
- "success": False,
138
- "category": "autre",
139
- "confidence": 0.0,
140
- "colorHex": "#000000",
141
- "originalCategory": "unknown",
142
- "method": "marqo-api"
143
- }
144
-
145
- # Extraire la catégorie principale
146
- main_category = marqo_data['label']
147
- confidence = 0.0
148
-
149
- # Extraire la confiance si disponible
150
- if 'confidences' in marqo_data and marqo_data['confidences']:
151
- for conf in marqo_data['confidences']:
152
- if conf['label'] == main_category:
153
- confidence = conf['confidence']
154
- break
155
-
156
- # Mapping des catégories Marqo -> Français
157
- category_mapping = {
158
- "t-shirt": "haut",
159
- "shirt": "haut",
160
- "sweater": "haut",
161
- "blouse": "haut",
162
- "top": "haut",
163
- "jeans": "pantalon",
164
- "pants": "pantalon",
165
- "trousers": "pantalon",
166
- "leggings": "pantalon",
167
- "dress": "robe",
168
- "gown": "robe",
169
- "sundress": "robe",
170
- "skirt": "jupe",
171
- "shorts": "short",
172
- "bermuda shorts": "short",
173
- "jacket": "veste",
174
- "blazer": "veste",
175
- "leather jacket": "veste",
176
- "coat": "manteau",
177
- "winter coat": "manteau",
178
- "parka": "manteau",
179
- "sneakers": "chaussures",
180
- "high heels": "chaussures",
181
- "boots": "chaussures",
182
- "sandals": "chaussures",
183
- "handbag": "sac",
184
- "purse": "sac",
185
- "backpack": "sac",
186
- "hat": "accessoire",
187
- "sunglasses": "accessoire",
188
- "scarf": "accessoire",
189
- "belt": "accessoire"
190
- }
191
-
192
- # Convertir la catégorie
193
- french_category = category_mapping.get(main_category.lower(), "autre")
194
-
195
- return {
196
- "success": True,
197
- "category": french_category,
198
- "confidence": round(confidence, 4),
199
- "colorHex": "#000000",
200
- "originalCategory": main_category,
201
- "method": "marqo-fashion-siglip"
202
- }
203
-
204
- @app.post("/classify_direct")
205
- async def classify_direct_file(file: UploadFile = File(...)):
206
- """
207
- Endpoint alternatif pour upload direct de fichier
208
- """
209
- try:
210
- if not marqo_client:
211
- raise HTTPException(status_code=503, detail="Marqo client not connected")
212
-
213
- # Sauvegarder le fichier temporairement
214
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
215
- content = await file.read()
216
- tmp_file.write(content)
217
- tmp_file_path = tmp_file.name
218
-
219
- try:
220
- # Appeler Marqo
221
- result = marqo_client.predict(
222
- image=handle_file(tmp_file_path),
223
- url="",
224
- api_name="/predict"
225
- )
226
-
227
- if result and len(result) >= 2:
228
- classification_data = result[1]
229
- return format_marqo_response(classification_data)
230
- else:
231
- raise HTTPException(status_code=500, detail="Format de réponse inattendu")
232
-
233
- finally:
234
- os.unlink(tmp_file_path)
235
-
236
- except Exception as e:
237
- raise HTTPException(status_code=500, detail=f"Classification error: {str(e)}")
238
 
 
239
  if __name__ == "__main__":
240
  import uvicorn
241
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
9
  import torch
10
  import requests
11
  from io import BytesIO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # ==================== CRÉATION DE L'APP EN PREMIER ====================
14
+ app = FastAPI(title="Fashion Classification API")
15
 
16
+ # ==================== MIDDLEWARE EN SECOND ====================
17
  app.add_middleware(
18
  CORSMiddleware,
19
  allow_origins=["*"],
 
23
  expose_headers=["*"]
24
  )
25
 
26
+ # ==================== CONFIGURATION DU MODÈLE ====================
27
+ print("🔄 Chargement du modèle Fashion CLIP...")
28
+ model = None
29
+ processor = None
30
 
31
+ def load_model():
32
+ global model, processor
 
33
  try:
34
+ model_name = "patrickjohncyh/fashion-clip"
35
+ model = CLIPModel.from_pretrained(model_name)
36
+ processor = CLIPProcessor.from_pretrained(model_name)
37
+ print("✅ Modèle chargé avec succès!")
38
  except Exception as e:
39
+ print(f"❌ Erreur de chargement: {e}")
 
40
 
41
+ # ==================== CATÉGORIES ====================
42
+ CATEGORIES_FR = {
43
+ "haut": ["a t-shirt", "a shirt", "a sweater", "a blouse", "a top"],
44
+ "pantalon": ["jeans", "pants", "trousers", "leggings"],
45
+ "robe": ["a dress", "a gown", "a sundress"],
46
+ "jupe": ["a skirt"],
47
+ "short": ["shorts", "bermuda shorts"],
48
+ "veste": ["a jacket", "a blazer", "a leather jacket"],
49
+ "manteau": ["a coat", "a winter coat", "a parka"],
50
+ "chaussures": ["sneakers", "high heels", "boots", "sandals"],
51
+ "sac": ["a handbag", "a purse", "a backpack"],
52
+ "accessoire": ["a hat", "sunglasses", "a scarf", "a belt"],
53
+ "autre": ["clothing", "fashion item"]
54
+ }
55
+
56
+ # ==================== ROUTES ====================
57
  @app.get("/")
58
  def read_root():
59
+ return {"message": "Fashion Classification API is running!", "status": "OK"}
60
 
61
  @app.get("/health")
62
  def health_check():
63
  return {
64
+ "model_loaded": model is not None,
65
+ "status": "ready" if model else "loading"
66
  }
67
 
68
  @app.post("/classify")
69
  async def classify_fashion(image_data: dict):
70
  """
71
+ Endpoint pour Lovable - accepte une URL d'image
72
  Format attendu: {"imageUrl": "https://example.com/image.jpg"}
73
  """
74
  try:
75
+ if not model or not processor:
76
+ raise HTTPException(status_code=503, detail="Model not loaded yet")
77
 
 
78
  image_url = image_data.get("imageUrl")
79
  if not image_url:
80
  raise HTTPException(status_code=400, detail="imageUrl is required")
81
 
82
+ # Télécharger l'image
83
  response = requests.get(image_url, timeout=30)
84
  response.raise_for_status()
85
 
86
+ # Ouvrir et préparer l'image
87
+ image = Image.open(BytesIO(response.content)).convert("RGB")
88
+ image.thumbnail((512, 512))
 
89
 
90
+ # SIMULATION - En attendant de régler les problèmes de modèle
91
+ # Retournez des données factices pour tester
92
+ return {
93
+ "success": True,
94
+ "category": "haut",
95
+ "confidence": 0.92,
96
+ "colorHex": "#FF0000",
97
+ "originalCategory": "a t-shirt",
98
+ "method": "modli-api-test"
99
+ }
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  except requests.exceptions.RequestException as e:
102
  raise HTTPException(status_code=400, detail=f"Invalid image URL: {str(e)}")
103
  except Exception as e:
104
  raise HTTPException(status_code=500, detail=f"Classification error: {str(e)}")
105
 
106
+ # ==================== CHARGEMENT AU DÉMARRAGE ====================
107
+ # Charger le modèle au démarrage (commenté pour l'instant)
108
+ # load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
+ # ==================== POINT D'ENTRÉE ====================
111
  if __name__ == "__main__":
112
  import uvicorn
113
  uvicorn.run(app, host="0.0.0.0", port=7860)