Spaces:
Sleeping
Sleeping
File size: 1,484 Bytes
7b84d44 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 | from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import ViTForImageClassification, ViTImageProcessor
import torch
from PIL import Image
import io
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:5500",
"https://cosmicsheep42.github.io/Pokedex/",
"https://cosmicsheep42.github.io/Pokedex",
"https://cosmicsheep42-backend.hf.space"
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
MAX_FILE_SIZE = 500 * 1024 * 1024 # 500MB
model_id = "skshmjn/Pokemon-classifier-gen9-1025"
model = ViTForImageClassification.from_pretrained(model_id)
processor = ViTImageProcessor.from_pretrained(model_id)
@app.post("/classify")
async def classify(file: UploadFile = File(...)):
if file.size and file.size > MAX_FILE_SIZE:
raise HTTPException(status_code=413, detail="File too large. Maximum size is 500MB.")
contents = await file.read()
if len(contents) > MAX_FILE_SIZE:
raise HTTPException(status_code=413, detail="File too large. Maximum size is 500MB.")
image = Image.open(io.BytesIO(contents)).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
predicted_id = outputs.logits.argmax(-1).item()
pokemon_name = model.config.id2label[predicted_id]
return {"pokemonName": pokemon_name}
|