hamzenium's picture
Upload 6 files
c8c46cf verified
from pathlib import Path
from PIL import Image
import torch
from transformers import AutoImageProcessor, ViTForImageClassification
from fastapi import FastAPI, File, UploadFile
from io import BytesIO
import shutil
model_path = "./best_model"
app = FastAPI()
processor = AutoImageProcessor.from_pretrained(
model_path,
local_files_only=True,
use_fast=True
)
model = ViTForImageClassification.from_pretrained(
model_path,
local_files_only=True,
id2label={"0": "real", "1": "fake"},
label2id={"real": 0, "fake": 1}
)
def predict_image(image: Image.Image):
inputs = processor(image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
pred_id = torch.argmax(outputs.logits, dim=1).item()
pred_label = model.config.id2label[str(pred_id)]
return pred_label
@app.post("/predict/")
async def upload_file(file: UploadFile = File(...)):
image_data = await file.read()
image = Image.open(BytesIO(image_data)).convert("RGB")
prediction = predict_image(image)
return {"filename": file.filename, "prediction": prediction}