Spaces:
Sleeping
Sleeping
| import altair as alt | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| from io import StringIO | |
| from PIL import Image | |
| import torch | |
| import torchvision | |
| import re | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_model(): | |
| with torch.no_grad(): | |
| model = torchvision.models.resnet18(pretrained=False) | |
| model.fc = torch.nn.Linear(model.fc.in_features, 2) | |
| state_dict = torch.load("src/model/model.pth", map_location=torch.device('cpu')) | |
| for k in list(state_dict.keys()): | |
| if re.search(r'in\d+\.running_(mean|var)$', k): | |
| del state_dict[k] | |
| model.load_state_dict(state_dict) | |
| model = model.to(device) | |
| model.eval() | |
| return model | |
| st.title("AI vs Real Photo Classification") | |
| # uploaded_video = st.file_uploader("Upload a video file you want to check", type=["mp4", "avi", "mov"]) | |
| uploaded_photo = st.file_uploader("Upload a photo file you want to check", type=["png", "jpg", "jpeg"], accept_multiple_files=False) | |
| if uploaded_photo is None: | |
| st.warning("Please upload a photo file to classify.") | |
| else: | |
| st.success("Photo uploaded successfully!") | |
| image = Image.open(uploaded_photo).convert("RGB") | |
| st.image(image, caption=uploaded_photo.name, use_container_width=True) | |
| clicked = st.button("Classify Image", key="classify_button", use_container_width=True) | |
| if clicked: | |
| model = load_model() | |
| preprocess = torchvision.transforms.Compose([ | |
| torchvision.transforms.Resize((224, 224)), | |
| torchvision.transforms.ToTensor(), | |
| torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| image_tensor = preprocess(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(image_tensor) | |
| preds = torch.argmax(output).cpu().numpy() | |
| st.header("Classification Result") | |
| if preds == 1: | |
| st.success("This is a real photo!") | |
| else: | |
| st.error("This is an AI-generated photo!") | |