Overreactingwallflower's picture
Create app.py
83cbcad verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import aiohttp
import io
import base64
app = FastAPI()
# Load model (skshmjn/Pokemon-classifier-gen9-1025)
MODEL_NAME = "skshmjn/Pokemon-classifier-gen9-1025"
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
class ClassifyRequest(BaseModel):
image_url: str = None
image_data: str = None # base64
@app.get("/health")
async def health():
return {"status": "ok", "model": MODEL_NAME}
@app.post("/classify")
async def classify(request: ClassifyRequest):
try:
# Load image
if request.image_url:
async with aiohttp.ClientSession() as session:
async with session.get(request.image_url) as resp:
image_bytes = await resp.read()
image = Image.open(io.BytesIO(image_bytes))
elif request.image_data:
image_bytes = base64.b64decode(request.image_data)
image = Image.open(io.BytesIO(image_bytes))
else:
raise HTTPException(400, "No image provided")
# Preprocess
inputs = processor(images=image, return_tensors="pt")
# Predict
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=1)
top_prob, top_idx = torch.max(probs, dim=1)
# Get Pokemon name
pokemon_name = model.config.id2label[top_idx.item()]
confidence = top_prob.item()
return {
"name": pokemon_name.lower(),
"confidence": confidence
}
except Exception as e:
raise HTTPException(500, str(e))