MODLI commited on
Commit
4e1c2b6
·
verified ·
1 Parent(s): 4d5ff3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -142
app.py CHANGED
@@ -4,180 +4,84 @@ os.environ['HF_HOME'] = '/tmp/cache'
4
  os.environ['TORCH_HOME'] = '/tmp/cache'
5
 
6
  from fastapi import FastAPI, File, UploadFile
7
- from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.responses import HTMLResponse
9
  from PIL import Image
10
  import torch
11
  import io
12
- import colorthief
13
- import tempfile
14
 
15
  app = FastAPI(title="Fashion Detection API")
16
 
17
- # Middleware CORS
18
- app.add_middleware(
19
- CORSMiddleware,
20
- allow_origins=["*"],
21
- allow_credentials=True,
22
- allow_methods=["*"],
23
- allow_headers=["*"],
24
- expose_headers=["*"]
25
- )
26
-
27
- # --- CHARGE LE MODÈLE MARQO FASHIONCLIP ---
28
- print("⚠️ Démarrage du chargement du modèle Marqo fashionCLIP...")
29
  model = None
30
- tokenizer = None
31
  processor = None
32
 
33
- def load_marqo_model():
34
- global model, tokenizer, processor
35
  try:
36
- from transformers import CLIPModel, CLIPTokenizer, CLIPImageProcessor
37
-
38
  model_name = "Marqo/marqo-fashionCLIP"
39
-
40
- # Charger les composants séparément
41
- model = CLIPModel.from_pretrained(
42
- model_name,
43
- cache_dir="/tmp/cache",
44
- torch_dtype=torch.float16
45
- )
46
- tokenizer = CLIPTokenizer.from_pretrained(model_name)
47
- processor = CLIPImageProcessor.from_pretrained(model_name)
48
-
49
- print("✅ Modèle Marqo fashionCLIP chargé avec succès !")
50
  except Exception as e:
51
- print(f"❌ Erreur chargement modèle Marqo: {e}")
52
 
53
  @app.on_event("startup")
54
- async def startup_event():
55
  import threading
56
- thread = threading.Thread(target=load_marqo_model)
57
- thread.daemon = True
58
- thread.start()
59
 
60
- # Catégories fashion (textes courts)
61
- categories = [
62
- "a t-shirt", "a dress", "jeans", "a shirt", "a skirt",
63
- "sneakers", "a handbag", "a jacket", "shorts", "a sweater"
64
- ]
65
 
66
  @app.get("/")
67
- def read_root():
68
- return {"message": "Fashion Detection API is running!", "status": "OK"}
69
-
70
- @app.get("/health")
71
- def health_check():
72
- return {
73
- "model_loaded": model is not None,
74
- "tokenizer_loaded": tokenizer is not None,
75
- "processor_loaded": processor is not None,
76
- "status": "ready" if all([model, tokenizer, processor]) else "loading"
77
- }
78
 
79
  @app.post("/analyze")
80
- async def analyze_image(file: UploadFile = File(...)):
81
- if model is None or tokenizer is None or processor is None:
82
- return {"error": "Model not loaded yet. Please wait or check /health endpoint."}
83
 
84
  try:
85
- # Lire l'image
86
- contents = await file.read()
87
- image = Image.open(io.BytesIO(contents)).convert("RGB")
88
-
89
- # Réduire la taille
90
- image.thumbnail((384, 384))
91
-
92
- # --- NOUVELLE APPROCHE SANS PROCESSOR BATCH ---
93
- # 1. Préparer l'image
94
- image_input = processor(images=image, return_tensors="pt")
95
-
96
- # 2. Préparer le texte - CHAQUE CATÉGORIE INDIVIDUELLEMENT
97
- text_features_list = []
98
 
 
 
99
  for category in categories:
100
- # Tokenizer chaque catégorie séparément
101
- text_inputs = tokenizer(
102
- category,
103
- return_tensors="pt",
104
- padding=True,
105
- truncation=True,
106
- max_length=77
107
  )
108
-
109
  with torch.no_grad():
110
- text_features = model.get_text_features(**text_inputs)
111
- text_features_list.append(text_features)
112
-
113
- # 3. Get image features
114
- with torch.no_grad():
115
- image_features = model.get_image_features(**image_input)
116
 
117
- # 4. Calculer les similarités
118
- similarities = []
119
- for text_features in text_features_list:
120
- # Normaliser les features
121
- image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True)
122
- text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
123
-
124
- # Calculer la similarité cosinus
125
- similarity = (image_features_norm @ text_features_norm.T).squeeze()
126
- similarities.append(similarity.item())
127
 
