Spaces:
Sleeping
Sleeping
File size: 2,584 Bytes
1626835 8b1ad5f 1626835 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch
import io
app = FastAPI(title="Skin Disease Classifier API")
print("Loading model... This may take a while the first time as it downloads from Hugging Face.")
model_name = "HotJellyBean/skin-disease-classifier"
# Load the processor and model
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
print("Model loaded successfully!")
@app.get("/")
async def root():
return {"message": "Skin Disease Classifier API is running. Send a POST request with an image to /predict."}
@app.post("/predict")
async def predict_skin_disease(file: UploadFile = File(...)):
if not file.content_type.startswith('image/'):
return JSONResponse(content={"success": False, "error": "File must be an image"}, status_code=400)
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# Preprocess the image
inputs = processor(images=image, return_tensors="pt")
# Make prediction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
# Get class labels from the model
labels = model.config.id2label
predicted_class = labels[predicted_class_idx]
# Calculate confidence probabilities using softmax
probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
# Map probabilities to class names
all_probs = {labels[i]: float(prob) for i, prob in enumerate(probabilities)}
# Sort probabilities from highest to lowest
sorted_probs = dict(sorted(all_probs.items(), key=lambda item: item[1], reverse=True))
return JSONResponse(content={
"success": True,
"prediction": predicted_class,
"confidence": float(probabilities[predicted_class_idx]),
"details": sorted_probs
})
except Exception as e:
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
if __name__ == "__main__":
import uvicorn
# Run the API on Hugging Face default port 7860
uvicorn.run(app, host="0.0.0.0", port=7860)
|