MODLI commited on
Commit
b1cba22
·
verified ·
1 Parent(s): afe455b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -79
app.py CHANGED
@@ -1,132 +1,139 @@
1
- import json
2
  import os
 
3
  os.environ['HF_HOME'] = '/tmp/cache'
4
  os.environ['TORCH_HOME'] = '/tmp/cache'
5
 
6
-
7
- from fastapi import FastAPI, File, UploadFile, Response
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from PIL import Image
10
  import torch
11
- from transformers import CLIPProcessor, CLIPModel # CHANGÉ : CLIP au lieu de Auto
12
  import io
13
  import colorthief
14
 
15
- # --- Charge le modèle Marqo fashionCLIP ---
16
- print("⚠️ Démarrage du chargement du modèle...")
17
- model_name = "Marqo/marqo-fashionCLIP"
18
- # CHANGÉ : On charge le modèle CLIP standard
19
- model = CLIPModel.from_pretrained(model_name)
20
- processor = CLIPProcessor.from_pretrained(model_name)
21
- print("✅ Modèle chargé avec succès !")
22
- # ---------------------------------------------------------
23
-
24
  app = FastAPI(title="Fashion Detection API")
25
 
26
- # Middleware pour autoriser les appels depuis votre application Lovable
27
  app.add_middleware(
28
  CORSMiddleware,
29
- allow_origins=["*"], # Pour le développement. Pour la production, remplacez par l'URL de Lovable.
30
  allow_credentials=True,
31
  allow_methods=["*"],
32
  allow_headers=["*"],
33
  expose_headers=["*"]
34
  )
35
 
36
- # Liste de catégories possibles en Anglais. Le modèle comprend mieux l'Anglais.
37
- # MODIFIEZ CETTE LISTE PERSONNALISEE SELON VOS BESOINS !
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  categories = [
39
  "a t-shirt", "a dress", "jeans", "a shirt", "a skirt", "sneakers",
40
- "a handbag", "a jacket", "shorts", "a sweater", "a coat", "high heels",
41
- "a scarf", "sunglasses", "a hat", "pants", "a blouse", "boots",
42
- "a sweatshirt", "a jumper", "an apron", "a ball gown", "a bandanna",
43
- "a baseball cap", "a beanie", "a belt", "a beret", "Bermuda shorts",
44
- "baby clothes", "a bib", "a bikini", "a blazer", "a bow tie",
45
- "boxer shorts", "a bra", "a bracelet", "breeches", "a buckle",
46
- "a button", "camouflage", "a cap", "a cape", "a cardigan", "a cloak",
47
- "clogs", "a corset", "a crown", "cuff links", "a dress shirt",
48
- "dungarees", "earmuffs", "earrings", "a flannel shirt", "flip-flops",
49
- "a fur coat", "a gilet", "glasses", "gloves", "a gown", "a Hawaiian shirt",
50
- "a helmet", "a hijab", "a hoodie", "a hospital gown", "jewelry",
51
- "a jumpsuit", "khakis", "a kilt", "knickers", "a lab coat",
52
- "a leather jacket", "leggings", "a leotard", "a life jacket",
53
- "lingerie", "loafers", "a miniskirt", "mittens", "a necklace",
54
- "a nightgown", "a nightshirt", "onesies", "pajamas", "a pantsuit",
55
- "pantyhose", "a parka", "a polo shirt", "a poncho", "a purse",
56
- "a raincoat", "a ring", "a robe", "a rugby shirt", "sandals",
57
- "scrubs", "shoes", "slippers", "socks", "a spacesuit", "stockings",
58
- "a stole", "a suit", "a sun hat", "a sundress", "suspenders",
59
- "sweatpants", "a swimsuit", "a tank top", "a tiara", "a tie",
60
- "a tie clip", "tights", "a toga", "a top", "a top coat", "a top hat",
61
- "a train", "a trench coat", "trousers", "trunks", "a tube top",
62
- "a turban", "a turtleneck", "a tutu", "a tuxedo", "an umbrella",
63
- "a veil", "a vest", "a waistcoat", "a wedding gown", "a wetsuit",
64
- "a windbreaker", "joggers", "palazzo pants", "cargo pants",
65
- "dress pants", "chinos", "a crop top", "a romper", "an insulated jacket",
66
- "a fleece", "a rain jacket", "a running jacket", "a graphic top",
67
- "a skort", "a sports bra", "water shorts", "goggles", "boxing gloves",
68
- "leg gaiters", "a neck gaiter", "a watch", "a swim trunk",
69
- "a pocket watch", "insoles", "climbing shoes"
70
  ]
