KVASIR / app.py
3v324v23's picture
Add application file
d29727d
from fastapi import FastAPI, UploadFile, File
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
from PIL import Image
import torch
import io
app = FastAPI()
# Load model and feature extractor
model_name = "mmuratarat/kvasir-v2-classifier"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
# Class ID to Label mapping
id2label = {
'0': 'dyed-lifted-polyps',
'1': 'dyed-resection-margins',
'2': 'esophagitis',
'3': 'normal-cecum',
'4': 'normal-pylorus',
'5': 'normal-z-line',
'6': 'polyps',
'7': 'ulcerative-colitis'
}
# Mapping for polyp presence
polyp_mapping = {
'dyed-lifted-polyps': "Polyp Present",
'dyed-resection-margins': "Polyp Present",
'polyps': "Polyp Present",
'ulcerative-colitis': "Polyp Absent", # UC does not mean polyps
'esophagitis': "Polyp Absent",
'normal-cecum': "Polyp Absent",
'normal-pylorus': "Polyp Absent",
'normal-z-line': "Polyp Absent"
}
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
# Read image file
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Process the image
inputs = feature_extractor(image, return_tensors="pt")
logits = model(**inputs).logits
# Get predicted class
predicted_label = logits.argmax(-1).item()
predicted_class = id2label[str(predicted_label)]
# Determine polyp presence
polyp_status = polyp_mapping[predicted_class]
return {
"predicted_class": predicted_class,
"polyp_status": polyp_status
}