File size: 1,933 Bytes
681547b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoModelForImageClassification, AutoProcessor
from PIL import Image
import fitz  # PyMuPDF
import io

app = FastAPI()

# Allow CORS for frontend
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # You can replace with your domain
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load your Hugging Face model
model_name = "AsmaaElnagger/Diabetic_RetinoPathy_detection"
model = AutoModelForImageClassification.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)

# Convert PDF to images
def pdf_to_images(pdf_data):
    pdf_document = fitz.open(stream=pdf_data, filetype="pdf")
    images = []
    for page_num in range(pdf_document.page_count):
        page = pdf_document.load_page(page_num)
        pix = page.get_pixmap()
        img_data = pix.tobytes("jpeg")
        images.append(img_data)
    return images

@app.post("/classify")
async def classify(file: UploadFile = File(...)):
    file_type = file.filename.rsplit('.', 1)[1].lower()
    file_data = await file.read()

    try:
        if file_type in ['jpg', 'jpeg', 'png', 'gif']:
            image = Image.open(io.BytesIO(file_data)).convert("RGB")
        elif file_type == 'pdf':
            images = pdf_to_images(file_data)
            if not images:
                return {"error": "PDF conversion failed"}
            image = Image.open(io.BytesIO(images[0])).convert("RGB")
        else:
            return {"error": "Unsupported file type"}

        inputs = processor(images=image, return_tensors="pt")
        outputs = model(**inputs)
        pred = model.config.id2label[outputs.logits.argmax(-1).item()]
        return {"result": pred}

    except Exception as e:
        return {"error": str(e)}