File size: 1,682 Bytes
d29727d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from fastapi import FastAPI, UploadFile, File
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
from PIL import Image
import torch
import io

app = FastAPI()

# Load model and feature extractor
model_name = "mmuratarat/kvasir-v2-classifier"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)

# Class ID to Label mapping
id2label = {
    '0': 'dyed-lifted-polyps',
    '1': 'dyed-resection-margins',
    '2': 'esophagitis',
    '3': 'normal-cecum',
    '4': 'normal-pylorus',
    '5': 'normal-z-line',
    '6': 'polyps',
    '7': 'ulcerative-colitis'
}

# Mapping for polyp presence
polyp_mapping = {
    'dyed-lifted-polyps': "Polyp Present",
    'dyed-resection-margins': "Polyp Present",
    'polyps': "Polyp Present",
    'ulcerative-colitis': "Polyp Absent",  # UC does not mean polyps
    'esophagitis': "Polyp Absent",
    'normal-cecum': "Polyp Absent",
    'normal-pylorus': "Polyp Absent",
    'normal-z-line': "Polyp Absent"
}

@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
    # Read image file
    image_bytes = await file.read()
    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

    # Process the image
    inputs = feature_extractor(image, return_tensors="pt")
    logits = model(**inputs).logits

    # Get predicted class
    predicted_label = logits.argmax(-1).item()
    predicted_class = id2label[str(predicted_label)]
    
    # Determine polyp presence
    polyp_status = polyp_mapping[predicted_class]

    return {
        "predicted_class": predicted_class,
        "polyp_status": polyp_status
    }