import os from pathlib import Path # ----------------- # Get the directory where app.py is located # ----------------- APP_DIR = Path(__file__).parent.resolve() # ----------------- # Fix Streamlit Permission Issues # ----------------- # 在 HF Space 中,将 Streamlit 配置目录设置到可写位置 if os.environ.get('SPACE_ID'): os.environ['STREAMLIT_SERVER_FILE_WATCHER_TYPE'] = 'none' os.environ['STREAMLIT_BROWSER_GATHER_USAGE_STATS'] = 'false' os.environ['STREAMLIT_SERVER_ENABLE_CORS'] = 'false' # 设置 HuggingFace 缓存到可写目录 CACHE_DIR = '/tmp/huggingface_cache' os.makedirs(CACHE_DIR, exist_ok=True) os.environ['HF_HOME'] = CACHE_DIR os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR os.environ['HF_DATASETS_CACHE'] = CACHE_DIR os.environ['HUGGINGFACE_HUB_CACHE'] = CACHE_DIR # 设置可写的配置目录 streamlit_dir = Path('/tmp/.streamlit') streamlit_dir.mkdir(exist_ok=True, parents=True) # os.environ['STREAMLIT_HOME'] = '/tmp/.streamlit' import streamlit as st from FineTune.model import ComputeStat import time # ----------------- # Page Configuration # ----------------- st.set_page_config( page_title="AdaDetectGPT", page_icon="🔍", layout="wide" ) # ----------------- # Model Loading (Cached) # ----------------- @st.cache_resource def load_model(from_pretrained, base_model, cache_dir, device): """ Load and cache the model to avoid reloading on every user interaction. This function runs only once when the app starts or when parameters change. """ # is_hf_space = os.environ.get('SPACE_ID') is not None is_hf_space = False if is_hf_space: cache_dir = '/tmp/huggingface_cache' os.makedirs(cache_dir, exist_ok=True) device = 'cpu' print("Using **CPU** now!") # 获取 HF Token(用于访问 gated 模型) hf_token = os.environ.get('HF_TOKEN', None) if hf_token: # 也可以用 login 方式 try: from huggingface_hub import login login(token=hf_token) print("✅ Successfully authenticated with HF token") except Exception as e: print(f"⚠️ HF login warning: {e}") # 🔥 新增:从 HF Hub 下载模型 # 检查是否是 HF Hub 路径(格式:username/repo-name) is_hf_hub = '/' in from_pretrained and not from_pretrained.startswith('.') if is_hf_hub: from huggingface_hub import snapshot_download print(f"📥 Downloading model from HuggingFace Hub: {from_pretrained}") try: # 下载整个仓库到本地 local_model_path = snapshot_download( repo_id=from_pretrained, cache_dir=cache_dir, token=hf_token, repo_type="model" ) print(f"✅ Model downloaded to: {local_model_path}") # 使用下载后的本地路径 from_pretrained = local_model_path except Exception as e: print(f"❌ Failed to download model: {e}") raise else: cache_dir = cache_dir with st.spinner("🔄 Loading model... This may take a moment on first launch."): model = ComputeStat.from_pretrained( from_pretrained, base_model, device=device, cache_dir=cache_dir ) model.set_criterion_fn('mean') return model # ----------------- # Result Feedback Module Import # ----------------- from feedback import FeedbackManager # Initialize Feedback Manager with HF dataset # 请将 'your-username/your-dataset-name' 替换为您的实际 HF 数据集仓库 ID # 确保在环境变量中设置了 HF_TOKEN 以访问私有数据集 FEEDBACK_DATASET_ID = os.environ.get('FEEDBACK_DATASET_ID', 'mamba413/user-feedback') feedback_manager = FeedbackManager( dataset_repo_id=FEEDBACK_DATASET_ID, hf_token=os.environ.get('HF_TOKEN'), local_backup=False if os.environ.get('SPACE_ID') else True # 保留本地备份 ) # ----------------- # Configuration # ----------------- MODEL_CONFIG = { 'from_pretrained': './src/FineTune/ckpt/', 'base_model': 'gemma-1b', 'cache_dir': '../cache', 'device': 'cpu' if os.environ.get('SPACE_ID') else 'mps', # 'device': 'cuda', } DOMAINS = [ "General", "Academia", "Finance", "Government", "Knowledge", "Legislation", "Medicine", "News", "UserReview" ] # Load model once at startup try: model = load_model( MODEL_CONFIG['from_pretrained'], MODEL_CONFIG['base_model'], MODEL_CONFIG['cache_dir'], MODEL_CONFIG['device'] ) model_loaded = True except Exception as e: model_loaded = False error_message = str(e) # =========== 🆕 session_state =========== if 'last_detection' not in st.session_state: st.session_state.last_detection = None if 'feedback_given' not in st.session_state: st.session_state.feedback_given = False # ======================================== # ----------------- # Streamlit Layout # ----------------- _, col0, _ = st.columns((1, 5, 1)) with col0: st.markdown( "

AdaDetectGPT: Adaptive LLM's Texts Detection

