TESTFASHION / app.py
MODLI's picture
Update app.py
9c94de5 verified
raw
history blame
8.43 kB
import os
import json
import time
os.environ['HF_HOME'] = '/tmp/cache'
os.environ['TORCH_HOME'] = '/tmp/cache'
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from PIL import Image
import torch
import io
import colorthief
import tempfile
import numpy as np
app = FastAPI(title="Fashion Classification API")
# Middleware CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"]
)
# --- ÉTAT DU MODÈLE ---
print("⚠️ Démarrage du chargement du modèle Marqo-FashionSigLIP...")
model = None
processor = None
model_loading = False
model_loaded = False
model_error = None
def load_fashion_model():
global model, processor, model_loading, model_loaded, model_error
model_loading = True
try:
from transformers import AutoModel, AutoProcessor
model_name = "Marqo/Marqo-FashionSigLIP-Classification"
print("📦 Téléchargement du modèle... (cela peut prendre 5-10 minutes)")
# Charger le modèle SigLIP
model = AutoModel.from_pretrained(
model_name,
cache_dir="/tmp/cache",
torch_dtype=torch.float16,
trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(
model_name,
trust_remote_code=True
)
print("✅ Modèle Marqo-FashionSigLIP chargé avec succès !")
model_loaded = True
model_loading = False
except Exception as e:
print(f"❌ Erreur chargement modèle: {e}")
model_error = str(e)
model_loading = False
import traceback
traceback.print_exc()
# Démarrer le chargement IMMÉDIATEMENT
load_fashion_model()
# Catégories de mode
categories = [
"t-shirt", "dress", "jeans", "shirt", "skirt",
"sneakers", "handbag", "jacket", "shorts", "sweater",
"coat", "high heels", "blouse", "boots", "hat"
]
@app.get("/")
def read_root():
return {
"message": "Fashion Classification API is running!",
"status": "OK",
"model_status": "loaded" if model_loaded else "loading" if model_loading else "error",
"model_name": "Marqo-FashionSigLIP-Classification"
}
@app.get("/health")
def health_check():
return {
"model_loaded": model_loaded,
"model_loading": model_loading,
"model_error": model_error,
"status": "ready" if model_loaded else "loading" if model_loading else "error",
"model_name": "Marqo-FashionSigLIP-Classification",
"timestamp": time.time()
}
@app.post("/analyze")
async def analyze_image(file: UploadFile = File(...)):
# Vérifier si le modèle est chargé
if not model_loaded:
if model_loading:
raise HTTPException(status_code=423, detail="Model still loading. Please wait 5-10 minutes and check /health")
else:
raise HTTPException(status_code=500, detail=f"Model failed to load: {model_error}")
if model is None or processor is None:
raise HTTPException(status_code=500, detail="Model not available")
try:
# Lire et préparer l'image
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
image = image.resize((384, 384))
# Traitement avec SigLIP
inputs = processor(
text=categories,
images=image,
return_tensors="pt",
padding=True,
truncation=True,
max_length=64,
)
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = torch.sigmoid(logits_per_image)
probs = probs.cpu().numpy()[0]
predicted_idx = np.argmax(probs)
category_name = categories[predicted_idx]
confidence_score = float(probs[predicted_idx])
# Analyse couleur
try:
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
image.save(tmp, format='JPEG')
tmp_path = tmp.name
color_thief = colorthief.ColorThief(tmp_path)
dominant_color = color_thief.get_color(quality=1)
hex_color = '#%02x%02x%02x' % dominant_color
os.unlink(tmp_path)
except Exception:
hex_color = "#000000"
return {
"category": category_name,
"confidence": round(confidence_score, 4),
"color_hex": hex_color,
"model": "Marqo-FashionSigLIP-Classification"
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Analysis error: {str(e)}")
# Interface de test avec statut de chargement
@app.get("/test-ui", response_class=HTMLResponse)
async def test_ui():
return f"""
<html>
<head>
<title>FashionSigLIP Detection</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 40px; }}
.container {{ max-width: 600px; margin: 0 auto; }}
form {{ border: 2px dashed #ccc; padding: 30px; text-align: center; }}
.status {{ padding: 15px; margin: 10px 0; border-radius: 5px; }}
.loading {{ background: #fff3cd; color: #856404; }}
.ready {{ background: #d4edda; color: #155724; }}
.error {{ background: #f8d7da; color: #721c24; }}
</style>
<script>
function checkStatus() {{
fetch('/health')
.then(response => response.json())
.then(data => {{
const statusDiv = document.getElementById('model-status');
const submitBtn = document.getElementById('submit-btn');
if (data.model_loaded) {{
statusDiv.innerHTML = '✅ <b>Modèle chargé et prêt !</b>';
statusDiv.className = 'status ready';
submitBtn.disabled = false;
}} else if (data.model_loading) {{
statusDiv.innerHTML = '⏳ <b>Chargement du modèle en cours...</b><br>Cela peut prendre 5-10 minutes';
statusDiv.className = 'status loading';
submitBtn.disabled = true;
setTimeout(checkStatus, 5000); // Re-check dans 5 sec
}} else {{
statusDiv.innerHTML = '❌ <b>Erreur de chargement:</b><br>' + (data.model_error || 'Unknown error');
statusDiv.className = 'status error';
submitBtn.disabled = true;
}}
}});
}}
// Vérifier le statut au chargement de la page
window.onload = checkStatus;
</script>
</head>
<body>
<div class="container">
<h1>👗 FashionSigLIP Detector</h1>
<div id="model-status" class="status loading">
Vérification du statut du modèle...
</div>
<form action="/analyze" method="post" enctype="multipart/form-data">
<h3>Uploader une image de vêtement :</h3>
<input type="file" name="file" accept="image/*" required>
<br><br>
<input type="submit" id="submit-btn" value="Analyser" disabled>
</form>
<div style="margin-top: 20px;">
<button onclick="checkStatus()">Actualiser le statut</button>
<button onclick="location.reload()">Rafraîchir la page</button>
</div>
</div>
</body>
</html>
"""