import os from pathlib import Path # ----------------- # Get the directory where app.py is located # ----------------- APP_DIR = Path(__file__).parent.resolve() account_name = 'mamba413' # ----------------- # Fix Streamlit Permission Issues # ----------------- # 在 HF Space 中,将 Streamlit 配置目录设置到可写位置 if os.environ.get('SPACE_ID'): os.environ['STREAMLIT_SERVER_FILE_WATCHER_TYPE'] = 'none' os.environ['STREAMLIT_BROWSER_GATHER_USAGE_STATS'] = 'false' os.environ['STREAMLIT_SERVER_ENABLE_CORS'] = 'false' # 设置 HuggingFace 缓存到可写目录 CACHE_DIR = '/tmp/huggingface_cache' os.makedirs(CACHE_DIR, exist_ok=True) os.environ['HF_HOME'] = CACHE_DIR os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR os.environ['HF_DATASETS_CACHE'] = CACHE_DIR os.environ['HUGGINGFACE_HUB_CACHE'] = CACHE_DIR # 设置可写的配置目录 streamlit_dir = Path('/tmp/.streamlit') streamlit_dir.mkdir(exist_ok=True, parents=True) # os.environ['STREAMLIT_HOME'] = '/tmp/.streamlit' import streamlit as st from FineTune.model import ComputeStat import time st.markdown( """ """, unsafe_allow_html=True ) st.markdown( """ """, unsafe_allow_html=True ) # ----------------- # Page Configuration # ----------------- st.set_page_config( page_title="DetectGPTPro", page_icon="🕵️", ) # ----------------- # Model Loading (Cached) # ----------------- @st.cache_resource def load_model(from_pretrained, base_model, cache_dir, device): """ Load and cache the model to avoid reloading on every user interaction. This function runs only once when the app starts or when parameters change. """ # is_hf_space = os.environ.get('SPACE_ID') is not None is_hf_space = False if is_hf_space: cache_dir = '/tmp/huggingface_cache' os.makedirs(cache_dir, exist_ok=True) device = 'cpu' print("Using **CPU** now!") # 获取 HF Token(用于访问 gated 模型) hf_token = os.environ.get('HF_TOKEN', None) if hf_token: # 也可以用 login 方式 try: from huggingface_hub import login login(token=hf_token) print("✅ Successfully authenticated with HF token") except Exception as e: print(f"⚠️ HF login warning: {e}") # 🔥 新增:从 HF Hub 下载模型 # 检查是否是 HF Hub 路径(格式:username/repo-name) is_hf_hub = '/' in from_pretrained and not from_pretrained.startswith('.') if is_hf_hub: from huggingface_hub import snapshot_download print(f"📥 Downloading model from HuggingFace Hub: {from_pretrained}") try: # 下载整个仓库到本地 local_model_path = snapshot_download( repo_id=from_pretrained, cache_dir=cache_dir, token=hf_token, repo_type="model" ) print(f"✅ Model downloaded to: {local_model_path}") # 使用下载后的本地路径 from_pretrained = local_model_path except Exception as e: print(f"❌ Failed to download model: {e}") raise else: cache_dir = cache_dir with st.spinner("🔄 Loading model... This may take a moment on first launch."): model = ComputeStat.from_pretrained( from_pretrained, base_model, device=device, cache_dir=cache_dir ) model.set_criterion_fn('mean') return model # ----------------- # Result Feedback Module Import # ----------------- from feedback import FeedbackManager # Initialize Feedback Manager with HF dataset # 请将 'your-username/your-dataset-name' 替换为您的实际 HF 数据集仓库 ID # 确保在环境变量中设置了 HF_TOKEN 以访问私有数据集 FEEDBACK_DATASET_ID = os.environ.get('FEEDBACK_DATASET_ID', f'{account_name}/user-feedback') feedback_manager = FeedbackManager( dataset_repo_id=FEEDBACK_DATASET_ID, hf_token=os.environ.get('HF_TOKEN'), local_backup=False if os.environ.get('SPACE_ID') else True # 保留本地备份 ) # ----------------- # Configuration # ----------------- MODEL_CONFIG = { 'from_pretrained': './src/FineTune/ckpt/', 'base_model': 'gemma-1b', 'cache_dir': '../cache', 'device': 'cpu' if os.environ.get('SPACE_ID') else 'mps', # 'device': 'cuda', } DOMAINS = [ "General", "Academia", "Finance", "Government", "Knowledge", "Legislation", "Medicine", "News", "UserReview" ] # Load model once at startup try: model = load_model( MODEL_CONFIG['from_pretrained'], MODEL_CONFIG['base_model'], MODEL_CONFIG['cache_dir'], MODEL_CONFIG['device'] ) model_loaded = True except Exception as e: model_loaded = False error_message = str(e) # =========== 🆕 session_state =========== if 'last_detection' not in st.session_state: st.session_state.last_detection = None if 'feedback_given' not in st.session_state: st.session_state.feedback_given = False # ======================================== # ----------------- # Streamlit Layout # ----------------- st.markdown( "

