Spaces:
Sleeping
Sleeping
File size: 5,295 Bytes
66791ba cff0bfe 5f1bae3 b13493b 5cf61c7 5f1bae3 5cf61c7 919d79d 5cf61c7 b13493b 5f1bae3 b1cba22 acd685d 5f1bae3 b1cba22 919d79d 5f1bae3 b1cba22 5f1bae3 b13493b 5f1bae3 acd685d 5f1bae3 b1cba22 1a0da4b b13493b 5f1bae3 b13493b 5f1bae3 b13493b b1cba22 5f1bae3 dbadef3 5f1bae3 cff0bfe 5f1bae3 ca655fa 919d79d 5f1bae3 ca655fa 5f1bae3 ca655fa 5f1bae3 919d79d e1eace0 5f1bae3 e1eace0 ca655fa 5f1bae3 ca655fa e1eace0 ca655fa b13493b 5f1bae3 b13493b ca655fa e1eace0 5f1bae3 b1cba22 5f1bae3 ca655fa 5f1bae3 aa56d44 5f1bae3 b1cba22 5f1bae3 919d79d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
os.environ['HF_HOME'] = '/tmp/cache'
os.environ['TORCH_HOME'] = '/tmp/cache'
import json
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import torch
import requests
from io import BytesIO
from transformers import CLIPProcessor, CLIPModel
app = FastAPI(title="Fashion Classification API")
# Middleware CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"]
)
# --- Configuration du modèle ---
print("🔄 Chargement du modèle Fashion CLIP...")
model = None
processor = None
def load_model():
global model, processor
try:
model_name = "patrickjohncyh/fashion-clip"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
print("✅ Modèle chargé avec succès!")
except Exception as e:
print(f"❌ Erreur de chargement: {e}")
# Charger le modèle au démarrage
load_model()
# Catégories en français avec mapping vers anglais
CATEGORIES_FR = {
"haut": ["a t-shirt", "a shirt", "a sweater", "a blouse", "a top"],
"pantalon": ["jeans", "pants", "trousers", "leggings"],
"robe": ["a dress", "a gown", "a sundress"],
"jupe": ["a skirt"],
"short": ["shorts", "bermuda shorts"],
"veste": ["a jacket", "a blazer", "a leather jacket"],
"manteau": ["a coat", "a winter coat", "a parka"],
"chaussures": ["sneakers", "high heels", "boots", "sandals"],
"sac": ["a handbag", "a purse", "a backpack"],
"accessoire": ["a hat", "sunglasses", "a scarf", "a belt"],
"autre": ["clothing", "fashion item"]
}
@app.get("/")
def read_root():
return {"message": "Fashion Classification API is running!", "status": "OK"}
@app.get("/health")
def health_check():
return {
"model_loaded": model is not None,
"status": "ready" if model else "loading"
}
@app.post("/classify")
async def classify_fashion(image_data: dict):
"""
Endpoint pour Lovable - accepte une URL d'image
"""
try:
if not model or not processor:
raise HTTPException(status_code=503, detail="Model not loaded yet")
# Vérifier et extraire l'URL de l'image
image_url = image_data.get("imageUrl")
if not image_url:
raise HTTPException(status_code=400, detail="imageUrl is required")
# Télécharger l'image
response = requests.get(image_url, timeout=30)
response.raise_for_status()
# Ouvrir et préparer l'image
image = Image.open(BytesIO(response.content)).convert("RGB")
image = image.resize((224, 224)) # Taille standard pour CLIP
# Préparer les catégories
all_english_categories = []
category_mapping = {}
for fr_cat, en_categories in CATEGORIES_FR.items():
all_english_categories.extend(en_categories)
for en_cat in en_categories:
category_mapping[en_cat] = fr_cat
# === NOUVELLE APPROCHE : Traitement séquentiel ===
results = {}
for category in all_english_categories:
try:
# Traiter chaque catégorie individuellement
inputs = processor(
text=[category], # Une seule catégorie à la fois
images=image,
return_tensors="pt",
padding=True,
truncation=True,
max_length=77,
return_overflowing_tokens=False
)
with torch.no_grad():
outputs = model(**inputs)
results[category] = outputs.logits_per_image.item()
except Exception as e:
print(f"Erreur avec la catégorie {category}: {e}")
results[category] = -10.0 # Valeur très basse en cas d'erreur
# Trouver la meilleure catégorie
if not results:
raise HTTPException(status_code=500, detail="Aucun résultat obtenu")
best_english_category = max(results, key=results.get)
confidence = results[best_english_category]
# Convertir le score en probabilité (approximative)
confidence_normalized = 1 / (1 + torch.exp(torch.tensor(-confidence))).item()
# Catégorie française
best_french_category = category_mapping.get(best_english_category, "autre")
return {
"success": True,
"category": best_french_category,
"confidence": round(confidence_normalized, 4),
"colorHex": "#000000",
"originalCategory": best_english_category,
"method": "modli-api"
}
except requests.exceptions.RequestException as e:
raise HTTPException(status_code=400, detail=f"Invalid image URL: {str(e)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Classification error: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |