import altair as alt import numpy as np import pandas as pd import streamlit as st from io import StringIO from PIL import Image import torch import torchvision import re device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @st.cache_resource def load_model(): with torch.no_grad(): model = torchvision.models.resnet18(pretrained=False) model.fc = torch.nn.Linear(model.fc.in_features, 2) state_dict = torch.load("src/model/model.pth", map_location=torch.device('cpu')) for k in list(state_dict.keys()): if re.search(r'in\d+\.running_(mean|var)$', k): del state_dict[k] model.load_state_dict(state_dict) model = model.to(device) model.eval() return model st.title("AI vs Real Photo Classification") # uploaded_video = st.file_uploader("Upload a video file you want to check", type=["mp4", "avi", "mov"]) uploaded_photo = st.file_uploader("Upload a photo file you want to check", type=["png", "jpg", "jpeg"], accept_multiple_files=False) if uploaded_photo is None: st.warning("Please upload a photo file to classify.") else: st.success("Photo uploaded successfully!") image = Image.open(uploaded_photo).convert("RGB") st.image(image, caption=uploaded_photo.name, use_container_width=True) clicked = st.button("Classify Image", key="classify_button", use_container_width=True) if clicked: model = load_model() preprocess = torchvision.transforms.Compose([ torchvision.transforms.Resize((224, 224)), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) image_tensor = preprocess(image).unsqueeze(0).to(device) with torch.no_grad(): output = model(image_tensor) preds = torch.argmax(output).cpu().numpy() st.header("Classification Result") if preds == 1: st.success("This is a real photo!") else: st.error("This is an AI-generated photo!")