MODLI commited on
Commit
e3527ed
·
verified ·
1 Parent(s): 087c8f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -21
app.py CHANGED
@@ -5,11 +5,12 @@ os.environ['TORCH_HOME'] = '/tmp/cache'
5
 
6
  from fastapi import FastAPI, File, UploadFile
7
  from fastapi.middleware.cors import CORSMiddleware
8
- from fastapi.responses import HTMLResponse # ← AJOUT IMPORT MANQUANT
9
  from PIL import Image
10
  import torch
11
  import io
12
  import colorthief
 
13
 
14
  app = FastAPI(title="Fashion Detection API")
15
 
@@ -38,24 +39,21 @@ def load_marqo_model():
38
  model = CLIPModel.from_pretrained(
39
  model_name,
40
  cache_dir="/tmp/cache",
41
- torch_dtype=torch.float16 # Réduit la mémoire
42
  )
43
  processor = CLIPProcessor.from_pretrained(model_name)
44
  print("✅ Modèle Marqo fashionCLIP chargé avec succès !")
45
  except Exception as e:
46
  print(f"❌ Erreur chargement modèle Marqo: {e}")
47
- print("Assurez-vous que les versions dans requirements.txt sont compatibles")
48
 
49
- # Charge le modèle au démarrage (mais en différé)
50
  @app.on_event("startup")
51
  async def startup_event():
52
  import threading
53
- # Charge le modèle dans un thread séparé pour ne pas bloquer le démarrage
54
  thread = threading.Thread(target=load_marqo_model)
55
  thread.daemon = True
56
  thread.start()
57
 
58
- # Catégories fashion simplifiées pour tests
59
  categories = [
60
  "a t-shirt", "a dress", "jeans", "a shirt", "a skirt", "sneakers",
61
  "a handbag", "a jacket", "shorts", "a sweater", "a coat", "high heels"
@@ -75,7 +73,6 @@ def health_check():
75
 
76
  @app.post("/analyze")
77
  async def analyze_image(file: UploadFile = File(...)):
78
- # Vérifier que le modèle est chargé
79
  if model is None or processor is None:
80
  return {"error": "Model not loaded yet. Please wait or check /health endpoint."}
81
 
@@ -84,7 +81,7 @@ async def analyze_image(file: UploadFile = File(...)):
84
  contents = await file.read()
85
  image = Image.open(io.BytesIO(contents)).convert("RGB")
86
 
87
- # Réduire la taille pour économiser la mémoire
88
  image.thumbnail((384, 384))
89
 
90
  # Analyse avec Marqo fashionCLIP
@@ -96,7 +93,6 @@ async def analyze_image(file: UploadFile = File(...)):
96
  truncation=True
97
  )
98
 
99
- # Utiliser le CPU (plus stable sur Hugging Face Spaces free)
100
  with torch.no_grad():
101
  outputs = model(**inputs)
102
 
@@ -107,13 +103,24 @@ async def analyze_image(file: UploadFile = File(...)):
107
  category_name = categories[predicted_class_idx]
108
  confidence_score = probs[0][predicted_class_idx].item()
109
 
110
- # Analyse couleur
111
- img_buffer = io.BytesIO()
112
- image.save(img_buffer, format="PNG")
113
- img_buffer.seek(0)
114
- color_thief = colorthief.ColorThief(img_buffer)
115
- dominant_color = color_thief.get_color(quality=1)
116
- hex_color = '#%02x%02x%02x' % dominant_color
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  return {
119
  "category": category_name,
@@ -129,12 +136,33 @@ async def analyze_image(file: UploadFile = File(...)):
129
  async def test_ui():
130
  return """
131
  <html>
 
 
 
 
 
 
 
 
 
 
132
  <body>
133
- <h1>Test Fashion Detection</h1>
134
- <form action="/analyze" method="post" enctype="multipart/form-data">
135
- <input type="file" name="file">
136
- <input type="submit" value="Analyzer">
137
- </form>
 
 
 
 
 
 
 
 
 
 
 
138
  </body>
139
  </html>
140
  """
 
5
 
6
  from fastapi import FastAPI, File, UploadFile
7
  from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.responses import HTMLResponse
9
  from PIL import Image
10
  import torch
11
  import io
12
  import colorthief
13
+ import tempfile
14
 
15
  app = FastAPI(title="Fashion Detection API")
16
 
 
39
  model = CLIPModel.from_pretrained(
40
  model_name,
41
  cache_dir="/tmp/cache",
42
+ torch_dtype=torch.float16
43
  )
44
  processor = CLIPProcessor.from_pretrained(model_name)
45
  print("✅ Modèle Marqo fashionCLIP chargé avec succès !")
46
  except Exception as e:
47
  print(f"❌ Erreur chargement modèle Marqo: {e}")
 
48
 
 
49
  @app.on_event("startup")
50
  async def startup_event():
51
  import threading
 
52
  thread = threading.Thread(target=load_marqo_model)
53
  thread.daemon = True
54
  thread.start()
55
 
56
+ # Catégories fashion simplifiées
57
  categories = [
58
  "a t-shirt", "a dress", "jeans", "a shirt", "a skirt", "sneakers",
59
  "a handbag", "a jacket", "shorts", "a sweater", "a coat", "high heels"
 
73
 
74
  @app.post("/analyze")
75
  async def analyze_image(file: UploadFile = File(...)):
 
76
  if model is None or processor is None:
77
  return {"error": "Model not loaded yet. Please wait or check /health endpoint."}
78
 
 
81
  contents = await file.read()
82
  image = Image.open(io.BytesIO(contents)).convert("RGB")
83
 
84
+ # Réduire la taille
85
  image.thumbnail((384, 384))
86
 
87
  # Analyse avec Marqo fashionCLIP
 
93
  truncation=True
94
  )
95
 
 
96
  with torch.no_grad():
97
  outputs = model(**inputs)
98
 
 
103
  category_name = categories[predicted_class_idx]
104
  confidence_score = probs[0][predicted_class_idx].item()
105
 
106
+ # --- CORRECTION DE L'ANALYSE COULEUR ---
107
+ try:
108
+ # Sauvegarder l'image dans un fichier temporaire pour ColorThief
109
+ with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
110
+ image.save(tmp, format='JPEG')
111
+ tmp_path = tmp.name
112
+
113
+ # Utiliser ColorThief avec le fichier temporaire
114
+ color_thief = colorthief.ColorThief(tmp_path)
115
+ dominant_color = color_thief.get_color(quality=1)
116
+ hex_color = '#%02x%02x%02x' % dominant_color
117
+
118
+ # Nettoyer le fichier temporaire
119
+ os.unlink(tmp_path)
120
+
121
+ except Exception as color_error:
122
+ print(f"Erreur analyse couleur: {color_error}")
123
+ hex_color = "#000000" # Couleur par défaut
124
 
125
  return {
126
  "category": category_name,
 
136
  async def test_ui():
137
  return """
138
  <html>
139
+ <head>
140
+ <title>Fashion Detection Test</title>
141
+ <style>
142
+ body { font-family: Arial, sans-serif; margin: 40px; }
143
+ .container { max-width: 600px; margin: 0 auto; }
144
+ form { border: 2px dashed #ccc; padding: 30px; text-align: center; }
145
+ input[type="file"] { margin: 10px 0; }
146
+ input[type="submit"] { background: #007bff; color: white; padding: 10px 20px; border: none; cursor: pointer; }
147
+ </style>
148
+ </head>
149
  <body>
150
+ <div class="container">
151
+ <h1>🎨 Test Fashion Detection</h1>
152
+ <form action="/analyze" method="post" enctype="multipart/form-data">
153
+ <h3>Uploader une image de vêtement :</h3>
154
+ <input type="file" name="file" accept="image/*" required>
155
+ <br>
156
+ <input type="submit" value="Analyser l'image 👗">
157
+ </form>
158
+
159
+ <div style="margin-top: 30px; padding: 20px; background: #f8f9fa;">
160
+ <h3>📝 Instructions :</h3>
161
+ <p>• Uploader une image claire d'un vêtement</p>
162
+ <p>• Formats supportés : JPG, PNG, WebP</p>
163
+ <p>• Taille recommandée : moins de 2MB</p>
164
+ </div>
165
+ </div>
166
  </body>
167
  </html>
168
  """