from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from transformers import AutoImageProcessor, AutoModelForImageClassification from PIL import Image import torch import io app = FastAPI(title="Skin Disease Classifier API") print("Loading model... This may take a while the first time as it downloads from Hugging Face.") model_name = "HotJellyBean/skin-disease-classifier" # Load the processor and model processor = AutoImageProcessor.from_pretrained(model_name) model = AutoModelForImageClassification.from_pretrained(model_name) print("Model loaded successfully!") @app.get("/") async def root(): return {"message": "Skin Disease Classifier API is running. Send a POST request with an image to /predict."} @app.post("/predict") async def predict_skin_disease(file: UploadFile = File(...)): if not file.content_type.startswith('image/'): return JSONResponse(content={"success": False, "error": "File must be an image"}, status_code=400) try: contents = await file.read() image = Image.open(io.BytesIO(contents)).convert("RGB") # Preprocess the image inputs = processor(images=image, return_tensors="pt") # Make prediction with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() # Get class labels from the model labels = model.config.id2label predicted_class = labels[predicted_class_idx] # Calculate confidence probabilities using softmax probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] # Map probabilities to class names all_probs = {labels[i]: float(prob) for i, prob in enumerate(probabilities)} # Sort probabilities from highest to lowest sorted_probs = dict(sorted(all_probs.items(), key=lambda item: item[1], reverse=True)) return JSONResponse(content={ "success": True, "prediction": predicted_class, "confidence": float(probabilities[predicted_class_idx]), "details": sorted_probs }) except Exception as e: return JSONResponse(content={"success": False, "error": str(e)}, status_code=500) if __name__ == "__main__": import uvicorn # Run the API on Hugging Face default port 7860 uvicorn.run(app, host="0.0.0.0", port=7860)