Joergenator commited on
Commit
b7b15fe
·
verified ·
1 Parent(s): dbcc5a7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Streamlit demo: real vs AI-generated image classifier.
2
+
3
+ Run locally with:
4
+ streamlit run app.py
5
+
6
+ Deployed on HuggingFace Spaces — model weights are pulled from
7
+ HF Hub on first use and cached to disk. See README.md for details.
8
+ """
9
+
10
+ import streamlit as st
11
+ from PIL import Image
12
+
13
+ from src.predict import MODEL_REGISTRY, load_model, predict_image
14
+
15
+
16
+ st.set_page_config(
17
+ page_title="Real vs AI-generated image classifier",
18
+ page_icon="\U0001F5BC️",
19
+ layout="centered",
20
+ )
21
+
22
+ st.title("Real vs AI-generated image classifier")
23
+ st.write(
24
+ "Course project for DAT255 — Deep Learning Engineering. "
25
+ "Pick a model, upload an image, and see whether the model thinks "
26
+ "it's a real photograph or AI-generated."
27
+ )
28
+
29
+
30
+ @st.cache_resource(show_spinner="Loading model weights...")
31
+ def _get_model(tag: str):
32
+ return load_model(tag, device="cpu")
33
+
34
+
35
+ tag_by_label = {spec.display_name: tag for tag, spec in MODEL_REGISTRY.items()}
36
+
37
+ chosen_label = st.selectbox(
38
+ "Model",
39
+ list(tag_by_label.keys()),
40
+ index=0,
41
+ help="Test AUC on the held-out test set is shown in the caption below.",
42
+ )
43
+ chosen_tag = tag_by_label[chosen_label]
44
+ chosen_spec = MODEL_REGISTRY[chosen_tag]
45
+ st.caption(f"Test AUC: {chosen_spec.test_auc:.4f}")
46
+
47
+ uploaded = st.file_uploader(
48
+ "Upload an image (JPG, PNG, WebP)",
49
+ type=["jpg", "jpeg", "png", "webp"],
50
+ )
51
+
52
+ if uploaded is not None:
53
+ image = Image.open(uploaded)
54
+ st.image(image, caption="Your image", use_column_width=True)
55
+
56
+ model = _get_model(chosen_tag)
57
+ with st.spinner("Running inference..."):
58
+ prob_ai, label = predict_image(model, image, device="cpu")
59
+
60
+ if label == "AI-generated":
61
+ st.error(f"Prediction: **{label}**")
62
+ else:
63
+ st.success(f"Prediction: **{label}**")
64
+
65
+ st.write(f"Probability the image is AI-generated: **{prob_ai:.2%}**")
66
+ st.progress(prob_ai)
67
+
68
+ with st.expander("What does this number mean?"):
69
+ st.write(
70
+ "The model outputs a single number between 0 and 1 "
71
+ "(a sigmoid of its internal logit). 0 means confidently real, "
72
+ "1 means confidently AI-generated. The label above uses a "
73
+ "threshold of 0.5."
74
+ )
75
+
76
+ st.divider()
77
+ st.caption(
78
+ "Models were trained on a 60 000-image dataset split 80/10/10. "
79
+ "The three transfer-learning models fine-tune ImageNet-pretrained "
80
+ "backbones; the scratch ResNet-50 was trained from random "
81
+ "initialisation with ReLU replaced by GELU throughout the network."
82
+ )