MODLI's picture
Update app.py
c3b9201 verified
import gradio as gr
from transformers import pipeline
from PIL import Image
import numpy as np
import requests
from io import BytesIO
import base64
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import uvicorn
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50
import torch.nn as nn
# Catégories fashion plus détaillées et précises
FASHION_CATEGORIES = [
"t-shirt", "dress", "jeans", "jacket", "skirt",
"sneakers", "handbag", "swimsuit", "lingerie", "sweater",
"coat", "shorts", "blouse", "hat", "top",
"sweatpants", "dress pants", "leggings", "boots",
"sandals", "heels", "backpack", "sunglasses", "blazer",
"cardigan", "polo shirt", "hoodie", "vest", "jumpsuit",
"romper", "crop top", "tank top", "long sleeve shirt",
"windbreaker", "parka", "trench coat", "leather jacket",
"denim jacket", "waistcoat", "suit", "tie", "scarf",
"gloves", "belt", "wallet", "watch", "jewelry"
]
print("🔧 Loading fashion model...")
# Charger un modèle plus spécialisé pour la mode
try:
# Essayer d'abord un modèle spécialisé fashion
fashion_pipe = pipeline(
"image-classification",
model="nateraw/fashion-clip",
device=0 if torch.cuda.is_available() else -1
)
print("✅ Fashion-CLIP model loaded successfully!")
except:
try:
# Fallback sur un modèle plus général mais avec fine-tuning
fashion_pipe = pipeline(
"zero-shot-image-classification",
model="openai/clip-vit-large-patch14",
device=0 if torch.cuda.is_available() else -1
)
print("✅ CLIP Large model loaded successfully!")
except:
# Dernier recours
fashion_pipe = pipeline(
"zero-shot-image-classification",
model="openai/clip-vit-base-patch32",
device=0 if torch.cuda.is_available() else -1
)
print("✅ CLIP Base model loaded as fallback!")
# Configuration API
API_KEYS = os.environ.get("API_KEYS", "").split(",")
# Modèle pour les requêtes API
class ClassificationRequest(BaseModel):
image_data: str
api_key: Optional[str] = None
def preprocess_image(image):
"""Prétraite l'image pour améliorer la détection"""
# Conversion en RGB si nécessaire
if image.mode != 'RGB':
image = image.convert('RGB')
# Redimensionnement intelligent avec maintien des proportions
width, height = image.size
max_size = 512
if max(width, height) > max_size:
ratio = max_size / max(width, height)
new_size = (int(width * ratio), int(height * ratio))
image = image.resize(new_size, Image.Resampling.LANCZOS)
return image
def load_image_from_url(url):
"""Charge une image depuis une URL de manière robuste"""
try:
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
}
response = requests.get(url, headers=headers, timeout=15)
response.raise_for_status()
# Vérifie que c'est bien une image
if not response.headers.get('content-type', '').startswith('image/'):
raise ValueError("URL does not point to an image")
image = Image.open(BytesIO(response.content))
return preprocess_image(image)
except Exception as e:
raise ValueError(f"❌ Cannot load image from URL: {str(e)}")
def analyze_fashion_item(image_input, url_input):
"""Analyse des vêtements - supporte image upload et URL"""
try:
# Détermine quelle source utiliser
if image_input is not None:
# Priorité à l'image uploadée
if isinstance(image_input, np.ndarray):
image = Image.fromarray(image_input)
else:
image = image_input
image = preprocess_image(image)
elif url_input and url_input.strip():
# Utilise l'URL
image = load_image_from_url(url_input.strip())
else:
return "❌ Please upload an image or enter a URL first", None
# 🔥 ANALYSE PRINCIPALE AVEC PARAMÈTRES OPTIMISÉS
try:
# Essayer d'abord avec le modèle fashion-clip
predictions = fashion_pipe(image)
# Si c'est le modèle fashion-clip, adapter le format de réponse
if hasattr(fashion_pipe, 'model') and 'fashion-clip' in str(fashion_pipe.model):
# Trier par score et formater
predictions = sorted(predictions, key=lambda x: x['score'], reverse=True)
confident_predictions = [p for p in predictions if p['score'] > 0.05]
else:
# Pour les modèles zero-shot
predictions = fashion_pipe(
image,
candidate_labels=FASHION_CATEGORIES,
hypothesis_template="a clear photo of {}",
multi_label=True
)
confident_predictions = [p for p in predictions if p['score'] > 0.1]
except Exception as model_error:
print(f"Model error: {model_error}")
return "❌ Model analysis failed. Please try another image.", image
if not confident_predictions:
return "❌ No confident prediction. Try a clearer image with better lighting.", image
# Trier par score décroissant
confident_predictions.sort(key=lambda x: x['score'], reverse=True)
best_pred = confident_predictions[0]
# Formatage des résultats
result_text = f"🎯 **Main item**: {best_pred['label'].title()}\n"
result_text += f"**Confidence**: {best_pred['score']*100:.1f}%\n\n"
if len(confident_predictions) > 1:
result_text += "**Other possibilities**:\n"
for i, pred in enumerate(confident_predictions[1:6], 1): # Top 5 seulement
result_text += f"{i}. {pred['label'].title()} ({pred['score']*100:.1f}%)\n"
# Conseils basés sur la confiance
if best_pred['score'] < 0.7:
result_text += f"\n💡 **Tip**: Low confidence. Try a clearer image with the item centered and good lighting."
else:
result_text += f"\n✅ **High confidence detection**: This is very likely a {best_pred['label']}."
return result_text, image
except Exception as e:
error_msg = str(e)
if "URL" in error_msg:
return f"❌ URL Error: {error_msg}", None
else:
return f"❌ Analysis Error: {error_msg}", None
# =============================================================================
# INTERFACE GRADIO
# =============================================================================
with gr.Blocks(
title="Fashion AI Classifier",
theme=gr.themes.Soft(),
css="""
.container { max-width: 900px; margin: 0 auto; padding: 20px; }
.header { text-align: center; margin-bottom: 30px; }
.input-section { background: #f8f9fa; padding: 20px; border-radius: 10px; }
.output-section { background: white; padding: 20px; border-radius: 10px; }
.success { color: green; }
.warning { color: orange; }
.error { color: red; }
"""
) as demo:
with gr.Column(elem_classes="container"):
gr.Markdown("""
<div class='header'>
<h1>👗 Fashion AI Classifier</h1>
<p>Upload an image or enter URL to analyze clothing items</p>
</div>
""")
with gr.Row():
with gr.Column(elem_classes="input-section"):
# Inputs séparés mais simples
image_input = gr.Image(
type="pil",
label="Upload Fashion Item",
height=200
)
url_input = gr.Textbox(
label="Or enter Image URL",
placeholder="https://example.com/image.jpg",
lines=2
)
gr.Markdown("""
**📝 Tips for better results:**
- Use clear, well-lit images
- Center the clothing item
- Use plain backgrounds when possible
- Avoid multiple items in one image
""")
analyze_btn = gr.Button(
"🔍 Analyze Item",
variant="primary",
size="lg"
)
with gr.Column(elem_classes="output-section"):
output_text = gr.Markdown(
label="Analysis Results",
value="**Results will appear here...**"
)
output_image = gr.Image(
label="Processed Image",
interactive=False,
height=300,
show_download_button=False
)
# Instructions
gr.Markdown("""
### 📋 Instructions:
- **Upload** an image **OR** enter a direct image URL
- Make sure the clothing item is clearly visible
- Well-lit images work best
- Avoid busy backgrounds
- For best results, show one item at a time
""")
# Événement de click
analyze_btn.click(
fn=analyze_fashion_item,
inputs=[image_input, url_input],
outputs=[output_text, output_image]
)
# =============================================================================
# API ENDPOINTS POUR LOVABLE
# =============================================================================
app = FastAPI()
@app.post("/api/classify")
async def api_classify(request: ClassificationRequest):
"""Endpoint API pour Lovable avec support des clés API"""
try:
# Vérification de la clé API (si configurée)
if API_KEYS and API_KEYS[0] and API_KEYS[0] != "":
if not request.api_key or request.api_key not in API_KEYS:
raise HTTPException(status_code=401, detail="Invalid API key")
if not request.image_data:
raise HTTPException(status_code=400, detail="No image data provided")
# Décode l'image base64
if ',' in request.image_data:
request.image_data = request.image_data.split(',')[1]
image_bytes = base64.b64decode(request.image_data)
image = Image.open(BytesIO(image_bytes))
image = preprocess_image(image)
# Analyse avec des inputs vides pour URL
result_text, processed_image = analyze_fashion_item(image, "")
if result_text.startswith("❌"):
raise HTTPException(status_code=400, detail=result_text)
# Convertir l'image traitée en base64 pour la réponse
buffered = BytesIO()
processed_image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return {
"success": True,
"result": result_text,
"processed_image": f"data:image/jpeg;base64,{img_str}"
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"API Error: {str(e)}")
@app.get("/api/health")
async def health_check():
"""Endpoint de santé pour vérifier que l'API fonctionne"""
return {"status": "healthy", "model_loaded": True}
# Montrer l'interface Gradio sur la racine
app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
reload=False
)