File size: 1,682 Bytes
d29727d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | from fastapi import FastAPI, UploadFile, File
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
from PIL import Image
import torch
import io
app = FastAPI()
# Load model and feature extractor
model_name = "mmuratarat/kvasir-v2-classifier"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
# Class ID to Label mapping
id2label = {
'0': 'dyed-lifted-polyps',
'1': 'dyed-resection-margins',
'2': 'esophagitis',
'3': 'normal-cecum',
'4': 'normal-pylorus',
'5': 'normal-z-line',
'6': 'polyps',
'7': 'ulcerative-colitis'
}
# Mapping for polyp presence
polyp_mapping = {
'dyed-lifted-polyps': "Polyp Present",
'dyed-resection-margins': "Polyp Present",
'polyps': "Polyp Present",
'ulcerative-colitis': "Polyp Absent", # UC does not mean polyps
'esophagitis': "Polyp Absent",
'normal-cecum': "Polyp Absent",
'normal-pylorus': "Polyp Absent",
'normal-z-line': "Polyp Absent"
}
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
# Read image file
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Process the image
inputs = feature_extractor(image, return_tensors="pt")
logits = model(**inputs).logits
# Get predicted class
predicted_label = logits.argmax(-1).item()
predicted_class = id2label[str(predicted_label)]
# Determine polyp presence
polyp_status = polyp_mapping[predicted_class]
return {
"predicted_class": predicted_class,
"polyp_status": polyp_status
}
|