Spaces:
Sleeping
Sleeping
| import os | |
| os.environ['HF_HOME'] = '/tmp/cache' | |
| os.environ['TORCH_HOME'] = '/tmp/cache' | |
| import json | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image | |
| import torch | |
| import requests | |
| from io import BytesIO | |
| from transformers import CLIPProcessor, CLIPModel | |
| app = FastAPI(title="Fashion Classification API") | |
| # Middleware CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| expose_headers=["*"] | |
| ) | |
| # --- Configuration du modèle --- | |
| print("🔄 Chargement du modèle Fashion CLIP...") | |
| model = None | |
| processor = None | |
| def load_model(): | |
| global model, processor | |
| try: | |
| model_name = "patrickjohncyh/fashion-clip" | |
| model = CLIPModel.from_pretrained(model_name) | |
| processor = CLIPProcessor.from_pretrained(model_name) | |
| print("✅ Modèle chargé avec succès!") | |
| except Exception as e: | |
| print(f"❌ Erreur de chargement: {e}") | |
| # Charger le modèle au démarrage | |
| load_model() | |
| # Catégories en français avec mapping vers anglais | |
| CATEGORIES_FR = { | |
| "haut": ["a t-shirt", "a shirt", "a sweater", "a blouse", "a top"], | |
| "pantalon": ["jeans", "pants", "trousers", "leggings"], | |
| "robe": ["a dress", "a gown", "a sundress"], | |
| "jupe": ["a skirt"], | |
| "short": ["shorts", "bermuda shorts"], | |
| "veste": ["a jacket", "a blazer", "a leather jacket"], | |
| "manteau": ["a coat", "a winter coat", "a parka"], | |
| "chaussures": ["sneakers", "high heels", "boots", "sandals"], | |
| "sac": ["a handbag", "a purse", "a backpack"], | |
| "accessoire": ["a hat", "sunglasses", "a scarf", "a belt"], | |
| "autre": ["clothing", "fashion item"] | |
| } | |
| def read_root(): | |
| return {"message": "Fashion Classification API is running!", "status": "OK"} | |
| def health_check(): | |
| return { | |
| "model_loaded": model is not None, | |
| "status": "ready" if model else "loading" | |
| } | |
| async def classify_fashion(image_data: dict): | |
| """ | |
| Endpoint pour Lovable - accepte une URL d'image | |
| """ | |
| try: | |
| if not model or not processor: | |
| raise HTTPException(status_code=503, detail="Model not loaded yet") | |
| # Vérifier et extraire l'URL de l'image | |
| image_url = image_data.get("imageUrl") | |
| if not image_url: | |
| raise HTTPException(status_code=400, detail="imageUrl is required") | |
| # Télécharger l'image | |
| response = requests.get(image_url, timeout=30) | |
| response.raise_for_status() | |
| # Ouvrir et préparer l'image | |
| image = Image.open(BytesIO(response.content)).convert("RGB") | |
| image = image.resize((224, 224)) # Taille standard pour CLIP | |
| # Préparer les catégories | |
| all_english_categories = [] | |
| category_mapping = {} | |
| for fr_cat, en_categories in CATEGORIES_FR.items(): | |
| all_english_categories.extend(en_categories) | |
| for en_cat in en_categories: | |
| category_mapping[en_cat] = fr_cat | |
| # === NOUVELLE APPROCHE : Traitement séquentiel === | |
| results = {} | |
| for category in all_english_categories: | |
| try: | |
| # Traiter chaque catégorie individuellement | |
| inputs = processor( | |
| text=[category], # Une seule catégorie à la fois | |
| images=image, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=77, | |
| return_overflowing_tokens=False | |
| ) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| results[category] = outputs.logits_per_image.item() | |
| except Exception as e: | |
| print(f"Erreur avec la catégorie {category}: {e}") | |
| results[category] = -10.0 # Valeur très basse en cas d'erreur | |
| # Trouver la meilleure catégorie | |
| if not results: | |
| raise HTTPException(status_code=500, detail="Aucun résultat obtenu") | |
| best_english_category = max(results, key=results.get) | |
| confidence = results[best_english_category] | |
| # Convertir le score en probabilité (approximative) | |
| confidence_normalized = 1 / (1 + torch.exp(torch.tensor(-confidence))).item() | |
| # Catégorie française | |
| best_french_category = category_mapping.get(best_english_category, "autre") | |
| return { | |
| "success": True, | |
| "category": best_french_category, | |
| "confidence": round(confidence_normalized, 4), | |
| "colorHex": "#000000", | |
| "originalCategory": best_english_category, | |
| "method": "modli-api" | |
| } | |
| except requests.exceptions.RequestException as e: | |
| raise HTTPException(status_code=400, detail=f"Invalid image URL: {str(e)}") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Classification error: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |