Overreactingwallflower commited on
Commit
83cbcad
·
verified ·
1 Parent(s): c40464f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import torch
4
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
5
+ from PIL import Image
6
+ import aiohttp
7
+ import io
8
+ import base64
9
+
10
+ app = FastAPI()
11
+
12
+ # Load model (skshmjn/Pokemon-classifier-gen9-1025)
13
+ MODEL_NAME = "skshmjn/Pokemon-classifier-gen9-1025"
14
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
15
+ model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
16
+
17
+ class ClassifyRequest(BaseModel):
18
+ image_url: str = None
19
+ image_data: str = None # base64
20
+
21
+ @app.get("/health")
22
+ async def health():
23
+ return {"status": "ok", "model": MODEL_NAME}
24
+
25
+ @app.post("/classify")
26
+ async def classify(request: ClassifyRequest):
27
+ try:
28
+ # Load image
29
+ if request.image_url:
30
+ async with aiohttp.ClientSession() as session:
31
+ async with session.get(request.image_url) as resp:
32
+ image_bytes = await resp.read()
33
+ image = Image.open(io.BytesIO(image_bytes))
34
+ elif request.image_data:
35
+ image_bytes = base64.b64decode(request.image_data)
36
+ image = Image.open(io.BytesIO(image_bytes))
37
+ else:
38
+ raise HTTPException(400, "No image provided")
39
+
40
+ # Preprocess
41
+ inputs = processor(images=image, return_tensors="pt")
42
+
43
+ # Predict
44
+ with torch.no_grad():
45
+ outputs = model(**inputs)
46
+ logits = outputs.logits
47
+ probs = torch.softmax(logits, dim=1)
48
+ top_prob, top_idx = torch.max(probs, dim=1)
49
+
50
+ # Get Pokemon name
51
+ pokemon_name = model.config.id2label[top_idx.item()]
52
+ confidence = top_prob.item()
53
+
54
+ return {
55
+ "name": pokemon_name.lower(),
56
+ "confidence": confidence
57
+ }
58
+
59
+ except Exception as e:
60
+ raise HTTPException(500, str(e))