Spaces:
Sleeping
Sleeping
| import sys | |
| try: | |
| import streamlit as st | |
| import torch | |
| import os | |
| from SeqXGPT.model import TransformerOnlyClassifier | |
| from SeqXGPT.generate_features import FeatureExtractor | |
| from SeqXGPT.transform import predict_with_model | |
| except ImportError as e: | |
| print(f"Error importing required modules: {e}") | |
| print("Please make sure all dependencies are installed by running: pip install -r requirements.txt") | |
| sys.exit(1) | |
| # Set page configuration | |
| st.set_page_config( | |
| page_title="SeqXGPT - AI Text Detector", | |
| page_icon="🔍", | |
| layout="centered", | |
| initial_sidebar_state="collapsed" | |
| ) | |
| id2label = {0: 'gpt2', 1: 'llama', 2: 'human', 3: 'gpt3re'} | |
| def load_extractor(): | |
| return FeatureExtractor() | |
| def load_model(ckpt_name): | |
| """Unified model loading method""" | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = TransformerOnlyClassifier(id2labels=id2label, seq_len=1024) | |
| if not os.path.exists(ckpt_name): | |
| print(f"Warning: Checkpoint {ckpt_name} not found") | |
| return None | |
| try: | |
| state_dict = torch.load( | |
| ckpt_name, | |
| map_location=device, | |
| weights_only=True | |
| ) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| print(f"Model loaded from {ckpt_name}") | |
| return model | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return None | |
| # Initialize feature extractor | |
| with st.spinner("Loading feature extractor and model..."): | |
| extractor = load_extractor() | |
| model = load_model('SeqXGPT/transformer_cls.pt') | |
| if model is None: | |
| st.error("Failed to load model checkpoint") | |
| # Custom CSS for better appearance | |
| st.markdown(""" | |
| <style> | |
| .main-header { | |
| font-size: 2.5rem; | |
| color: #2563eb; | |
| text-align: center; | |
| text-shadow: 1px 1px 2px rgba(0,0,0,0.1); | |
| } | |
| .sub-header { | |
| font-size: 1.2rem; | |
| color: #64748b; | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| } | |
| .result-box { | |
| padding: 1.5rem; | |
| border-radius: 0.75rem; | |
| margin: 1rem 0; | |
| box-shadow: 0 4px 6px -1px rgba(0,0,0,0.1); | |
| color: #333333; | |
| } | |
| .human-text { | |
| background-color: #c6f6d5; | |
| border-left: 5px solid #059669; | |
| border-bottom: 1px solid rgba(5,150,105,0.1); | |
| } | |
| .ai-text { | |
| background-color: #fed7d7; | |
| border-left: 5px solid #dc2626; | |
| border-bottom: 1px solid rgba(220,38,38,0.1); | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # App header | |
| st.markdown("<h1 class='main-header'>SeqXGPT</h1>", unsafe_allow_html=True) | |
| st.markdown("<p class='sub-header'>AI Text Detection Made Simple</p>", unsafe_allow_html=True) | |
| # Text input area | |
| text_input = st.text_area("Enter text to analyze:", height=200, | |
| placeholder="Paste or type text here to check if it's AI-generated...") | |
| # Analysis button | |
| if st.button("Analyze Text", type="primary", use_container_width=True): | |
| if not text_input or len(text_input.strip()) < 50: | |
| st.error("Please enter at least 50 characters.") | |
| else: | |
| with st.spinner("Analyzing text..."): | |
| try: | |
| # Extract features | |
| features = extractor.extract_features(text_input) | |
| # Make prediction | |
| result = predict_with_model(model, features, id2label) | |
| # Get prediction and confidence | |
| prediction = result['text_prediction'] | |
| logits = result['token_logits'] | |
| # Calculate confidence score | |
| confidence = (result['token_predictions'].count(prediction) / len(result['token_predictions'])) * 100 | |
| # Display results | |
| st.markdown("### Analysis Results") | |
| # Progress bar for confidence | |
| st.progress(min(max(confidence/100, 0.0), 1.0)) | |
| # Verdict based on prediction | |
| if prediction in ['gpt2', 'llama', 'gpt3re']: | |
| st.markdown(f""" | |
| <div class='result-box ai-text'> | |
| <h3>🤖 Likely AI-Generated ({confidence:.1f}%)</h3> | |
| <p>This text shows strong indicators of being generated by {prediction.upper()}.</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| elif prediction == 'human': | |
| st.markdown(f""" | |
| <div class='result-box human-text'> | |
| <h3>👤 Likely Human-Written ({confidence:.1f}%)</h3> | |
| <p>This text appears to be written by a human.</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| else: | |
| st.markdown(f""" | |
| <div class='result-box'> | |
| <h3>⚠️ Uncertain ({confidence:.1f}%)</h3> | |
| <p>Unable to confidently determine if this text is AI-generated or human-written.</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| except Exception as e: | |
| st.error(f"An error occurred during analysis: {str(e)}") | |
| st.info("Try with a different text or check if the model is properly loaded.") | |
| # Footer | |
| st.markdown("---") | |
| st.markdown("SeqXGPT is an AI text detection tool that analyzes text patterns to identify AI-generated content.") | |