Rujit's picture
Update app.py
0fe1a7a verified
import gradio as gr
import torch
from torchvision import models, transforms
from PIL import Image
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
# Class labels
class_names = ['fake', 'real']
# Image transform
data_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load model
def load_model():
model = models.densenet121(weights='IMAGENET1K_V1')
model.classifier = nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, 2)
)
device = torch.device('cpu') # Use CPU for Hugging Face
model = model.to(device)
checkpoint = torch.load("best_model.pth", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model, device
model, device = load_model()
# Inference function
def predict(image):
# Convert numpy array to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Convert RGBA to RGB if needed
if image.mode == "RGBA":
image = image.convert("RGB")
# Apply transforms
image = data_transforms(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(image)
probs = F.softmax(outputs, dim=1)
conf, pred = torch.max(probs, 1)
label = class_names[pred.item()]
confidence = f"{conf.item() * 100:.2f}%"
return f"{label} ({confidence})"
# Gradio interface
demo = gr.Interface(fn=predict, inputs="image", outputs="text", api_name="predict")
demo.launch()