Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| from PIL import Image | |
| import torch | |
| import io | |
| app = FastAPI() | |
| MODEL_NAME = "varun1505/face-characteristics" | |
| try: | |
| processor = AutoImageProcessor.from_pretrained(MODEL_NAME) | |
| model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) | |
| print(f"Successfully loaded model: {MODEL_NAME}") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| processor = None | |
| model = None | |
| def read_root(): | |
| return {"status": "AuraSkin Face Characteristics Analyzer is running"} | |
| async def analyze_image(image_file: UploadFile = File(...)): | |
| if not model or not processor: | |
| raise HTTPException(status_code=500, detail="Model is not available.") | |
| if not image_file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="Invalid file type. Please upload an image.") | |
| try: | |
| image_bytes = await image_file.read() | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| inputs = processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # --- THIS IS A MULTI-LABEL MODEL --- | |
| # We use sigmoid to get an independent probability (0 to 1) for each label. | |
| probabilities = torch.sigmoid(logits) | |
| # --- NO THRESHOLD --- | |
| # We will return the probability for ALL possible labels. | |
| results = [] | |
| for i, score in enumerate(probabilities[0]): | |
| label = model.config.id2label[i] | |
| results.append({"label": label, "score": score.item()}) | |
| # Return the full list, sorted by score. Your app can then decide | |
| # which ones to display based on its own threshold (e.g., show all > 0.5). | |
| return sorted(results, key=lambda x: x['score'], reverse=True) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |