class / main.py
sheikh987's picture
Update main.py
5f2ae70 verified
Raw
History Blame Contribute Delete
1.44 kB
import io
import torch
from PIL import Image
from fastapi import FastAPI, File, UploadFile, HTTPException
from transformers import AutoImageProcessor, AutoModelForImageClassification
# Log model loading
print("πŸš€ Starting model download...")
model_id = "sheikh987/Skin_Cancer-Image_Classification"
processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModelForImageClassification.from_pretrained(model_id)
print("βœ… Model loaded successfully.")
# FastAPI app
app = FastAPI(title="Skin Cancer Classifier API")
# Health check endpoint
@app.get("/status")
def status():
return {"status": "ok", "model": model_id}
# Image classification endpoint
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Invalid image file")
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Could not decode image")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
idx = logits.argmax(-1).item()
label = model.config.id2label[idx]
confidence = torch.nn.functional.softmax(logits, dim=-1)[0][idx].item()
return {"label": label, "confidence": round(confidence, 4)}