from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from transformers import ViTForImageClassification, ViTImageProcessor import torch from PIL import Image import io app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=[ "http://localhost:5500", "https://cosmicsheep42.github.io/Pokedex/", "https://cosmicsheep42.github.io/Pokedex", "https://cosmicsheep42-backend.hf.space" ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) MAX_FILE_SIZE = 500 * 1024 * 1024 # 500MB model_id = "skshmjn/Pokemon-classifier-gen9-1025" model = ViTForImageClassification.from_pretrained(model_id) processor = ViTImageProcessor.from_pretrained(model_id) @app.post("/classify") async def classify(file: UploadFile = File(...)): if file.size and file.size > MAX_FILE_SIZE: raise HTTPException(status_code=413, detail="File too large. Maximum size is 500MB.") contents = await file.read() if len(contents) > MAX_FILE_SIZE: raise HTTPException(status_code=413, detail="File too large. Maximum size is 500MB.") image = Image.open(io.BytesIO(contents)).convert("RGB") inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) predicted_id = outputs.logits.argmax(-1).item() pokemon_name = model.config.id2label[predicted_id] return {"pokemonName": pokemon_name}