71
 
72
- # Ajoutez cette route AVANT votre route /analyze
73
  @app.get("/")
74
  def read_root():
75
  return {"message": "Fashion Detection API is running!", "status": "OK"}
76
 
 
 
 
 
 
 
 
 
77
  @app.post("/analyze")
78
  async def analyze_image(file: UploadFile = File(...)):
79
- # 1. Lire l'image envoyée par l'utilisateur
80
- contents = await file.read()
81
- image = Image.open(io.BytesIO(contents)).convert("RGB")
82
-
83
- # 2. ANALYSE AVEC LE MODÈLE MARQO FASHIONCLIP (CODE CORRIGÉ)
84
  try:
85
- # CHANGÉ : Préparer les inputs correctement pour CLIP
86
- inputs = processor(text=categories, images=image, return_tensors="pt", padding=True)
 
 
 
 
87
 
88
- # Passer through the model
 
 
 
 
 
 
 
 
 
89
  with torch.no_grad():
90
  outputs = model(**inputs)
91
 
92
- # Récupérer les similarités image-texte
93
  logits_per_image = outputs.logits_per_image
94
- probs = logits_per_image.softmax(dim=1) # Convertir en probabilités
95
 
96
- # Trouver la catégorie avec la probabilité la plus élevée
97
  predicted_class_idx = probs.argmax(dim=1).item()
98
  category_name = categories[predicted_class_idx]
99
  confidence_score = probs[0][predicted_class_idx].item()
100
 
101
- except Exception as e:
102
- return {"error": f"Erreur lors de l'analyse AI: {str(e)}"}
103
-
104
- # 3. ANALYSE DE LA COULEUR (avec ColorThief)
105
- try:
106
- # On sauvegarde l'image en mémoire pour ColorThief
107
  img_buffer = io.BytesIO()
108
  image.save(img_buffer, format="PNG")
109
  img_buffer.seek(0)
110
- # Extrait la couleur dominante
111
  color_thief = colorthief.ColorThief(img_buffer)
112
  dominant_color = color_thief.get_color(quality=1)
113
- # Convertit le RGB (ex: (255, 0, 0)) en code hexadécimal (ex: #ff0000)
114
  hex_color = '#%02x%02x%02x' % dominant_color
115
- except Exception as e:
116
- hex_color = "#000000" # Couleur noire par défault en cas d'erreur
117
 
118
- # 4. Renvoie le résultat à Lovable
119
- return Response(
120
- content=json.dumps({
121
  "category": category_name,
122
  "color_hex": hex_color,
123
  "confidence": round(confidence_score, 4)
124
- }),
125
- media_type="application/json",
126
- headers={
127
- "Access-Control-Allow-Origin": "*",
128
- "Access-Control-Allow-Credentials": "true"
129
  }
130
- ) # Arrondit le score de confiance à 4 décimales
131
 
132
- # Cette partie est importante pour Hugging Face Spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import json
3
  os.environ['HF_HOME'] = '/tmp/cache'
4
  os.environ['TORCH_HOME'] = '/tmp/cache'
5
 
6
+ from fastapi import FastAPI, File, UploadFile
 
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from PIL import Image
9
  import torch
 
10
  import io
11
  import colorthief
12
 
 
 
 
 
 
 
 
 
 
13
  app = FastAPI(title="Fashion Detection API")
14
 
15
+ # Middleware CORS
16
  app.add_middleware(
17
  CORSMiddleware,
18
+ allow_origins=["*"],
19
  allow_credentials=True,
20
  allow_methods=["*"],
21
  allow_headers=["*"],
22
  expose_headers=["*"]
23
  )
24
 
