chuodinity commited on
Commit
90bf132
·
verified ·
1 Parent(s): b1d9107

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +104 -0
  2. requirements.txt +7 -2
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ from PIL import Image
5
+ from transformers import (
6
+ AutoTokenizer, AutoConfig, AutoModel, PreTrainedModel,
7
+ pipeline, ViTImageProcessor, ViTForImageClassification
8
+ )
9
+
10
+ # --- DESKLIB TEXT DETECTOR ARCHITECTURE ---
11
+ class DesklibAIDetectionModel(PreTrainedModel):
12
+ config_class = AutoConfig
13
+ # NEW: Add this line to satisfy the latest Transformers internal checks
14
+ _tied_weights_keys = {}
15
+
16
+ def __init__(self, config):
17
+ super().__init__(config)
18
+ self.model = AutoModel.from_config(config)
19
+ self.classifier = nn.Linear(config.hidden_size, 1)
20
+
21
+ # NEW: Always call post_init at the end of __init__
22
+ self._tied_weights_keys = {}
23
+ if not hasattr(self, "_keys_to_ignore_on_save"):
24
+ self._keys_to_ignore_on_save = []
25
+
26
+ self.post_init()
27
+
28
+ def forward(self, input_ids, attention_mask=None):
29
+ outputs = self.model(input_ids, attention_mask=attention_mask)
30
+ last_hidden_state = outputs[0]
31
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
32
+ mean_pooled = torch.sum(last_hidden_state * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
33
+ return self.classifier(mean_pooled)
34
+
35
+ # --- LOAD SPECIALIZED MODELS ---
36
+ @st.cache_resource
37
+ def load_assets():
38
+ device = "cuda" if torch.cuda.is_available() else "cpu"
39
+
40
+ # Text Model (Desklib)
41
+ text_model_id = "desklib/ai-text-detector-v1.01"
42
+ t_tokenizer = AutoTokenizer.from_pretrained(text_model_id)
43
+ t_model = DesklibAIDetectionModel.from_pretrained(text_model_id).to(device)
44
+
45
+ # Image Model (Specialized ViT for AIGC)
46
+ img_model_id = "capcheck/ai-image-detection"
47
+ img_pipe = pipeline("image-classification", model=img_model_id, device=0 if device == "cuda" else -1)
48
+
49
+ return t_tokenizer, t_model, img_pipe, device
50
+
51
+ tokenizer, text_model, img_pipeline, device = load_assets()
52
+
53
+ # --- UI INTERFACE ---
54
+ st.set_page_config(page_title="AIGC Late Fusion Detector", layout="wide")
55
+ st.title("🛡️ Specialized Multimodal AIGC Detector")
56
+
57
+ col_in, col_out = st.columns([1, 1])
58
+
59
+ with col_in:
60
+ st.subheader("Input Content")
61
+ uploaded_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
62
+ user_text = st.text_area("Input Text", placeholder="Paste article or caption...", height=200)
63
+
64
+ if uploaded_file:
65
+ st.image(Image.open(uploaded_file), caption="Uploaded Image", use_container_width=True)
66
+
67
+ # --- PROCESSING ---
68
+ if st.button("Run Multi-Modal Detection") and uploaded_file and user_text:
69
+ with st.spinner("Analyzing artifacts in text and pixels..."):
70
+ # 1. Text Score (Logit -> Sigmoid)
71
+ t_inputs = tokenizer(user_text, return_tensors="pt", truncation=True, padding=True).to(device)
72
+ with torch.no_grad():
73
+ t_logit = text_model(t_inputs['input_ids'], t_inputs['attention_mask'])
74
+ p_text = torch.sigmoid(t_logit).item()
75
+
76
+ # 2. Image Score (AIGC ViT)
77
+ img_results = img_pipeline(Image.open(uploaded_file))
78
+ # Find the score for 'FAKE' (AI-generated), case-insensitive, with safe fallback
79
+ p_image = next((item['score'] for item in img_results if item['label'].upper() == 'FAKE'), 0.0)
80
+
81
+ # 3. Late Fusion (Weighted Average)
82
+ # Using 0.5/0.5 for balanced multimodal detection
83
+ fused_score = (0.5 * p_text) + (0.5 * p_image)
84
+
85
+ with col_out:
86
+ st.subheader("System Verdict")
87
+
88
+ # Classification logic
89
+ verdict = "AI-GENERATED" if fused_score > 0.5 else "HUMAN-ORIGIN"
90
+ color = "red" if verdict == "AI-GENERATED" else "green"
91
+
92
+ st.markdown(f"### Result: :{color}[{verdict}]")
93
+ st.metric("Aggregate Confidence", f"{fused_score:.2%}")
94
+
95
+ # Visual Breakdown
96
+ st.write("**Modality Breakdown:**")
97
+ st.progress(p_text, text=f"Text AI Probability: {p_text:.1%}")
98
+ st.progress(p_image, text=f"Image AI Probability: {p_image:.1%}")
99
+
100
+ # Brief Forensic Note
101
+ if fused_score > 0.5:
102
+ st.warning("Conclusion: High cross-modal artifact detection. The content shows patterns consistent with synthetic generation.")
103
+ else:
104
+ st.success("Conclusion: Low probability of AI generation. Features align with natural human patterns.")
requirements.txt CHANGED
@@ -1,3 +1,8 @@
1
- altair
 
 
 
 
2
  pandas
3
- streamlit
 
 
1
+ streamlit
2
+ torch
3
+ transformers
4
+ Pillow
5
+ plotly
6
  pandas
7
+ numpy
8
+ accelerate