MODLI commited on
Commit
4d5ff3f
·
verified ·
1 Parent(s): b6eb828

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -52
app.py CHANGED
@@ -27,20 +27,25 @@ app.add_middleware(
27
  # --- CHARGE LE MODÈLE MARQO FASHIONCLIP ---
28
  print("⚠️ Démarrage du chargement du modèle Marqo fashionCLIP...")
29
  model = None
 
30
  processor = None
31
 
32
  def load_marqo_model():
33
- global model, processor
34
  try:
35
- from transformers import CLIPProcessor, CLIPModel
36
 
37
  model_name = "Marqo/marqo-fashionCLIP"
 
 
38
  model = CLIPModel.from_pretrained(
39
  model_name,
40
  cache_dir="/tmp/cache",
41
  torch_dtype=torch.float16
42
  )
43
- processor = CLIPProcessor.from_pretrained(model_name)
 
 
44
  print("✅ Modèle Marqo fashionCLIP chargé avec succès !")
45
  except Exception as e:
46
  print(f"❌ Erreur chargement modèle Marqo: {e}")
@@ -52,10 +57,10 @@ async def startup_event():
52
  thread.daemon = True
53
  thread.start()
54
 
55
- # Catégories fashion (textes plus courts et uniformes)
56
  categories = [
57
- "t-shirt", "dress", "jeans", "shirt", "skirt", "sneakers",
58
- "handbag", "jacket", "shorts", "sweater", "coat", "heels"
59
  ]
60
 
61
  @app.get("/")
@@ -66,13 +71,14 @@ def read_root():
66
  def health_check():
67
  return {
68
  "model_loaded": model is not None,
 
69
  "processor_loaded": processor is not None,
70
- "status": "ready" if model and processor else "loading"
71
  }
72
 
73
  @app.post("/analyze")
74
  async def analyze_image(file: UploadFile = File(...)):
75
- if model is None or processor is None:
76
  return {"error": "Model not loaded yet. Please wait or check /health endpoint."}
77
 
78
  try:
@@ -83,32 +89,43 @@ async def analyze_image(file: UploadFile = File(...)):
83
  # Réduire la taille
84
  image.thumbnail((384, 384))
85
 
86
- # --- SOLUTION DÉFINITIVE ---
87
- # Traiter chaque catégorie SÉPARÉMENT pour éviter les problèmes de padding
88
- similarities = []
 
 
 
89
 
90
  for category in categories:
91
- # Préparer les inputs pour UNE catégorie à la fois
92
- inputs = processor(
93
- text=[category], # Une seule catégorie
94
- images=image,
95
- return_tensors="pt",
96
- padding=True, # Padding pour une seule phrase
97
- truncation=True
98
  )
99
 
100
- # Déplacer sur le device du modèle
101
- device = next(model.parameters()).device
102
- inputs = {k: v.to(device) for k, v in inputs.items()}
103
-
104
  with torch.no_grad():
105
- outputs = model(**inputs)
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- # Récupérer le score de similarité
108
- similarity_score = outputs.logits_per_image.item()
109
- similarities.append(similarity_score)
110
 
111
- # Convertir en tensor et calculer les probabilités
112
  similarities_tensor = torch.tensor(similarities)
113
  probs = torch.nn.functional.softmax(similarities_tensor, dim=0)
114
 
@@ -142,40 +159,25 @@ async def analyze_image(file: UploadFile = File(...)):
142
  except Exception as e:
143
  return {"error": f"Erreur lors de l'analyse: {str(e)}"}
144
 
145
- # Interface de test
146
  @app.get("/test-ui", response_class=HTMLResponse)
147
  async def test_ui():
148
  return """
149
  <html>
150
  <head>
151
- <title>Fashion Detection Test</title>
152
  <style>
153
- body { font-family: Arial, sans-serif; margin: 40px; }
154
- .container { max-width: 600px; margin: 0 auto; }
155
- form { border: 2px dashed #ccc; padding: 30px; text-align: center; }
156
- input[type="file"] { margin: 10px 0; }
157
- input[type="submit"] {
158
- background: #007bff; color: white; padding: 10px 20px;
159
- border: none; cursor: pointer; border-radius: 5px;
160
- }
161
- .result { margin-top: 20px; padding: 20px; background: #f0f8ff; }
162
  </style>
163
  </head>
164
  <body>
165
- <div class="container">
166
- <h1>🎨 Fashion Detection AI</h1>
167
- <form action="/analyze" method="post" enctype="multipart/form-data">
168
- <h3>Uploader une image de vêtement :</h3>
169
- <input type="file" name="file" accept="image/*" required>
170
- <br><br>
171
- <input type="submit" value="Analyser l'image 👗">
172
- </form>
173
-
174
- <div class="result">
175
- <h3>📋 Résultat de l'analyse :</h3>
176
- <p>Attendez l'upload et le traitement de l'image...</p>
177
- </div>
178
- </div>
179
  </body>
180
  </html>
181
  """
 
27
  # --- CHARGE LE MODÈLE MARQO FASHIONCLIP ---
28
  print("⚠️ Démarrage du chargement du modèle Marqo fashionCLIP...")
29
  model = None
30
+ tokenizer = None
31
  processor = None
32
 
33
  def load_marqo_model():
34
+ global model, tokenizer, processor
35
  try:
36
+ from transformers import CLIPModel, CLIPTokenizer, CLIPImageProcessor
37
 
38
  model_name = "Marqo/marqo-fashionCLIP"
39
+
40
+ # Charger les composants séparément
41
  model = CLIPModel.from_pretrained(
42
  model_name,
43
  cache_dir="/tmp/cache",
44
  torch_dtype=torch.float16
45
  )
46
+ tokenizer = CLIPTokenizer.from_pretrained(model_name)
47
+ processor = CLIPImageProcessor.from_pretrained(model_name)
48
+
49
  print("✅ Modèle Marqo fashionCLIP chargé avec succès !")
50
  except Exception as e:
51
  print(f"❌ Erreur chargement modèle Marqo: {e}")
 
57
  thread.daemon = True
58
  thread.start()
59
 
60
+ # Catégories fashion (textes courts)
61
  categories = [
62
+ "a t-shirt", "a dress", "jeans", "a shirt", "a skirt",
63
+ "sneakers", "a handbag", "a jacket", "shorts", "a sweater"
64
  ]
65
 
66
  @app.get("/")
 
71
  def health_check():
72
  return {
73
  "model_loaded": model is not None,
74
+ "tokenizer_loaded": tokenizer is not None,
75
  "processor_loaded": processor is not None,
76
+ "status": "ready" if all([model, tokenizer, processor]) else "loading"
77
  }
78
 
79
  @app.post("/analyze")
80
  async def analyze_image(file: UploadFile = File(...)):
81
+ if model is None or tokenizer is None or processor is None:
82
  return {"error": "Model not loaded yet. Please wait or check /health endpoint."}
83
 
84
  try:
 
89
  # Réduire la taille
90
  image.thumbnail((384, 384))
91
 
92
+ # --- NOUVELLE APPROCHE SANS PROCESSOR BATCH ---
93
+ # 1. Préparer l'image
94
+ image_input = processor(images=image, return_tensors="pt")
95
+
96
+ # 2. Préparer le texte - CHAQUE CATÉGORIE INDIVIDUELLEMENT
97
+ text_features_list = []
98
 
99
  for category in categories:
100
+ # Tokenizer chaque catégorie séparément
101
+ text_inputs = tokenizer(
102
+ category,
103
+ return_tensors="pt",
104
+ padding=True,
105
+ truncation=True,
106
+ max_length=77
107
  )
108
 
 
 
 
 
109
  with torch.no_grad():
110
+ text_features = model.get_text_features(**text_inputs)
111
+ text_features_list.append(text_features)
112
+
113
+ # 3. Get image features
114
+ with torch.no_grad():
115
+ image_features = model.get_image_features(**image_input)
116
+
117
+ # 4. Calculer les similarités
118
+ similarities = []
119
+ for text_features in text_features_list:
120
+ # Normaliser les features
121
+ image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True)
122
+ text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
123
 
