Spaces:
Sleeping
Sleeping
File size: 2,151 Bytes
878c3a9 4f24bd0 fdcda3c 4f24bd0 fdcda3c 4f24bd0 878c3a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import altair as alt
import numpy as np
import pandas as pd
import streamlit as st
from io import StringIO
from PIL import Image
import torch
import torchvision
import re
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@st.cache_resource
def load_model():
with torch.no_grad():
model = torchvision.models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
state_dict = torch.load("src/model/model.pth", map_location=torch.device('cpu'))
for k in list(state_dict.keys()):
if re.search(r'in\d+\.running_(mean|var)$', k):
del state_dict[k]
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
return model
st.title("AI vs Real Photo Classification")
# uploaded_video = st.file_uploader("Upload a video file you want to check", type=["mp4", "avi", "mov"])
uploaded_photo = st.file_uploader("Upload a photo file you want to check", type=["png", "jpg", "jpeg"], accept_multiple_files=False)
if uploaded_photo is None:
st.warning("Please upload a photo file to classify.")
else:
st.success("Photo uploaded successfully!")
image = Image.open(uploaded_photo).convert("RGB")
st.image(image, caption=uploaded_photo.name, use_container_width=True)
clicked = st.button("Classify Image", key="classify_button", use_container_width=True)
if clicked:
model = load_model()
preprocess = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image_tensor = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image_tensor)
preds = torch.argmax(output).cpu().numpy()
st.header("Classification Result")
if preds == 1:
st.success("This is a real photo!")
else:
st.error("This is an AI-generated photo!")
|