File size: 2,151 Bytes
878c3a9
 
 
 
4f24bd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdcda3c
4f24bd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdcda3c
4f24bd0
 
 
 
 
 
 
 
 
 
878c3a9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import altair as alt
import numpy as np
import pandas as pd
import streamlit as st
from io import StringIO
from PIL import Image
import torch
import torchvision
import re

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

@st.cache_resource
def load_model():
    with torch.no_grad():
        model = torchvision.models.resnet18(pretrained=False)
        model.fc = torch.nn.Linear(model.fc.in_features, 2)
        state_dict = torch.load("src/model/model.pth", map_location=torch.device('cpu'))
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        model.load_state_dict(state_dict)
        model = model.to(device)
        model.eval()
        return model

st.title("AI vs Real Photo Classification")
# uploaded_video = st.file_uploader("Upload a video file you want to check", type=["mp4", "avi", "mov"])
uploaded_photo = st.file_uploader("Upload a photo file you want to check", type=["png", "jpg", "jpeg"], accept_multiple_files=False)



if uploaded_photo is None:
    st.warning("Please upload a photo file to classify.")
else: 
    st.success("Photo uploaded successfully!")
    image = Image.open(uploaded_photo).convert("RGB")
    st.image(image, caption=uploaded_photo.name, use_container_width=True)
    clicked = st.button("Classify Image", key="classify_button", use_container_width=True)

    if clicked:
        model = load_model()
        preprocess = torchvision.transforms.Compose([
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        
        image_tensor = preprocess(image).unsqueeze(0).to(device)
        
        with torch.no_grad():
            output = model(image_tensor)
            preds = torch.argmax(output).cpu().numpy()
            st.header("Classification Result")
            if preds == 1:
                st.success("This is a real photo!")
            else:
                st.error("This is an AI-generated photo!")