| import io |
| import torch |
| from PIL import Image |
| from fastapi import FastAPI, File, UploadFile, HTTPException |
| from transformers import AutoImageProcessor, AutoModelForImageClassification |
|
|
| |
| print("π Starting model download...") |
| model_id = "sheikh987/Skin_Cancer-Image_Classification" |
| processor = AutoImageProcessor.from_pretrained(model_id) |
| model = AutoModelForImageClassification.from_pretrained(model_id) |
| print("β
Model loaded successfully.") |
|
|
| |
| app = FastAPI(title="Skin Cancer Classifier API") |
|
|
| |
| @app.get("/status") |
| def status(): |
| return {"status": "ok", "model": model_id} |
|
|
| |
| @app.post("/predict/") |
| async def predict(file: UploadFile = File(...)): |
| if not file.content_type.startswith("image/"): |
| raise HTTPException(status_code=400, detail="Invalid image file") |
|
|
| try: |
| image_bytes = await file.read() |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| except Exception: |
| raise HTTPException(status_code=400, detail="Could not decode image") |
|
|
| inputs = processor(images=image, return_tensors="pt") |
| with torch.no_grad(): |
| outputs = model(**inputs) |
|
|
| logits = outputs.logits |
| idx = logits.argmax(-1).item() |
| label = model.config.id2label[idx] |
| confidence = torch.nn.functional.softmax(logits, dim=-1)[0][idx].item() |
|
|
| return {"label": label, "confidence": round(confidence, 4)} |