128
- # 5. Convertir en probabilités
129
- similarities_tensor = torch.tensor(similarities)
130
- probs = torch.nn.functional.softmax(similarities_tensor, dim=0)
131
-
132
- # Trouver la catégorie prédite
133
- predicted_class_idx = probs.argmax().item()
134
- category_name = categories[predicted_class_idx]
135
- confidence_score = probs[predicted_class_idx].item()
136
-
137
- # Analyse couleur
138
- try:
139
- with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
140
- image.save(tmp, format='JPEG')
141
- tmp_path = tmp.name
142
-
143
- color_thief = colorthief.ColorThief(tmp_path)
144
- dominant_color = color_thief.get_color(quality=1)
145
- hex_color = '#%02x%02x%02x' % dominant_color
146
-
147
- os.unlink(tmp_path)
148
-
149
- except Exception as color_error:
150
- print(f"Erreur analyse couleur: {color_error}")
151
- hex_color = "#000000"
152
-
153
  return {
154
- "category": category_name,
155
- "color_hex": hex_color,
156
- "confidence": round(confidence_score, 4)
157
  }
158
-
159
  except Exception as e:
160
- return {"error": f"Erreur lors de l'analyse: {str(e)}"}
161
 
162
- # Interface de test SIMPLIFIÉE
163
- @app.get("/test-ui", response_class=HTMLResponse)
164
- async def test_ui():
165
  return """
166
- <html>
167
- <head>
168
- <title>Fashion Detection</title>
169
- <style>
170
- body { font-family: Arial, sans-serif; margin: 40px; text-align: center; }
171
- form { border: 2px dashed #ccc; padding: 30px; display: inline-block; }
172
- </style>
173
- </head>
174
- <body>
175
- <h1>🎨 Fashion Detection</h1>
176
- <form action="/analyze" method="post" enctype="multipart/form-data">
177
- <input type="file" name="file" accept="image/*" required>
178
- <br><br>
179
- <input type="submit" value="Analyze">
180
- </form>
181
- </body>
182
- </html>
183
  """
 
4
  os.environ['TORCH_HOME'] = '/tmp/cache'
5
 
6
  from fastapi import FastAPI, File, UploadFile
 
7
  from fastapi.responses import HTMLResponse
8
  from PIL import Image
9
  import torch
10
  import io
 
 
11
 
12
  app = FastAPI(title="Fashion Detection API")
13
 
14
+ # Modèle et processor (chargement différé)
 
 
 
 
 
 
 
 
 
 
 
15
  model = None
 
16
  processor = None
17
 
18
+ def load_model():
19
+ global model, processor
20
  try:
21
+ from transformers import CLIPModel, CLIPProcessor
 
22
  model_name = "Marqo/marqo-fashionCLIP"
23
+ model = CLIPModel.from_pretrained(model_name)
24
+ processor = CLIPProcessor.from_pretrained(model_name)
25
+ print("✅ Modèle chargé!")
 
 
 
 
 
 
 
 
26
  except Exception as e:
27
+ print(f"❌ Erreur: {e}")
28
 
29
  @app.on_event("startup")
30
+ async def startup():
31
  import threading
32
+ threading.Thread(target=load_model).start()
 
 
33
 
34
+ categories = ["t-shirt", "dress", "jeans", "shirt", "skirt", "jacket", "sweater"]
 
 
 
 
35
 
36
  @app.get("/")
37
+ def home():
38
+ return {"message": "API running", "status": "OK"}
 
 
 
 
 
 
 
 
 
39
 
40
  @app.post("/analyze")
41
+ async def analyze(file: UploadFile = File(...)):
42
+ if not model or not processor:
43
+ return {"error": "Model loading..."}
44
 
45
  try:
46
+ # Lire image
47
+ image = Image.open(io.BytesIO(await file.read())).convert("RGB")
48
+ image.thumbnail((256, 256))
 
 
 
 
 
 
 
 
 
 
49
 
50
+ # Méthode SIMPLE et FIABLE
51
+ results = {}
52
  for category in categories:
53
+ inputs = processor(
54
+ text=[category],
55
+ images=image,
56
+ return_tensors="pt",
57
+ padding=True,
58
+ truncation=True
 
59
  )
 
60
  with torch.no_grad():
61
+ outputs = model(**inputs)
62
+ results[category] = outputs.logits_per_image.item()
 
 
 
 
63
 
64
+ # Trouver le meilleur résultat
65
+ best_category = max(results, key=results.get)
66
+ confidence = results[best_category]
 
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  return {
69
+ "category": best_category,
70
+ "confidence": round(confidence, 4),
71
+ "color_hex": "#000000" # Couleur basique pour l'instant
72
  }
73
+
74
  except Exception as e:
75
+ return {"error": str(e)}
76
 
77
+ @app.get("/ui", response_class=HTMLResponse)
78
+ def ui():
 
79
  return """
80
+ <html><body>
81
+ <h1>Fashion Detector</h1>
82
+ <form action="/analyze" method="post" enctype="multipart/form-data">
83
+ <input type="file" name="file" accept="image/*">
84
+ <input type="submit" value="Analyze">
85
+ </form>
86
+ </body></html>
 
 
 
 
 
 
 
 
 
 
87
  """