seqxgpt_demo / app.py
Jinglong Xiong
change structure for huggingface space
43d8695
import sys
try:
import streamlit as st
import torch
import os
from SeqXGPT.model import TransformerOnlyClassifier
from SeqXGPT.generate_features import FeatureExtractor
from SeqXGPT.transform import predict_with_model
except ImportError as e:
print(f"Error importing required modules: {e}")
print("Please make sure all dependencies are installed by running: pip install -r requirements.txt")
sys.exit(1)
# Set page configuration
st.set_page_config(
page_title="SeqXGPT - AI Text Detector",
page_icon="🔍",
layout="centered",
initial_sidebar_state="collapsed"
)
id2label = {0: 'gpt2', 1: 'llama', 2: 'human', 3: 'gpt3re'}
@st.cache_resource
def load_extractor():
return FeatureExtractor()
@st.cache_resource
def load_model(ckpt_name):
"""Unified model loading method"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TransformerOnlyClassifier(id2labels=id2label, seq_len=1024)
if not os.path.exists(ckpt_name):
print(f"Warning: Checkpoint {ckpt_name} not found")
return None
try:
state_dict = torch.load(
ckpt_name,
map_location=device,
weights_only=True
)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
print(f"Model loaded from {ckpt_name}")
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
# Initialize feature extractor
with st.spinner("Loading feature extractor and model..."):
extractor = load_extractor()
model = load_model('SeqXGPT/transformer_cls.pt')
if model is None:
st.error("Failed to load model checkpoint")
# Custom CSS for better appearance
st.markdown("""
<style>
.main-header {
font-size: 2.5rem;
color: #2563eb;
text-align: center;
text-shadow: 1px 1px 2px rgba(0,0,0,0.1);
}
.sub-header {
font-size: 1.2rem;
color: #64748b;
text-align: center;
margin-bottom: 2rem;
}
.result-box {
padding: 1.5rem;
border-radius: 0.75rem;
margin: 1rem 0;
box-shadow: 0 4px 6px -1px rgba(0,0,0,0.1);
color: #333333;
}
.human-text {
background-color: #c6f6d5;
border-left: 5px solid #059669;
border-bottom: 1px solid rgba(5,150,105,0.1);
}
.ai-text {
background-color: #fed7d7;
border-left: 5px solid #dc2626;
border-bottom: 1px solid rgba(220,38,38,0.1);
}
</style>
""", unsafe_allow_html=True)
# App header
st.markdown("<h1 class='main-header'>SeqXGPT</h1>", unsafe_allow_html=True)
st.markdown("<p class='sub-header'>AI Text Detection Made Simple</p>", unsafe_allow_html=True)
# Text input area
text_input = st.text_area("Enter text to analyze:", height=200,
placeholder="Paste or type text here to check if it's AI-generated...")
# Analysis button
if st.button("Analyze Text", type="primary", use_container_width=True):
if not text_input or len(text_input.strip()) < 50:
st.error("Please enter at least 50 characters.")
else:
with st.spinner("Analyzing text..."):
try:
# Extract features
features = extractor.extract_features(text_input)
# Make prediction
result = predict_with_model(model, features, id2label)
# Get prediction and confidence
prediction = result['text_prediction']
logits = result['token_logits']
# Calculate confidence score
confidence = (result['token_predictions'].count(prediction) / len(result['token_predictions'])) * 100
# Display results
st.markdown("### Analysis Results")
# Progress bar for confidence
st.progress(min(max(confidence/100, 0.0), 1.0))
# Verdict based on prediction
if prediction in ['gpt2', 'llama', 'gpt3re']:
st.markdown(f"""
<div class='result-box ai-text'>
<h3>🤖 Likely AI-Generated ({confidence:.1f}%)</h3>
<p>This text shows strong indicators of being generated by {prediction.upper()}.</p>
</div>
""", unsafe_allow_html=True)
elif prediction == 'human':
st.markdown(f"""
<div class='result-box human-text'>
<h3>👤 Likely Human-Written ({confidence:.1f}%)</h3>
<p>This text appears to be written by a human.</p>
</div>
""", unsafe_allow_html=True)
else:
st.markdown(f"""
<div class='result-box'>
<h3>⚠️ Uncertain ({confidence:.1f}%)</h3>
<p>Unable to confidently determine if this text is AI-generated or human-written.</p>
</div>
""", unsafe_allow_html=True)
except Exception as e:
st.error(f"An error occurred during analysis: {str(e)}")
st.info("Try with a different text or check if the model is properly loaded.")
# Footer
st.markdown("---")
st.markdown("SeqXGPT is an AI text detection tool that analyzes text patterns to identify AI-generated content.")