File size: 5,295 Bytes
66791ba
 
 
cff0bfe
5f1bae3
 
b13493b
5cf61c7
 
5f1bae3
 
 
5cf61c7
919d79d
5cf61c7
b13493b
 
 
 
 
 
 
 
 
 
5f1bae3
 
b1cba22
 
acd685d
 
5f1bae3
b1cba22
919d79d
5f1bae3
 
 
b1cba22
5f1bae3
b13493b
5f1bae3
acd685d
 
5f1bae3
 
 
 
 
 
 
 
 
 
 
 
 
 
b1cba22
1a0da4b
b13493b
5f1bae3
b13493b
 
 
 
5f1bae3
 
b13493b
b1cba22
5f1bae3
 
 
 
 
dbadef3
5f1bae3
 
cff0bfe
5f1bae3
 
 
 
 
ca655fa
919d79d
5f1bae3
 
 
 
ca655fa
5f1bae3
ca655fa
5f1bae3
919d79d
e1eace0
5f1bae3
 
e1eace0
 
 
ca655fa
 
5f1bae3
ca655fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1eace0
ca655fa
 
 
b13493b
5f1bae3
 
b13493b
ca655fa
 
 
 
e1eace0
5f1bae3
b1cba22
5f1bae3
 
ca655fa
5f1bae3
 
 
aa56d44
5f1bae3
 
 
b1cba22
5f1bae3
 
919d79d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
os.environ['HF_HOME'] = '/tmp/cache'
os.environ['TORCH_HOME'] = '/tmp/cache'

import json
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import torch
import requests
from io import BytesIO
from transformers import CLIPProcessor, CLIPModel

app = FastAPI(title="Fashion Classification API")

# Middleware CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
    expose_headers=["*"]
)

# --- Configuration du modèle ---
print("🔄 Chargement du modèle Fashion CLIP...")
model = None
processor = None

def load_model():
    global model, processor
    try:
        model_name = "patrickjohncyh/fashion-clip"
        model = CLIPModel.from_pretrained(model_name)
        processor = CLIPProcessor.from_pretrained(model_name)
        print("✅ Modèle chargé avec succès!")
    except Exception as e:
        print(f"❌ Erreur de chargement: {e}")

# Charger le modèle au démarrage
load_model()

# Catégories en français avec mapping vers anglais
CATEGORIES_FR = {
    "haut": ["a t-shirt", "a shirt", "a sweater", "a blouse", "a top"],
    "pantalon": ["jeans", "pants", "trousers", "leggings"],
    "robe": ["a dress", "a gown", "a sundress"],
    "jupe": ["a skirt"],
    "short": ["shorts", "bermuda shorts"],
    "veste": ["a jacket", "a blazer", "a leather jacket"],
    "manteau": ["a coat", "a winter coat", "a parka"],
    "chaussures": ["sneakers", "high heels", "boots", "sandals"],
    "sac": ["a handbag", "a purse", "a backpack"],
    "accessoire": ["a hat", "sunglasses", "a scarf", "a belt"],
    "autre": ["clothing", "fashion item"]
}

@app.get("/")
def read_root():
    return {"message": "Fashion Classification API is running!", "status": "OK"}

@app.get("/health")
def health_check():
    return {
        "model_loaded": model is not None,
        "status": "ready" if model else "loading"
    }

@app.post("/classify")
async def classify_fashion(image_data: dict):
    """
    Endpoint pour Lovable - accepte une URL d'image
    """
    try:
        if not model or not processor:
            raise HTTPException(status_code=503, detail="Model not loaded yet")
        
        # Vérifier et extraire l'URL de l'image
        image_url = image_data.get("imageUrl")
        if not image_url:
            raise HTTPException(status_code=400, detail="imageUrl is required")
        
        # Télécharger l'image
        response = requests.get(image_url, timeout=30)
        response.raise_for_status()
        
        # Ouvrir et préparer l'image
        image = Image.open(BytesIO(response.content)).convert("RGB")
        image = image.resize((224, 224))  # Taille standard pour CLIP
        
        # Préparer les catégories
        all_english_categories = []
        category_mapping = {}
        
        for fr_cat, en_categories in CATEGORIES_FR.items():
            all_english_categories.extend(en_categories)
            for en_cat in en_categories:
                category_mapping[en_cat] = fr_cat
        
        # === NOUVELLE APPROCHE : Traitement séquentiel ===
        results = {}
        
        for category in all_english_categories:
            try:
                # Traiter chaque catégorie individuellement
                inputs = processor(
                    text=[category],  # Une seule catégorie à la fois
                    images=image,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=77,
                    return_overflowing_tokens=False
                )
                
                with torch.no_grad():
                    outputs = model(**inputs)
                    results[category] = outputs.logits_per_image.item()
                    
            except Exception as e:
                print(f"Erreur avec la catégorie {category}: {e}")
                results[category] = -10.0  # Valeur très basse en cas d'erreur
        
        # Trouver la meilleure catégorie
        if not results:
            raise HTTPException(status_code=500, detail="Aucun résultat obtenu")
        
        best_english_category = max(results, key=results.get)
        confidence = results[best_english_category]
        
        # Convertir le score en probabilité (approximative)
        confidence_normalized = 1 / (1 + torch.exp(torch.tensor(-confidence))).item()
        
        # Catégorie française
        best_french_category = category_mapping.get(best_english_category, "autre")
        
        return {
            "success": True,
            "category": best_french_category,
            "confidence": round(confidence_normalized, 4),
            "colorHex": "#000000",
            "originalCategory": best_english_category,
            "method": "modli-api"
        }
        
    except requests.exceptions.RequestException as e:
        raise HTTPException(status_code=400, detail=f"Invalid image URL: {str(e)}")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Classification error: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)