Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| from PIL import Image | |
| import torch | |
| import io | |
| app = FastAPI(title="Skin Disease Classifier API") | |
| print("Loading model... This may take a while the first time as it downloads from Hugging Face.") | |
| model_name = "HotJellyBean/skin-disease-classifier" | |
| # Load the processor and model | |
| processor = AutoImageProcessor.from_pretrained(model_name) | |
| model = AutoModelForImageClassification.from_pretrained(model_name) | |
| print("Model loaded successfully!") | |
| async def root(): | |
| return {"message": "Skin Disease Classifier API is running. Send a POST request with an image to /predict."} | |
| async def predict_skin_disease(file: UploadFile = File(...)): | |
| if not file.content_type.startswith('image/'): | |
| return JSONResponse(content={"success": False, "error": "File must be an image"}, status_code=400) | |
| try: | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| # Preprocess the image | |
| inputs = processor(images=image, return_tensors="pt") | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predicted_class_idx = logits.argmax(-1).item() | |
| # Get class labels from the model | |
| labels = model.config.id2label | |
| predicted_class = labels[predicted_class_idx] | |
| # Calculate confidence probabilities using softmax | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] | |
| # Map probabilities to class names | |
| all_probs = {labels[i]: float(prob) for i, prob in enumerate(probabilities)} | |
| # Sort probabilities from highest to lowest | |
| sorted_probs = dict(sorted(all_probs.items(), key=lambda item: item[1], reverse=True)) | |
| return JSONResponse(content={ | |
| "success": True, | |
| "prediction": predicted_class, | |
| "confidence": float(probabilities[predicted_class_idx]), | |
| "details": sorted_probs | |
| }) | |
| except Exception as e: | |
| return JSONResponse(content={"success": False, "error": str(e)}, status_code=500) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # Run the API on Hugging Face default port 7860 | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |