|
|
from fastapi import FastAPI, UploadFile, File |
|
|
from transformers import AutoModelForImageClassification, AutoFeatureExtractor |
|
|
from PIL import Image |
|
|
import torch |
|
|
import io |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
model_name = "mmuratarat/kvasir-v2-classifier" |
|
|
model = AutoModelForImageClassification.from_pretrained(model_name) |
|
|
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
id2label = { |
|
|
'0': 'dyed-lifted-polyps', |
|
|
'1': 'dyed-resection-margins', |
|
|
'2': 'esophagitis', |
|
|
'3': 'normal-cecum', |
|
|
'4': 'normal-pylorus', |
|
|
'5': 'normal-z-line', |
|
|
'6': 'polyps', |
|
|
'7': 'ulcerative-colitis' |
|
|
} |
|
|
|
|
|
|
|
|
polyp_mapping = { |
|
|
'dyed-lifted-polyps': "Polyp Present", |
|
|
'dyed-resection-margins': "Polyp Present", |
|
|
'polyps': "Polyp Present", |
|
|
'ulcerative-colitis': "Polyp Absent", |
|
|
'esophagitis': "Polyp Absent", |
|
|
'normal-cecum': "Polyp Absent", |
|
|
'normal-pylorus': "Polyp Absent", |
|
|
'normal-z-line': "Polyp Absent" |
|
|
} |
|
|
|
|
|
@app.post("/predict/") |
|
|
async def predict(file: UploadFile = File(...)): |
|
|
|
|
|
image_bytes = await file.read() |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
|
|
|
|
|
|
inputs = feature_extractor(image, return_tensors="pt") |
|
|
logits = model(**inputs).logits |
|
|
|
|
|
|
|
|
predicted_label = logits.argmax(-1).item() |
|
|
predicted_class = id2label[str(predicted_label)] |
|
|
|
|
|
|
|
|
polyp_status = polyp_mapping[predicted_class] |
|
|
|
|
|
return { |
|
|
"predicted_class": predicted_class, |
|
|
"polyp_status": polyp_status |
|
|
} |
|
|
|
|
|
|