| | from fastapi import FastAPI, File, UploadFile, Request |
| | from fastapi.responses import HTMLResponse, JSONResponse |
| | from fastapi.staticfiles import StaticFiles |
| | from fastapi.templating import Jinja2Templates |
| | from PIL import Image |
| | import torch |
| | from transformers import AutoImageProcessor, AutoModelForImageClassification |
| | import io |
| |
|
| | app = FastAPI() |
| |
|
| | |
| | processor = AutoImageProcessor.from_pretrained("aashituli/promblemo") |
| | model = AutoModelForImageClassification.from_pretrained("aashituli/promblemo") |
| |
|
| | |
| | app.mount("/static", StaticFiles(directory="static"), name="static") |
| | templates = Jinja2Templates(directory="templates") |
| |
|
| | @app.get("/", response_class=HTMLResponse) |
| | async def home(request: Request): |
| | return templates.TemplateResponse("index.html", {"request": request}) |
| |
|
| | @app.post("/predict/") |
| | async def predict(file: UploadFile = File(...)): |
| | try: |
| | contents = await file.read() |
| | image = Image.open(io.BytesIO(contents)).convert("RGB") |
| | inputs = processor(images=image, return_tensors="pt") |
| |
|
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| |
|
| | predicted_class_idx = outputs.logits.argmax(-1).item() |
| | predicted_class = model.config.id2label[predicted_class_idx] |
| |
|
| | return JSONResponse(content={"prediction": predicted_class}) |
| |
|
| | except Exception as e: |
| | return JSONResponse(content={"error": str(e)}, status_code=500) |
| |
|