File size: 5,626 Bytes
2748153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import sys

try:
    import streamlit as st
    import torch
    import os
    from SeqXGPT.model import TransformerOnlyClassifier
    from SeqXGPT.generate_features import FeatureExtractor
    from SeqXGPT.transform import predict_with_model
except ImportError as e:
    print(f"Error importing required modules: {e}")
    print("Please make sure all dependencies are installed by running: pip install -r requirements.txt")
    sys.exit(1)

# Set page configuration
st.set_page_config(
    page_title="SeqXGPT - AI Text Detector",
    page_icon="🔍",
    layout="centered",
    initial_sidebar_state="collapsed"
)

id2label = {0: 'gpt2', 1: 'llama', 2: 'human', 3: 'gpt3re'}

@st.cache_resource
def load_extractor():
    return FeatureExtractor()

@st.cache_resource
def load_model(ckpt_name):
    """Unified model loading method"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = TransformerOnlyClassifier(id2labels=id2label, seq_len=1024)
    
    if not os.path.exists(ckpt_name):
        print(f"Warning: Checkpoint {ckpt_name} not found")
        return None
    
    try:
        state_dict = torch.load(
            ckpt_name,
            map_location=device,
            weights_only=True
        )
        model.load_state_dict(state_dict)
        model.to(device)
        model.eval()
        print(f"Model loaded from {ckpt_name}")
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

# Initialize feature extractor
with st.spinner("Loading feature extractor and model..."):
    extractor = load_extractor()
    model = load_model('SeqXGPT/transformer_cls.pt')
    if model is None:
        st.error("Failed to load model checkpoint")

# Custom CSS for better appearance
st.markdown("""
<style>
    .main-header {
        font-size: 2.5rem;
        color: #2563eb;
        text-align: center;
        text-shadow: 1px 1px 2px rgba(0,0,0,0.1);
    }
    .sub-header {
        font-size: 1.2rem;
        color: #64748b;
        text-align: center;
        margin-bottom: 2rem;
    }
    .result-box {
        padding: 1.5rem;
        border-radius: 0.75rem;
        margin: 1rem 0;
        box-shadow: 0 4px 6px -1px rgba(0,0,0,0.1);
        color: #333333;
    }
    .human-text {
        background-color: #c6f6d5;
        border-left: 5px solid #059669;
        border-bottom: 1px solid rgba(5,150,105,0.1);
    }
    .ai-text {
        background-color: #fed7d7;
        border-left: 5px solid #dc2626;
        border-bottom: 1px solid rgba(220,38,38,0.1);
    }
</style>
""", unsafe_allow_html=True)

# App header
st.markdown("<h1 class='main-header'>SeqXGPT</h1>", unsafe_allow_html=True)
st.markdown("<p class='sub-header'>AI Text Detection Made Simple</p>", unsafe_allow_html=True)

# Text input area
text_input = st.text_area("Enter text to analyze:", height=200, 
                          placeholder="Paste or type text here to check if it's AI-generated...")

# Analysis button
if st.button("Analyze Text", type="primary", use_container_width=True):
    if not text_input or len(text_input.strip()) < 50:
        st.error("Please enter at least 50 characters.")
    else:
        with st.spinner("Analyzing text..."):
            try:
                # Extract features
                features = extractor.extract_features(text_input)
                
                # Make prediction
                result = predict_with_model(model, features, id2label)
                
                # Get prediction and confidence
                prediction = result['text_prediction']
                logits = result['token_logits']
                
                # Calculate confidence score
                confidence = (result['token_predictions'].count(prediction) / len(result['token_predictions'])) * 100
                
                # Display results
                st.markdown("### Analysis Results")
                
                # Progress bar for confidence
                st.progress(min(max(confidence/100, 0.0), 1.0))
                
                # Verdict based on prediction
                if prediction in ['gpt2', 'llama', 'gpt3re']:
                    st.markdown(f"""
                    <div class='result-box ai-text'>
                        <h3>🤖 Likely AI-Generated ({confidence:.1f}%)</h3>
                        <p>This text shows strong indicators of being generated by {prediction.upper()}.</p>
                    </div>
                    """, unsafe_allow_html=True)
                elif prediction == 'human':
                    st.markdown(f"""
                    <div class='result-box human-text'>
                        <h3>👤 Likely Human-Written ({confidence:.1f}%)</h3>
                        <p>This text appears to be written by a human.</p>
                    </div>
                    """, unsafe_allow_html=True)
                else:
                    st.markdown(f"""
                    <div class='result-box'>
                        <h3>⚠️ Uncertain ({confidence:.1f}%)</h3>
                        <p>Unable to confidently determine if this text is AI-generated or human-written.</p>
                    </div>
                    """, unsafe_allow_html=True)
                
            except Exception as e:
                st.error(f"An error occurred during analysis: {str(e)}")
                st.info("Try with a different text or check if the model is properly loaded.")

# Footer
st.markdown("---")
st.markdown("SeqXGPT is an AI text detection tool that analyzes text patterns to identify AI-generated content.")