", unsafe_allow_html=True, ) st.markdown( """Pasted the text to be detected below and click the 'Detect' button to get the p-value. Use a better option may improve detection.""" ) # Display model loading status if not model_loaded: st.error(f"❌ Failed to load model: {error_message}") st.stop() # ----------------- # Main Interface # ----------------- # --- Two columns: Input text & button | Result displays --- col1, col2 = st.columns((1, 1)) with col1: text_input = st.text_area( label="", placeholder="Paste your text to be detected here", help="Typically, providing text with a longer content would get a more reliable result.", height=200, ) detect_clicked = st.button("Detect", type="primary", use_container_width=True) selected_domain = st.selectbox( label="⚙️ Domain (Optional)", options=DOMAINS, index=0, # Default to General help="💡 **Tip:** Select the domain that best matches your text for improving detection accuracy. Default is 'General' that means consider all domains." ) with col2: statistics_ph = st.empty() statistics_ph.text_input( label="Statistics", value="", disabled=True, help="Statistics will appear here after clicking the Detect button.", ) pvalue_ph = st.empty() pvalue_ph.text_input( label="p-value", value="", disabled=True, help="p-value will appear here after clicking the Detect button.", ) # ----------------- # Detection Logic # ----------------- if detect_clicked: if not text_input.strip(): st.warning("⚠️ Please enter some text before detecting.") else: # ========== Reset feedback state ========== st.session_state.feedback_given = False # ========================================== # Start timing to decide whether to show progress bar start_time = time.time() # Use a placeholder for dynamic updates status_placeholder = st.empty() result_placeholder = st.empty() try: # Show spinner for quick operations (< 2 seconds expected) with status_placeholder: with st.spinner(f"🔍 Analyzing text in {selected_domain} domain..."): # Perform inference crit, p_value = model.compute_p_value(text_input, selected_domain) elapsed_time = time.time() - start_time # Convert tensors to Python scalars if needed if hasattr(crit, 'item'): crit = crit.item() if hasattr(p_value, 'item'): p_value = p_value.item() # Clear status and show results status_placeholder.empty() # ========== 🆕 保存检测结果到 session_state ========== st.session_state.last_detection = { 'text': text_input, 'domain': selected_domain, 'statistics': crit, 'p_value': p_value, 'elapsed_time': elapsed_time } # Update score displays with col2: statistics_ph.text_input( label="Statistics", value=f"{crit:.6f}", disabled=True, help="Detection statistics will appear here after clicking Detect.", ) pvalue_ph.text_input( label="p-value", value=f"{p_value:.6f}", disabled=True, help="p-value will appear here after clicking Detect.", ) st.info( """ **📊 p-value:** - **Lower p-value** (closer to 0) indicates text is **more likely AI-generated** - **Higher p-value** (closer to 1) indicates text is **more likely human-written** - Generally, p-value < 0.05 suggests the text may be LLM-generated """, icon="💡" ) # ========== 🆕 Feedback buttons (moved here for better UX) ========== st.markdown("**📝 Result Feedback**: Does this detection result meet your expectations?") current_text = text_input current_domain = selected_domain current_statistics = crit current_pvalue = p_value feedback_col1, feedback_col2 = st.columns(2) with feedback_col1: if st.button("✅ Expected", use_container_width=True, type="secondary", key=f"expected_btn_{hash(text_input[:50])}"): try: success, message = feedback_manager.save_feedback( current_text, current_domain, current_statistics, current_pvalue, 'expected' ) if success: st.success("✅ Thank you for your feedback!") st.caption(f"💾 {message}") else: st.error(f"Failed to save feedback: {message}") except Exception as e: st.error(f"Failed to save feedback: {str(e)}") import traceback st.code(traceback.format_exc()) with feedback_col2: if st.button("❌ Unexpected", use_container_width=True, type="secondary", key=f"unexpected_btn_{hash(text_input[:50])}"): try: success, message = feedback_manager.save_feedback( current_text, current_domain, current_statistics, current_pvalue, 'unexpected' ) if success: st.warning("❌ Feedback recorded! This will help us improve.") st.caption(f"💾 {message}") else: st.error(f"Failed to save feedback: {message}") except Exception as e: st.error(f"Failed to save feedback: {str(e)}") import traceback st.code(traceback.format_exc()) if st.session_state.feedback_given: st.success("✅ Feedback submitted successfully!") # ============================================ # Show detailed results with result_placeholder: st.caption(f"⏱️ Processing time: {elapsed_time:.2f} seconds") except Exception as e: status_placeholder.empty() st.error(f"❌ Error during detection: {str(e)}") st.exception(e) # ========== 🆕 Citation and paper reference section ========== # st.markdown("---") # st.markdown( # """ # 📄 **Citation** If you find this tool useful for you, please cite our paper: **[AdaDetectGPT: Adaptive Detection of LLM-Generated Text with Statistical Guarantees](https://arxiv.org/abs/2510.01268)** # """ # ) # with st.expander("📋 BibTeX Citation"): # st.code( # """ # @inproceedings{zhou2024adadetectgpt, # title={AdaDetectGPT: Adaptive Detection of LLM-Generated Text with Statistical Guarantees}, # author={Hongyi Zhou and Jin Zhu and Pingfan Su and Kai Ye and Ying Yang and Shakeel A O B Gavioli-Akilagun and Chengchun Shi}, # booktitle={The Thirty-Ninth Annual Conference on Neural Information Processing Systems (Accepted)}, # year={2025}, # } # """, # language="bibtex" # ) # ----------------- # Footer # ----------------- st.markdown( """ """, unsafe_allow_html=True )