classifierskins / main.py
sheikh987's picture
Update main.py
43f775b verified
Raw
History Blame Contribute Delete
1.19 kB
import io
import torch
from PIL import Image
from fastapi import FastAPI, File, UploadFile, HTTPException
from transformers import AutoImageProcessor, AutoModelForImageClassification
model_id = "sheikh987/Skin_Cancer-Image_Classification"
processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModelForImageClassification.from_pretrained(model_id)
app = FastAPI(title="Skin Cancer Classifier API")
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Invalid image file")
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Could not decode image")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
idx = logits.argmax(-1).item()
label = model.config.id2label[idx]
confidence = torch.nn.functional.softmax(logits, dim=-1)[0][idx].item()
return {"label": label, "confidence": round(confidence, 4)}