from pathlib import Path from PIL import Image import torch from transformers import AutoImageProcessor, ViTForImageClassification from fastapi import FastAPI, File, UploadFile from io import BytesIO import shutil model_path = "./best_model" app = FastAPI() processor = AutoImageProcessor.from_pretrained( model_path, local_files_only=True, use_fast=True ) model = ViTForImageClassification.from_pretrained( model_path, local_files_only=True, id2label={"0": "real", "1": "fake"}, label2id={"real": 0, "fake": 1} ) def predict_image(image: Image.Image): inputs = processor(image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) pred_id = torch.argmax(outputs.logits, dim=1).item() pred_label = model.config.id2label[str(pred_id)] return pred_label @app.post("/predict/") async def upload_file(file: UploadFile = File(...)): image_data = await file.read() image = Image.open(BytesIO(image_data)).convert("RGB") prediction = predict_image(image) return {"filename": file.filename, "prediction": prediction}