File size: 4,364 Bytes
90bf132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03a7087
90bf132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import streamlit as st
import torch
import torch.nn as nn
from PIL import Image
from transformers import (
    AutoTokenizer, AutoConfig, AutoModel, PreTrainedModel, 
    pipeline, ViTImageProcessor, ViTForImageClassification
)

# --- DESKLIB TEXT DETECTOR ARCHITECTURE ---
class DesklibAIDetectionModel(PreTrainedModel):
    config_class = AutoConfig
    _tied_weights_keys = {}

    def __init__(self, config):
        super().__init__(config)
        self.model = AutoModel.from_config(config)
        self.classifier = nn.Linear(config.hidden_size, 1)
        
        # NEW: Always call post_init at the end of __init__
        self._tied_weights_keys = {}
        if not hasattr(self, "_keys_to_ignore_on_save"):
            self._keys_to_ignore_on_save = []
            
        self.post_init()

    def forward(self, input_ids, attention_mask=None):
        outputs = self.model(input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        mean_pooled = torch.sum(last_hidden_state * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return self.classifier(mean_pooled)

# --- LOAD SPECIALIZED MODELS ---
@st.cache_resource
def load_assets():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Text Model (Desklib)
    text_model_id = "desklib/ai-text-detector-v1.01"
    t_tokenizer = AutoTokenizer.from_pretrained(text_model_id)
    t_model = DesklibAIDetectionModel.from_pretrained(text_model_id).to(device)
    
    # Image Model (Specialized ViT for AIGC)
    img_model_id = "capcheck/ai-image-detection"
    img_pipe = pipeline("image-classification", model=img_model_id, device=0 if device == "cuda" else -1)
    
    return t_tokenizer, t_model, img_pipe, device

tokenizer, text_model, img_pipeline, device = load_assets()

# --- UI INTERFACE ---
st.set_page_config(page_title="AIGC Late Fusion Detector", layout="wide")
st.title("OriSense")

col_in, col_out = st.columns([1, 1])

with col_in:
    st.subheader("Input Content")
    uploaded_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
    user_text = st.text_area("Input Text", placeholder="Paste article or caption...", height=200)
    
    if uploaded_file:
        st.image(Image.open(uploaded_file), caption="Uploaded Image", use_container_width=True)

# --- PROCESSING ---
if st.button("Run Multi-Modal Detection") and uploaded_file and user_text:
    with st.spinner("Analyzing artifacts in text and pixels..."):
        # 1. Text Score (Logit -> Sigmoid)
        t_inputs = tokenizer(user_text, return_tensors="pt", truncation=True, padding=True).to(device)
        with torch.no_grad():
            t_logit = text_model(t_inputs['input_ids'], t_inputs['attention_mask'])
            p_text = torch.sigmoid(t_logit).item()

        # 2. Image Score (AIGC ViT)
        img_results = img_pipeline(Image.open(uploaded_file))
        # Find the score for 'FAKE' (AI-generated), case-insensitive, with safe fallback
        p_image = next((item['score'] for item in img_results if item['label'].upper() == 'FAKE'), 0.0)

        # 3. Late Fusion (Weighted Average)
        # Using 0.5/0.5 for balanced multimodal detection
        fused_score = (0.5 * p_text) + (0.5 * p_image)

        with col_out:
            st.subheader("System Verdict")
            
            # Classification logic
            verdict = "AI-GENERATED" if fused_score > 0.5 else "HUMAN-ORIGIN"
            color = "red" if verdict == "AI-GENERATED" else "green"
            
            st.markdown(f"### Result: :{color}[{verdict}]")
            st.metric("Aggregate Confidence", f"{fused_score:.2%}")
            
            # Visual Breakdown
            st.write("**Modality Breakdown:**")
            st.progress(p_text, text=f"Text AI Probability: {p_text:.1%}")
            st.progress(p_image, text=f"Image AI Probability: {p_image:.1%}")
            
            # Brief Forensic Note
            if fused_score > 0.5:
                st.warning("Conclusion: High cross-modal artifact detection. The content shows patterns consistent with synthetic generation.")
            else:
                st.success("Conclusion: Low probability of AI generation. Features align with natural human patterns.")