AnonymousPaperSubmission123's picture
Update src/app.py
68799d0 verified
import os
from pathlib import Path
# -----------------
# Get the directory where app.py is located
# -----------------
APP_DIR = Path(__file__).parent.resolve()
account_name = 'mamba413'
# -----------------
# Fix Streamlit Permission Issues
# -----------------
# 在 HF Space 中,将 Streamlit 配置目录设置到可写位置
if os.environ.get('SPACE_ID'):
os.environ['STREAMLIT_SERVER_FILE_WATCHER_TYPE'] = 'none'
os.environ['STREAMLIT_BROWSER_GATHER_USAGE_STATS'] = 'false'
os.environ['STREAMLIT_SERVER_ENABLE_CORS'] = 'false'
# 设置 HuggingFace 缓存到可写目录
CACHE_DIR = '/tmp/huggingface_cache'
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ['HF_HOME'] = CACHE_DIR
os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR
os.environ['HF_DATASETS_CACHE'] = CACHE_DIR
os.environ['HUGGINGFACE_HUB_CACHE'] = CACHE_DIR
# 设置可写的配置目录
streamlit_dir = Path('/tmp/.streamlit')
streamlit_dir.mkdir(exist_ok=True, parents=True)
# os.environ['STREAMLIT_HOME'] = '/tmp/.streamlit'
import streamlit as st
from FineTune.model import ComputeStat
import time
st.markdown(
"""
<style>
/* Text area & text input */
textarea, input[type="text"] {
background-color: #f8fafc !important;
border: 1px solid #e5e7eb !important;
color: #111827 !important;
}
textarea::placeholder {
color: #9ca3af !important;
}
/* Selectbox */
div[data-testid="stSelectbox"] > div {
background-color: #f8fafc !important;
border: 1px solid #e5e7eb !important;
}
</style>
""",
unsafe_allow_html=True
)
st.markdown(
"""
<style>
/* Detect button */
div.stButton > button[kind="primary"] {
background-color: #fdae6b;
border: white;
color: black;
font-weight: 600;
height: 4.3rem;
font-size: 1.1rem;
display: flex;
align-items: center;
justify-content: center;
gap: 0.55rem;
}
/* Icon inside Detect button */
div.stButton > button[kind="primary"] span {
font-size: 1.25rem;
line-height: 1;
}
div.stButton > button[kind="primary"]:hover {
background-color: #fd8d3c;
border-color: white;
}
div.stButton > button[kind="primary"]:active {
background-color: #fd8d3c;
border-color: white;
}
</style>
""",
unsafe_allow_html=True
)
# -----------------
# Page Configuration
# -----------------
st.set_page_config(
page_title="DetectGPTPro",
page_icon="🕵️",
)
# -----------------
# Model Loading (Cached)
# -----------------
@st.cache_resource
def load_model(from_pretrained, base_model, cache_dir, device):
"""
Load and cache the model to avoid reloading on every user interaction.
This function runs only once when the app starts or when parameters change.
"""
# is_hf_space = os.environ.get('SPACE_ID') is not None
is_hf_space = False
if is_hf_space:
cache_dir = '/tmp/huggingface_cache'
os.makedirs(cache_dir, exist_ok=True)
device = 'cpu'
print("Using **CPU** now!")
# 获取 HF Token(用于访问 gated 模型)
hf_token = os.environ.get('HF_TOKEN', None)
if hf_token:
# 也可以用 login 方式
try:
from huggingface_hub import login
login(token=hf_token)
print("✅ Successfully authenticated with HF token")
except Exception as e:
print(f"⚠️ HF login warning: {e}")
# 🔥 新增:从 HF Hub 下载模型
# 检查是否是 HF Hub 路径(格式:username/repo-name)
is_hf_hub = '/' in from_pretrained and not from_pretrained.startswith('.')
if is_hf_hub:
from huggingface_hub import snapshot_download
print(f"📥 Downloading model from HuggingFace Hub: {from_pretrained}")
try:
# 下载整个仓库到本地
local_model_path = snapshot_download(
repo_id=from_pretrained,
cache_dir=cache_dir,
token=hf_token,
repo_type="model"
)
print(f"✅ Model downloaded to: {local_model_path}")
# 使用下载后的本地路径
from_pretrained = local_model_path
except Exception as e:
print(f"❌ Failed to download model: {e}")
raise
else:
cache_dir = cache_dir
with st.spinner("🔄 Loading model... This may take a moment on first launch."):
model = ComputeStat.from_pretrained(
from_pretrained,
base_model,
device=device,
cache_dir=cache_dir
)
model.set_criterion_fn('mean')
return model
# -----------------
# Result Feedback Module Import
# -----------------
from feedback import FeedbackManager
# Initialize Feedback Manager with HF dataset
# 请将 'your-username/your-dataset-name' 替换为您的实际 HF 数据集仓库 ID
# 确保在环境变量中设置了 HF_TOKEN 以访问私有数据集
FEEDBACK_DATASET_ID = os.environ.get('FEEDBACK_DATASET_ID', f'{account_name}/user-feedback')
feedback_manager = FeedbackManager(
dataset_repo_id=FEEDBACK_DATASET_ID,
hf_token=os.environ.get('HF_TOKEN'),
local_backup=False if os.environ.get('SPACE_ID') else True # 保留本地备份
)
# -----------------
# Configuration
# -----------------
MODEL_CONFIG = {
'from_pretrained': './src/FineTune/ckpt/',
'base_model': 'gemma-1b',
'cache_dir': '../cache',
'device': 'cpu' if os.environ.get('SPACE_ID') else 'mps',
# 'device': 'cuda',
}
DOMAINS = [
"General",
"Academia",
"Finance",
"Government",
"Knowledge",
"Legislation",
"Medicine",
"News",
"UserReview"
]
# Load model once at startup
try:
model = load_model(
MODEL_CONFIG['from_pretrained'],
MODEL_CONFIG['base_model'],
MODEL_CONFIG['cache_dir'],
MODEL_CONFIG['device']
)
model_loaded = True
except Exception as e:
model_loaded = False
error_message = str(e)
# =========== 🆕 session_state ===========
if 'last_detection' not in st.session_state:
st.session_state.last_detection = None
if 'feedback_given' not in st.session_state:
st.session_state.feedback_given = False
# ========================================
# -----------------
# Streamlit Layout
# -----------------
st.markdown(
"<h1 style='text-align: center;'> Detect AI-Generated Texts 🕵️ </h1>",
unsafe_allow_html=True,
)
# st.markdown(
# """Pasted the text to be detected below and click the 'Detect' button to get the p-value. Use a better option may improve detection."""
# )
# Display model loading status
if not model_loaded:
st.error(f"❌ Failed to load model: {error_message}")
st.stop()
# -----------------
# Main Interface
# -----------------
# --- Two columns: Input text & button | Result displays ---
text_input = st.text_area(
label="📝 Input Text to be Detected",
placeholder="Paste your text here",
height=240,
label_visibility="hidden",
)
subcol11, subcol12, subcol13 = st.columns((1, 1, 1))
selected_domain = subcol11.selectbox(
label="💡 Domain that matches your text",
options=DOMAINS,
index=0, # Default to General
# label_visibility="collapsed",
# label_visibility="hidden",
)
detect_clicked = subcol12.button("🔍 Detect", type="primary", use_container_width=True)
selected_level = subcol13.slider(
label="Significance level (α)",
min_value=0.01,
max_value=0.2,
value=0.05,
step=0.005,
# label_visibility="collapsed",
)
# col2, col3, col4 = st.columns((1, 1, 2))
# with col2:
# statistics_ph = st.empty()
# statistics_ph.text_input(
# label="Statistic",
# value="",
# placeholder="",
# disabled=True,
# )
# with col3:
# pvalue_ph = st.empty()
# pvalue_ph.text_input(
# label="p-value",
# value="",
# placeholder="",
# disabled=True,
# )
# with col4:
# conclusion_ph = st.empty()
# conclusion_ph.text_input(
# label="Conclusion",
# value="",
# placeholder="",
# disabled=True,
# )
# -----------------
# Detection Logic
# -----------------
if detect_clicked:
if not text_input.strip():
st.warning("⚠️ Please enter some text before detecting.")
else:
# ========== Reset feedback state ==========
st.session_state.feedback_given = False
# ==========================================
# Start timing to decide whether to show progress bar
start_time = time.time()
# Use a placeholder for dynamic updates
status_placeholder = st.empty()
result_placeholder = st.empty()
try:
# Show spinner for quick operations (< 2 seconds expected)
with status_placeholder:
with st.spinner(f"🔍 Analyzing text in {selected_domain} domain..."):
# Perform inference
crit, p_value = model.compute_p_value(text_input, selected_domain)
elapsed_time = time.time() - start_time
# Convert tensors to Python scalars if needed
if hasattr(crit, 'item'):
crit = crit.item()
if hasattr(p_value, 'item'):
p_value = p_value.item()
# Clear status and show results
status_placeholder.empty()
# ========== 🆕 保存检测结果到 session_state ==========
st.session_state.last_detection = {
'text': text_input,
'domain': selected_domain,
'statistics': crit,
'p_value': p_value,
'elapsed_time': elapsed_time
}
# # Update score displays
# statistics_ph.text_input(
# label="Statistics",
# value=f"{crit:.6f}",
# disabled=True,
# )
# pvalue_ph.text_input(
# label="p-value",
# value=f"{p_value:.6f}",
# disabled=True,
# )
# conclusion_ph.text_input(
# label="Conclusion",
# value="Reject H0: Text is likely LLM-generated." if p_value < selected_level else "Fail to Reject H0: Text is likely human-written.",
# disabled=True,
# )
st.info(
f"""
**Conclusion**:
{'Text is likely LLM-generated.' if p_value < selected_level else 'Fail to reject hypothesis that text is human-written.'}
based on the observation that $p$-value {p_value:.3f} is {'less' if p_value < selected_level else 'greater'} than significance level {selected_level:.2f} 📊
""",
icon="💡"
)
st.markdown(
"""
<style>
/* Tighten spacing inside Clarification / Citation expanders */
div[data-testid="stExpander"] {
margin-top: -1.3rem;
}
div[data-testid="stExpander"] p,
div[data-testid="stExpander"] li {
line-height: 1.35;
margin-bottom: 0.1rem;
}
div[data-testid="stExpander"] ul {
margin-top: 0.1rem;
}
</style>
""",
unsafe_allow_html=True
)
with st.expander("📋 Interpretation and Suggestions"):
st.markdown(
"""
+ Interpretation:
- $p$-value: Lower $p$-value (closer to 0) indicates text is **more likely AI-generated**; Higher $p$-value (closer to 1) indicates text is **more likely human-written**.
- Significance Level (α): a threshold set by the user to determine the sensitivity of the detection. Lower α means stricter criteria for claiming the text is AI-generated.
+ Suggestions for better detection:
- Provide longer text inputs for more reliable detection results.
- Select the domain that best matches the content of your text to improve detection accuracy.
"""
)
# ========== 🆕 Feedback buttons (moved here for better UX) ==========
st.markdown("**📝 Result Feedback**: Does this detection result meet your expectations?")
current_text = text_input
current_domain = selected_domain
current_statistics = crit
current_pvalue = p_value
feedback_col1, feedback_col2 = st.columns(2)
with feedback_col1:
if st.button("✅ Expected", use_container_width=True, type="secondary", key=f"expected_btn_{hash(text_input[:50])}"):
try:
success, message = feedback_manager.save_feedback(
current_text,
current_domain,
current_statistics,
current_pvalue,
'expected'
)
if success:
st.success("✅ Thank you for your feedback!")
st.caption(f"💾 {message}")
else:
st.error(f"Failed to save feedback: {message}")
except Exception as e:
st.error(f"Failed to save feedback: {str(e)}")
import traceback
st.code(traceback.format_exc())
with feedback_col2:
if st.button("❌ Unexpected", use_container_width=True, type="secondary", key=f"unexpected_btn_{hash(text_input[:50])}"):
try:
success, message = feedback_manager.save_feedback(
current_text,
current_domain,
current_statistics,
current_pvalue,
'unexpected'
)
if success:
st.warning("❌ Feedback recorded! This will help us improve.")
st.caption(f"💾 {message}")
else:
st.error(f"Failed to save feedback: {message}")
except Exception as e:
st.error(f"Failed to save feedback: {str(e)}")
import traceback
st.code(traceback.format_exc())
if st.session_state.feedback_given:
st.success("✅ Feedback submitted successfully!")
# ============================================
# Show detailed results
with result_placeholder:
st.caption(f"⏱️ Processing time: {elapsed_time:.2f} seconds")
except Exception as e:
status_placeholder.empty()
st.error(f"❌ Error during detection: {str(e)}")
st.exception(e)
# st.markdown("<div style='height: 4rem;'></div>", unsafe_allow_html=True)
# st.markdown(
# """
# <style>
# /* Tighten spacing inside Clarification / Citation expanders */
# div[data-testid="stExpander"] p,
# div[data-testid="stExpander"] li {
# line-height: 1.35;
# margin-bottom: 0.3rem;
# }
# div[data-testid="stExpander"] ul {
# margin-top: 0.3rem;
# margin-bottom: 0.3rem;
# }
# </style>
# """,
# unsafe_allow_html=True
# )
# with st.expander("📋 Illustration and Suggestions"):
# st.markdown(
# """
# + Illustration:
# - Statistic: a numerical measure computed from the input text. Higher values typically indicate a greater likelihood of being AI-generated.
# - p-value: Lower p-value (closer to 0) indicates text is **more likely AI-generated**; Higher p-value (closer to 1) indicates text is **more likely human-written**.
# - Conclusion: when the p-value is less than the significance level (α), the text is classified as AI-generated.
# - Significance Level (α): a threshold set by the user to determine the sensitivity of the detection. Lower α means stricter criteria for claiming the text is AI-generated.
# + Suggestions for better detection:
# - Provide longer text inputs for more reliable detection results.
# - Select the domain that best matches the content of your text to improve detection accuracy.
# """
# )
# with st.expander("📋 Citation"):
# st.markdown(
# """
# If you find this tool useful for you, please cite our paper: **[AdaDetectGPT: Adaptive Detection of LLM-Generated Text with Statistical Guarantees](https://arxiv.org/abs/2510.01268)**
# """
# )
# st.code(
# """
# @inproceedings{zhou2024adadetectgpt,
# title={AdaDetectGPT: Adaptive Detection of LLM-Generated Text with Statistical Guarantees},
# author={Hongyi Zhou and Jin Zhu and Pingfan Su and Kai Ye and Ying Yang and Shakeel A O B Gavioli-Akilagun and Chengchun Shi},
# booktitle={The Thirty-Ninth Annual Conference on Neural Information Processing Systems},
# year={2025},
# }
# """,
# language="bibtex"
# )
# -----------------
# Footer
# -----------------
st.markdown(
"""
<style>
.footer {
position: fixed;
left: 0;
bottom: 0;
width: 100%;
background-color: white;
color: gray;
text-align: center;
padding: 1px;
border-top: 1px solid #e0e0e0;
z-index: 999;
}
/* Add padding to main content to prevent overlap with fixed footer */
.main .block-container {
padding-bottom: 1px;
}
</style>
<div class='footer'>
<small> This tool is developed for research purposes only. The detection results are not 100% accurate and should not be used as the sole basis for any critical decisions. Users are advised to use this tool responsibly and ethically. </small>
</div>
""",
unsafe_allow_html=True
)