| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| import torch |
| from transformers import AutoImageProcessor, AutoModelForImageClassification |
| from PIL import Image |
| import aiohttp |
| import io |
| import base64 |
|
|
| app = FastAPI() |
|
|
| |
| MODEL_NAME = "skshmjn/Pokemon-classifier-gen9-1025" |
| processor = AutoImageProcessor.from_pretrained(MODEL_NAME) |
| model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) |
|
|
| class ClassifyRequest(BaseModel): |
| image_url: str = None |
| image_data: str = None |
|
|
| @app.get("/health") |
| async def health(): |
| return {"status": "ok", "model": MODEL_NAME} |
|
|
| @app.post("/classify") |
| async def classify(request: ClassifyRequest): |
| try: |
| |
| if request.image_url: |
| async with aiohttp.ClientSession() as session: |
| async with session.get(request.image_url) as resp: |
| image_bytes = await resp.read() |
| image = Image.open(io.BytesIO(image_bytes)) |
| elif request.image_data: |
| image_bytes = base64.b64decode(request.image_data) |
| image = Image.open(io.BytesIO(image_bytes)) |
| else: |
| raise HTTPException(400, "No image provided") |
| |
| |
| inputs = processor(images=image, return_tensors="pt") |
| |
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| logits = outputs.logits |
| probs = torch.softmax(logits, dim=1) |
| top_prob, top_idx = torch.max(probs, dim=1) |
| |
| |
| pokemon_name = model.config.id2label[top_idx.item()] |
| confidence = top_prob.item() |
| |
| return { |
| "name": pokemon_name.lower(), |
| "confidence": confidence |
| } |
| |
| except Exception as e: |
| raise HTTPException(500, str(e)) |