|
|
from pathlib import Path |
|
|
from PIL import Image |
|
|
import torch |
|
|
from transformers import AutoImageProcessor, ViTForImageClassification |
|
|
from fastapi import FastAPI, File, UploadFile |
|
|
from io import BytesIO |
|
|
import shutil |
|
|
|
|
|
model_path = "./best_model" |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
processor = AutoImageProcessor.from_pretrained( |
|
|
model_path, |
|
|
local_files_only=True, |
|
|
use_fast=True |
|
|
) |
|
|
|
|
|
model = ViTForImageClassification.from_pretrained( |
|
|
model_path, |
|
|
local_files_only=True, |
|
|
id2label={"0": "real", "1": "fake"}, |
|
|
label2id={"real": 0, "fake": 1} |
|
|
) |
|
|
|
|
|
def predict_image(image: Image.Image): |
|
|
inputs = processor(image, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
pred_id = torch.argmax(outputs.logits, dim=1).item() |
|
|
pred_label = model.config.id2label[str(pred_id)] |
|
|
return pred_label |
|
|
|
|
|
@app.post("/predict/") |
|
|
async def upload_file(file: UploadFile = File(...)): |
|
|
image_data = await file.read() |
|
|
image = Image.open(BytesIO(image_data)).convert("RGB") |
|
|
prediction = predict_image(image) |
|
|
return {"filename": file.filename, "prediction": prediction} |