TESTFASHION / app.py
MODLI's picture
Update app.py
b6eb828 verified
raw
history blame
6.13 kB
import os
import json
os.environ['HF_HOME'] = '/tmp/cache'
os.environ['TORCH_HOME'] = '/tmp/cache'
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from PIL import Image
import torch
import io
import colorthief
import tempfile
app = FastAPI(title="Fashion Detection API")
# Middleware CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"]
)
# --- CHARGE LE MODÈLE MARQO FASHIONCLIP ---
print("⚠️ Démarrage du chargement du modèle Marqo fashionCLIP...")
model = None
processor = None
def load_marqo_model():
global model, processor
try:
from transformers import CLIPProcessor, CLIPModel
model_name = "Marqo/marqo-fashionCLIP"
model = CLIPModel.from_pretrained(
model_name,
cache_dir="/tmp/cache",
torch_dtype=torch.float16
)
processor = CLIPProcessor.from_pretrained(model_name)
print("✅ Modèle Marqo fashionCLIP chargé avec succès !")
except Exception as e:
print(f"❌ Erreur chargement modèle Marqo: {e}")
@app.on_event("startup")
async def startup_event():
import threading
thread = threading.Thread(target=load_marqo_model)
thread.daemon = True
thread.start()
# Catégories fashion (textes plus courts et uniformes)
categories = [
"t-shirt", "dress", "jeans", "shirt", "skirt", "sneakers",
"handbag", "jacket", "shorts", "sweater", "coat", "heels"
]
@app.get("/")
def read_root():
return {"message": "Fashion Detection API is running!", "status": "OK"}
@app.get("/health")
def health_check():
return {
"model_loaded": model is not None,
"processor_loaded": processor is not None,
"status": "ready" if model and processor else "loading"
}
@app.post("/analyze")
async def analyze_image(file: UploadFile = File(...)):
if model is None or processor is None:
return {"error": "Model not loaded yet. Please wait or check /health endpoint."}
try:
# Lire l'image
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# Réduire la taille
image.thumbnail((384, 384))
# --- SOLUTION DÉFINITIVE ---
# Traiter chaque catégorie SÉPARÉMENT pour éviter les problèmes de padding
similarities = []
for category in categories:
# Préparer les inputs pour UNE catégorie à la fois
inputs = processor(
text=[category], # Une seule catégorie
images=image,
return_tensors="pt",
padding=True, # Padding pour une seule phrase
truncation=True
)
# Déplacer sur le device du modèle
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
# Récupérer le score de similarité
similarity_score = outputs.logits_per_image.item()
similarities.append(similarity_score)
# Convertir en tensor et calculer les probabilités
similarities_tensor = torch.tensor(similarities)
probs = torch.nn.functional.softmax(similarities_tensor, dim=0)
# Trouver la catégorie prédite
predicted_class_idx = probs.argmax().item()
category_name = categories[predicted_class_idx]
confidence_score = probs[predicted_class_idx].item()
# Analyse couleur
try:
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
image.save(tmp, format='JPEG')
tmp_path = tmp.name
color_thief = colorthief.ColorThief(tmp_path)
dominant_color = color_thief.get_color(quality=1)
hex_color = '#%02x%02x%02x' % dominant_color
os.unlink(tmp_path)
except Exception as color_error:
print(f"Erreur analyse couleur: {color_error}")
hex_color = "#000000"
return {
"category": category_name,
"color_hex": hex_color,
"confidence": round(confidence_score, 4)
}
except Exception as e:
return {"error": f"Erreur lors de l'analyse: {str(e)}"}
# Interface de test
@app.get("/test-ui", response_class=HTMLResponse)
async def test_ui():
return """
<html>
<head>
<title>Fashion Detection Test</title>
<style>
body { font-family: Arial, sans-serif; margin: 40px; }
.container { max-width: 600px; margin: 0 auto; }
form { border: 2px dashed #ccc; padding: 30px; text-align: center; }
input[type="file"] { margin: 10px 0; }
input[type="submit"] {
background: #007bff; color: white; padding: 10px 20px;
border: none; cursor: pointer; border-radius: 5px;
}
.result { margin-top: 20px; padding: 20px; background: #f0f8ff; }
</style>
</head>
<body>
<div class="container">
<h1>🎨 Fashion Detection AI</h1>
<form action="/analyze" method="post" enctype="multipart/form-data">
<h3>Uploader une image de vêtement :</h3>
<input type="file" name="file" accept="image/*" required>
<br><br>
<input type="submit" value="Analyser l'image 👗">
</form>
<div class="result">
<h3>📋 Résultat de l'analyse :</h3>
<p>Attendez l'upload et le traitement de l'image...</p>
</div>
</div>
</body>
</html>
"""