Dermex / api.py
Hassan73's picture
Upload api.py
2a42654 verified
raw
history blame
2.58 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch
import io
app = FastAPI(title="Skin Disease Classifier API")
print("Loading model... This may take a while the first time as it downloads from Hugging Face.")
model_name = "HotJellyBean/skin-disease-classifier"
# Load the processor and model
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
print("Model loaded successfully!")
@app.get("/")
async def root():
return {"message": "Skin Disease Classifier API is running. Send a POST request with an image to /predict."}
@app.post("/predict")
async def predict_skin_disease(file: UploadFile = File(...)):
if not file.content_type.startswith('image/'):
return JSONResponse(content={"success": False, "error": "File must be an image"}, status_code=400)
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# Make prediction
inputs = processor(images=image, return_tensors="pt")
# Make prediction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
# Get class labels from the model
labels = model.config.id2label
predicted_class = labels[predicted_class_idx]
# Calculate confidence probabilities using softmax
probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
# Map probabilities to class names
all_probs = {labels[i]: float(prob) for i, prob in enumerate(probabilities)}
# Sort probabilities from highest to lowest
sorted_probs = dict(sorted(all_probs.items(), key=lambda item: item[1], reverse=True))
return JSONResponse(content={
"success": True,
"prediction": predicted_class,
"confidence": float(probabilities[predicted_class_idx]),
"details": sorted_probs
})
except Exception as e:
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
if __name__ == "__main__":
import uvicorn
# Run the API on Hugging Face default port 7860
uvicorn.run(app, host="0.0.0.0", port=7860)