25
+ # --- CHARGE LE MODÈLE MARQO FASHIONCLIP ---
26
+ print("⚠️ Démarrage du chargement du modèle Marqo fashionCLIP...")
27
+ model = None
28
+ processor = None
29
+
30
+ def load_marqo_model():
31
+ global model, processor
32
+ try:
33
+ # Import différé pour éviter les problèmes de compatibilité
34
+ from transformers import CLIPProcessor, CLIPModel
35
+
36
+ model_name = "Marqo/marqo-fashionCLIP"
37
+ model = CLIPModel.from_pretrained(
38
+ model_name,
39
+ cache_dir="/tmp/cache",
40
+ torch_dtype=torch.float16 # Réduit la mémoire
41
+ )
42
+ processor = CLIPProcessor.from_pretrained(model_name)
43
+ print("✅ Modèle Marqo fashionCLIP chargé avec succès !")
44
+ except Exception as e:
45
+ print(f"❌ Erreur chargement modèle Marqo: {e}")
46
+ print("Assurez-vous que les versions dans requirements.txt sont compatibles")
47
+
48
+ # Charge le modèle au démarrage (mais en différé)
49
+ @app.on_event("startup")
50
+ async def startup_event():
51
+ import threading
52
+ # Charge le modèle dans un thread séparé pour ne pas bloquer le démarrage
53
+ thread = threading.Thread(target=load_marqo_model)
54
+ thread.daemon = True
55
+ thread.start()
56
+
57
+ # Catégories fashion simplifiées pour tests
58
  categories = [
59
  "a t-shirt", "a dress", "jeans", "a shirt", "a skirt", "sneakers",
60
+ "a handbag", "a jacket", "shorts", "a sweater", "a coat", "high heels"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  ]
62
 
 
63
  @app.get("/")
64
  def read_root():
65
  return {"message": "Fashion Detection API is running!", "status": "OK"}
66
 
67
+ @app.get("/health")
68
+ def health_check():
69
+ return {
70
+ "model_loaded": model is not None,
71
+ "processor_loaded": processor is not None,
72
+ "status": "ready" if model and processor else "loading"
73
+ }
74
+
75
  @app.post("/analyze")
76
  async def analyze_image(file: UploadFile = File(...)):
77
+ # Vérifier que le modèle est chargé
78
+ if model is None or processor is None:
79
+ return {"error": "Model not loaded yet. Please wait or check /health endpoint."}
80
+
 
81
  try:
82
+ # Lire l'image
83
+ contents = await file.read()
84
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
85
+
86
+ # Réduire la taille pour économiser la mémoire
87
+ image.thumbnail((384, 384))
88
 
89
+ # Analyse avec Marqo fashionCLIP
90
+ inputs = processor(
91
+ text=categories,
92
+ images=image,
93
+ return_tensors="pt",
94
+ padding=True,
95
+ truncation=True
96
+ )
97
+
98
+ # Utiliser le CPU (plus stable sur Hugging Face Spaces free)
99
  with torch.no_grad():
100
  outputs = model(**inputs)
101
 
 
102
  logits_per_image = outputs.logits_per_image
103
+ probs = logits_per_image.softmax(dim=1)
104
 
 
105
  predicted_class_idx = probs.argmax(dim=1).item()
106
  category_name = categories[predicted_class_idx]
107
  confidence_score = probs[0][predicted_class_idx].item()
108
 
109
+ # Analyse couleur
 
 
 
 
 
110
  img_buffer = io.BytesIO()
111
  image.save(img_buffer, format="PNG")
112
  img_buffer.seek(0)
 
113
  color_thief = colorthief.ColorThief(img_buffer)
114
  dominant_color = color_thief.get_color(quality=1)
 
115
  hex_color = '#%02x%02x%02x' % dominant_color
 
 
116
 
117
+ return {
 
 
118
  "category": category_name,
119
  "color_hex": hex_color,
120
  "confidence": round(confidence_score, 4)
 
 
 
 
 
121
  }
 
122
 
123
+ except Exception as e:
124
+ return {"error": f"Erreur lors de l'analyse: {str(e)}"}
125
+
126
+ # Interface simple pour tester
127
+ @app.get("/test-ui", response_class=HTMLResponse)
128
+ async def test_ui():
129
+ return """
130
+ <html>
131
+ <body>
132
+ <h1>Test Fashion Detection</h1>
133
+ <form action="/analyze" method="post" enctype="multipart/form-data">
134
+ <input type="file" name="file">
135
+ <input type="submit" value="Analyzer">
136
+ </form>
137
+ </body>
138
+ </html>
139
+ """