TESTFASHION / app.py
MODLI's picture
Update app.py
9fe1fb1 verified
raw
history blame
5.84 kB
import json
import os
os.environ['HF_HOME'] = '/tmp/cache'
os.environ['TORCH_HOME'] = '/tmp/cache'
from fastapi import FastAPI, File, UploadFile, Response
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel # CHANGÉ : CLIP au lieu de Auto
import io
import colorthief
# --- Charge le modèle Marqo fashionCLIP ---
print("⚠️ Démarrage du chargement du modèle...")
model_name = "Marqo/marqo-fashionCLIP"
# CHANGÉ : On charge le modèle CLIP standard
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
print("✅ Modèle chargé avec succès !")
# ---------------------------------------------------------
app = FastAPI(title="Fashion Detection API")
# Middleware pour autoriser les appels depuis votre application Lovable
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Pour le développement. Pour la production, remplacez par l'URL de Lovable.
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"]
)
# Liste de catégories possibles en Anglais. Le modèle comprend mieux l'Anglais.
# MODIFIEZ CETTE LISTE PERSONNALISEE SELON VOS BESOINS !
categories = [
"a t-shirt", "a dress", "jeans", "a shirt", "a skirt", "sneakers",
"a handbag", "a jacket", "shorts", "a sweater", "a coat", "high heels",
"a scarf", "sunglasses", "a hat", "pants", "a blouse", "boots",
"a sweatshirt", "a jumper", "an apron", "a ball gown", "a bandanna",
"a baseball cap", "a beanie", "a belt", "a beret", "Bermuda shorts",
"baby clothes", "a bib", "a bikini", "a blazer", "a bow tie",
"boxer shorts", "a bra", "a bracelet", "breeches", "a buckle",
"a button", "camouflage", "a cap", "a cape", "a cardigan", "a cloak",
"clogs", "a corset", "a crown", "cuff links", "a dress shirt",
"dungarees", "earmuffs", "earrings", "a flannel shirt", "flip-flops",
"a fur coat", "a gilet", "glasses", "gloves", "a gown", "a Hawaiian shirt",
"a helmet", "a hijab", "a hoodie", "a hospital gown", "jewelry",
"a jumpsuit", "khakis", "a kilt", "knickers", "a lab coat",
"a leather jacket", "leggings", "a leotard", "a life jacket",
"lingerie", "loafers", "a miniskirt", "mittens", "a necklace",
"a nightgown", "a nightshirt", "onesies", "pajamas", "a pantsuit",
"pantyhose", "a parka", "a polo shirt", "a poncho", "a purse",
"a raincoat", "a ring", "a robe", "a rugby shirt", "sandals",
"scrubs", "shoes", "slippers", "socks", "a spacesuit", "stockings",
"a stole", "a suit", "a sun hat", "a sundress", "suspenders",
"sweatpants", "a swimsuit", "a tank top", "a tiara", "a tie",
"a tie clip", "tights", "a toga", "a top", "a top coat", "a top hat",
"a train", "a trench coat", "trousers", "trunks", "a tube top",
"a turban", "a turtleneck", "a tutu", "a tuxedo", "an umbrella",
"a veil", "a vest", "a waistcoat", "a wedding gown", "a wetsuit",
"a windbreaker", "joggers", "palazzo pants", "cargo pants",
"dress pants", "chinos", "a crop top", "a romper", "an insulated jacket",
"a fleece", "a rain jacket", "a running jacket", "a graphic top",
"a skort", "a sports bra", "water shorts", "goggles", "boxing gloves",
"leg gaiters", "a neck gaiter", "a watch", "a swim trunk",
"a pocket watch", "insoles", "climbing shoes"
]
# Ajoutez cette route AVANT votre route /analyze
@app.get("/")
def read_root():
return {"message": "Fashion Detection API is running!", "status": "OK"}
@app.post("/analyze")
async def analyze_image(file: UploadFile = File(...)):
# 1. Lire l'image envoyée par l'utilisateur
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# 2. ANALYSE AVEC LE MODÈLE MARQO FASHIONCLIP (CODE CORRIGÉ)
try:
# CHANGÉ : Préparer les inputs correctement pour CLIP
inputs = processor(text=categories, images=image, return_tensors="pt", padding=True)
# Passer through the model
with torch.no_grad():
outputs = model(**inputs)
# Récupérer les similarités image-texte
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1) # Convertir en probabilités
# Trouver la catégorie avec la probabilité la plus élevée
predicted_class_idx = probs.argmax(dim=1).item()
category_name = categories[predicted_class_idx]
confidence_score = probs[0][predicted_class_idx].item()
except Exception as e:
return {"error": f"Erreur lors de l'analyse AI: {str(e)}"}
# 3. ANALYSE DE LA COULEUR (avec ColorThief)
try:
# On sauvegarde l'image en mémoire pour ColorThief
img_buffer = io.BytesIO()
image.save(img_buffer, format="PNG")
img_buffer.seek(0)
# Extrait la couleur dominante
color_thief = colorthief.ColorThief(img_buffer)
dominant_color = color_thief.get_color(quality=1)
# Convertit le RGB (ex: (255, 0, 0)) en code hexadécimal (ex: #ff0000)
hex_color = '#%02x%02x%02x' % dominant_color
except Exception as e:
hex_color = "#000000" # Couleur noire par défault en cas d'erreur
# 4. Renvoie le résultat à Lovable
return Response(
content=json.dumps({
"category": category_name,
"color_hex": hex_color,
"confidence": round(confidence_score, 4)
}),
media_type="application/json",
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Credentials": "true"
}
) # Arrondit le score de confiance à 4 décimales
# Cette partie est importante pour Hugging Face Spaces