124
+ # Calculer la similarité cosinus
125
+ similarity = (image_features_norm @ text_features_norm.T).squeeze()
126
+ similarities.append(similarity.item())
127
 
128
+ # 5. Convertir en probabilités
129
  similarities_tensor = torch.tensor(similarities)
130
  probs = torch.nn.functional.softmax(similarities_tensor, dim=0)
131
 
 
159
  except Exception as e:
160
  return {"error": f"Erreur lors de l'analyse: {str(e)}"}
161
 
162
+ # Interface de test SIMPLIFIÉE
163
  @app.get("/test-ui", response_class=HTMLResponse)
164
  async def test_ui():
165
  return """
166
  <html>
167
  <head>
168
+ <title>Fashion Detection</title>
169
  <style>
170
+ body { font-family: Arial, sans-serif; margin: 40px; text-align: center; }
171
+ form { border: 2px dashed #ccc; padding: 30px; display: inline-block; }
 
 
 
 
 
 
 
172
  </style>
173
  </head>
174
  <body>
175
+ <h1>🎨 Fashion Detection</h1>
176
+ <form action="/analyze" method="post" enctype="multipart/form-data">
177
+ <input type="file" name="file" accept="image/*" required>
178
+ <br><br>
179
+ <input type="submit" value="Analyze">
180
+ </form>
 
 
 
 
 
 
 
 
181
  </body>
182
  </html>
183
  """