Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from transformers import ViTForImageClassification, ViTImageProcessor | |
| import torch | |
| from PIL import Image | |
| import io | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "http://localhost:5500", | |
| "https://cosmicsheep42.github.io/Pokedex/", | |
| "https://cosmicsheep42.github.io/Pokedex", | |
| "https://cosmicsheep42-backend.hf.space" | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| MAX_FILE_SIZE = 500 * 1024 * 1024 # 500MB | |
| model_id = "skshmjn/Pokemon-classifier-gen9-1025" | |
| model = ViTForImageClassification.from_pretrained(model_id) | |
| processor = ViTImageProcessor.from_pretrained(model_id) | |
| async def classify(file: UploadFile = File(...)): | |
| if file.size and file.size > MAX_FILE_SIZE: | |
| raise HTTPException(status_code=413, detail="File too large. Maximum size is 500MB.") | |
| contents = await file.read() | |
| if len(contents) > MAX_FILE_SIZE: | |
| raise HTTPException(status_code=413, detail="File too large. Maximum size is 500MB.") | |
| image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| inputs = processor(images=image, return_tensors="pt") | |
| outputs = model(**inputs) | |
| predicted_id = outputs.logits.argmax(-1).item() | |
| pokemon_name = model.config.id2label[predicted_id] | |
| return {"pokemonName": pokemon_name} | |