Detect AI-Generated Texts 🕵️

", unsafe_allow_html=True, ) # st.markdown( # """Pasted the text to be detected below and click the 'Detect' button to get the p-value. Use a better option may improve detection.""" # ) # Display model loading status if not model_loaded: st.error(f"❌ Failed to load model: {error_message}") st.stop() # ----------------- # Main Interface # ----------------- # --- Two columns: Input text & button | Result displays --- text_input = st.text_area( label="📝 Input Text to be Detected", placeholder="Paste your text here", height=240, label_visibility="hidden", ) subcol11, subcol12, subcol13 = st.columns((1, 1, 1)) selected_domain = subcol11.selectbox( label="💡 Domain that matches your text", options=DOMAINS, index=0, # Default to General # label_visibility="collapsed", # label_visibility="hidden", ) detect_clicked = subcol12.button("🔍 Detect", type="primary", use_container_width=True) selected_level = subcol13.slider( label="Significance level (α)", min_value=0.01, max_value=0.2, value=0.05, step=0.005, # label_visibility="collapsed", ) # col2, col3, col4 = st.columns((1, 1, 2)) # with col2: # statistics_ph = st.empty() # statistics_ph.text_input( # label="Statistic", # value="", # placeholder="", # disabled=True, # ) # with col3: # pvalue_ph = st.empty() # pvalue_ph.text_input( # label="p-value", # value="", # placeholder="", # disabled=True, # ) # with col4: # conclusion_ph = st.empty() # conclusion_ph.text_input( # label="Conclusion", # value="", # placeholder="", # disabled=True, # ) # ----------------- # Detection Logic # ----------------- if detect_clicked: if not text_input.strip(): st.warning("⚠️ Please enter some text before detecting.") else: # ========== Reset feedback state ========== st.session_state.feedback_given = False # ========================================== # Start timing to decide whether to show progress bar start_time = time.time() # Use a placeholder for dynamic updates status_placeholder = st.empty() result_placeholder = st.empty() try: # Show spinner for quick operations (< 2 seconds expected) with status_placeholder: with st.spinner(f"🔍 Analyzing text in {selected_domain} domain..."): # Perform inference crit, p_value = model.compute_p_value(text_input, selected_domain) elapsed_time = time.time() - start_time # Convert tensors to Python scalars if needed if hasattr(crit, 'item'): crit = crit.item() if hasattr(p_value, 'item'): p_value = p_value.item() # Clear status and show results status_placeholder.empty() # ========== 🆕 保存检测结果到 session_state ========== st.session_state.last_detection = { 'text': text_input, 'domain': selected_domain, 'statistics': crit, 'p_value': p_value, 'elapsed_time': elapsed_time } # # Update score displays # statistics_ph.text_input( # label="Statistics", # value=f"{crit:.6f}", # disabled=True, # ) # pvalue_ph.text_input( # label="p-value", # value=f"{p_value:.6f}", # disabled=True, # ) # conclusion_ph.text_input( # label="Conclusion", # value="Reject H0: Text is likely LLM-generated." if p_value < selected_level else "Fail to Reject H0: Text is likely human-written.", # disabled=True, # ) st.info( f""" **Conclusion**: {'Text is likely LLM-generated.' if p_value < selected_level else 'Fail to reject hypothesis that text is human-written.'} based on the observation that $p$-value {p_value:.3f} is {'less' if p_value < selected_level else 'greater'} than significance level {selected_level:.2f} 📊 """, icon="💡" ) st.markdown( """ """, unsafe_allow_html=True ) with st.expander("📋 Interpretation and Suggestions"): st.markdown( """ + Interpretation: - $p$-value: Lower $p$-value (closer to 0) indicates text is **more likely AI-generated**; Higher $p$-value (closer to 1) indicates text is **more likely human-written**. - Significance Level (α): a threshold set by the user to determine the sensitivity of the detection. Lower α means stricter criteria for claiming the text is AI-generated. + Suggestions for better detection: - Provide longer text inputs for more reliable detection results. - Select the domain that best matches the content of your text to improve detection accuracy. """ ) # ========== 🆕 Feedback buttons (moved here for better UX) ========== st.markdown("**📝 Result Feedback**: Does this detection result meet your expectations?") current_text = text_input current_domain = selected_domain current_statistics = crit current_pvalue = p_value feedback_col1, feedback_col2 = st.columns(2) with feedback_col1: if st.button("✅ Expected", use_container_width=True, type="secondary", key=f"expected_btn_{hash(text_input[:50])}"): try: success, message = feedback_manager.save_feedback( current_text, current_domain, current_statistics, current_pvalue, 'expected' ) if success: st.success("✅ Thank you for your feedback!") st.caption(f"💾 {message}") else: st.error(f"Failed to save feedback: {message}") except Exception as e: st.error(f"Failed to save feedback: {str(e)}") import traceback st.code(traceback.format_exc()) with feedback_col2: if st.button("❌ Unexpected", use_container_width=True, type="secondary", key=f"unexpected_btn_{hash(text_input[:50])}"): try: success, message = feedback_manager.save_feedback( current_text, current_domain, current_statistics, current_pvalue, 'unexpected' ) if success: st.warning("❌ Feedback recorded! This will help us improve.") st.caption(f"💾 {message}") else: st.error(f"Failed to save feedback: {message}") except Exception as e: st.error(f"Failed to save feedback: {str(e)}") import traceback st.code(traceback.format_exc()) if st.session_state.feedback_given: st.success("✅ Feedback submitted successfully!") # ============================================ # Show detailed results with result_placeholder: st.caption(f"⏱️ Processing time: {elapsed_time:.2f} seconds") except Exception as e: status_placeholder.empty() st.error(f"❌ Error during detection: {str(e)}") st.exception(e) # st.markdown("
", unsafe_allow_html=True) # st.markdown( # """ # # """, # unsafe_allow_html=True # ) # with st.expander("📋 Illustration and Suggestions"): # st.markdown( # """ # + Illustration: # - Statistic: a numerical measure computed from the input text. Higher values typically indicate a greater likelihood of being AI-generated. # - p-value: Lower p-value (closer to 0) indicates text is **more likely AI-generated**; Higher p-value (closer to 1) indicates text is **more likely human-written**. # - Conclusion: when the p-value is less than the significance level (α), the text is classified as AI-generated. # - Significance Level (α): a threshold set by the user to determine the sensitivity of the detection. Lower α means stricter criteria for claiming the text is AI-generated. # + Suggestions for better detection: # - Provide longer text inputs for more reliable detection results. # - Select the domain that best matches the content of your text to improve detection accuracy. # """ # ) # with st.expander("📋 Citation"): # st.markdown( # """ # If you find this tool useful for you, please cite our paper: **[AdaDetectGPT: Adaptive Detection of LLM-Generated Text with Statistical Guarantees](https://arxiv.org/abs/2510.01268)** # """ # ) # st.code( # """ # @inproceedings{zhou2024adadetectgpt, # title={AdaDetectGPT: Adaptive Detection of LLM-Generated Text with Statistical Guarantees}, # author={Hongyi Zhou and Jin Zhu and Pingfan Su and Kai Ye and Ying Yang and Shakeel A O B Gavioli-Akilagun and Chengchun Shi}, # booktitle={The Thirty-Ninth Annual Conference on Neural Information Processing Systems}, # year={2025}, # } # """, # language="bibtex" # ) # ----------------- # Footer # ----------------- st.markdown( """ """, unsafe_allow_html=True )