|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
APP_DIR = Path(__file__).parent.resolve() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if os.environ.get('SPACE_ID'): |
|
|
os.environ['STREAMLIT_SERVER_FILE_WATCHER_TYPE'] = 'none' |
|
|
os.environ['STREAMLIT_BROWSER_GATHER_USAGE_STATS'] = 'false' |
|
|
os.environ['STREAMLIT_SERVER_ENABLE_CORS'] = 'false' |
|
|
|
|
|
|
|
|
CACHE_DIR = '/tmp/huggingface_cache' |
|
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
|
|
|
os.environ['HF_HOME'] = CACHE_DIR |
|
|
os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR |
|
|
os.environ['HF_DATASETS_CACHE'] = CACHE_DIR |
|
|
os.environ['HUGGINGFACE_HUB_CACHE'] = CACHE_DIR |
|
|
|
|
|
|
|
|
streamlit_dir = Path('/tmp/.streamlit') |
|
|
streamlit_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
|
|
|
|
|
|
import streamlit as st |
|
|
from FineTune.model import ComputeStat |
|
|
import time |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="AdaDetectGPT", |
|
|
page_icon="🔍", |
|
|
layout="wide" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(from_pretrained, base_model, cache_dir, device): |
|
|
""" |
|
|
Load and cache the model to avoid reloading on every user interaction. |
|
|
This function runs only once when the app starts or when parameters change. |
|
|
""" |
|
|
|
|
|
is_hf_space = False |
|
|
if is_hf_space: |
|
|
cache_dir = '/tmp/huggingface_cache' |
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
|
|
device = 'cpu' |
|
|
print("Using **CPU** now!") |
|
|
|
|
|
|
|
|
hf_token = os.environ.get('HF_TOKEN', None) |
|
|
if hf_token: |
|
|
|
|
|
try: |
|
|
from huggingface_hub import login |
|
|
login(token=hf_token) |
|
|
print("✅ Successfully authenticated with HF token") |
|
|
except Exception as e: |
|
|
print(f"⚠️ HF login warning: {e}") |
|
|
|
|
|
|
|
|
|
|
|
is_hf_hub = '/' in from_pretrained and not from_pretrained.startswith('.') |
|
|
if is_hf_hub: |
|
|
from huggingface_hub import snapshot_download |
|
|
print(f"📥 Downloading model from HuggingFace Hub: {from_pretrained}") |
|
|
try: |
|
|
|
|
|
local_model_path = snapshot_download( |
|
|
repo_id=from_pretrained, |
|
|
cache_dir=cache_dir, |
|
|
token=hf_token, |
|
|
repo_type="model" |
|
|
) |
|
|
print(f"✅ Model downloaded to: {local_model_path}") |
|
|
|
|
|
from_pretrained = local_model_path |
|
|
except Exception as e: |
|
|
print(f"❌ Failed to download model: {e}") |
|
|
raise |
|
|
else: |
|
|
cache_dir = cache_dir |
|
|
|
|
|
with st.spinner("🔄 Loading model... This may take a moment on first launch."): |
|
|
model = ComputeStat.from_pretrained( |
|
|
from_pretrained, |
|
|
base_model, |
|
|
device=device, |
|
|
cache_dir=cache_dir |
|
|
) |
|
|
model.set_criterion_fn('mean') |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from feedback import FeedbackManager |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FEEDBACK_DATASET_ID = os.environ.get('FEEDBACK_DATASET_ID', 'mamba413/user-feedback') |
|
|
feedback_manager = FeedbackManager( |
|
|
dataset_repo_id=FEEDBACK_DATASET_ID, |
|
|
hf_token=os.environ.get('HF_TOKEN'), |
|
|
local_backup=False if os.environ.get('SPACE_ID') else True |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_CONFIG = { |
|
|
'from_pretrained': './src/FineTune/ckpt/', |
|
|
'base_model': 'gemma-1b', |
|
|
'cache_dir': '../cache', |
|
|
'device': 'cpu' if os.environ.get('SPACE_ID') else 'mps', |
|
|
|
|
|
} |
|
|
|
|
|
DOMAINS = [ |
|
|
"General", |
|
|
"Academia", |
|
|
"Finance", |
|
|
"Government", |
|
|
"Knowledge", |
|
|
"Legislation", |
|
|
"Medicine", |
|
|
"News", |
|
|
"UserReview" |
|
|
] |
|
|
|
|
|
|
|
|
try: |
|
|
model = load_model( |
|
|
MODEL_CONFIG['from_pretrained'], |
|
|
MODEL_CONFIG['base_model'], |
|
|
MODEL_CONFIG['cache_dir'], |
|
|
MODEL_CONFIG['device'] |
|
|
) |
|
|
model_loaded = True |
|
|
except Exception as e: |
|
|
model_loaded = False |
|
|
error_message = str(e) |
|
|
|
|
|
|
|
|
if 'last_detection' not in st.session_state: |
|
|
st.session_state.last_detection = None |
|
|
if 'feedback_given' not in st.session_state: |
|
|
st.session_state.feedback_given = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_, col0, _ = st.columns((1, 5, 1)) |
|
|
with col0: |
|
|
st.markdown( |
|
|
"<h1 style='text-align: center; color: #0072C3;'>AdaDetectGPT: Adaptive LLM's Texts Detection</h1>", |
|
|
unsafe_allow_html=True, |
|
|
) |
|
|
|
|
|
st.markdown( |
|
|
"""Pasted the text to be detected below and click the 'Detect' button to get the p-value. Use a better option may improve detection.""" |
|
|
) |
|
|
|
|
|
|
|
|
if not model_loaded: |
|
|
st.error(f"❌ Failed to load model: {error_message}") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
col1, col2 = st.columns((1, 1)) |
|
|
|
|
|
with col1: |
|
|
text_input = st.text_area( |
|
|
label="", |
|
|
placeholder="Paste your text to be detected here", |
|
|
help="Typically, providing text with a longer content would get a more reliable result.", |
|
|
height=200, |
|
|
) |
|
|
|
|
|
detect_clicked = st.button("Detect", type="primary", use_container_width=True) |
|
|
|
|
|
selected_domain = st.selectbox( |
|
|
label="⚙️ Domain (Optional)", |
|
|
options=DOMAINS, |
|
|
index=0, |
|
|
help="💡 **Tip:** Select the domain that best matches your text for improving detection accuracy. Default is 'General' that means consider all domains." |
|
|
) |
|
|
|
|
|
with col2: |
|
|
statistics_ph = st.empty() |
|
|
statistics_ph.text_input( |
|
|
label="Statistics", |
|
|
value="", |
|
|
disabled=True, |
|
|
help="Statistics will appear here after clicking the Detect button.", |
|
|
) |
|
|
|
|
|
pvalue_ph = st.empty() |
|
|
pvalue_ph.text_input( |
|
|
label="p-value", |
|
|
value="", |
|
|
disabled=True, |
|
|
help="p-value will appear here after clicking the Detect button.", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if detect_clicked: |
|
|
if not text_input.strip(): |
|
|
st.warning("⚠️ Please enter some text before detecting.") |
|
|
else: |
|
|
|
|
|
st.session_state.feedback_given = False |
|
|
|
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
status_placeholder = st.empty() |
|
|
result_placeholder = st.empty() |
|
|
|
|
|
try: |
|
|
|
|
|
with status_placeholder: |
|
|
with st.spinner(f"🔍 Analyzing text in {selected_domain} domain..."): |
|
|
|
|
|
crit, p_value = model.compute_p_value(text_input, selected_domain) |
|
|
elapsed_time = time.time() - start_time |
|
|
|
|
|
|
|
|
if hasattr(crit, 'item'): |
|
|
crit = crit.item() |
|
|
if hasattr(p_value, 'item'): |
|
|
p_value = p_value.item() |
|
|
|
|
|
|
|
|
status_placeholder.empty() |
|
|
|
|
|
|
|
|
st.session_state.last_detection = { |
|
|
'text': text_input, |
|
|
'domain': selected_domain, |
|
|
'statistics': crit, |
|
|
'p_value': p_value, |
|
|
'elapsed_time': elapsed_time |
|
|
} |
|
|
|
|
|
|
|
|
with col2: |
|
|
statistics_ph.text_input( |
|
|
label="Statistics", |
|
|
value=f"{crit:.6f}", |
|
|
disabled=True, |
|
|
help="Detection statistics will appear here after clicking Detect.", |
|
|
) |
|
|
|
|
|
pvalue_ph.text_input( |
|
|
label="p-value", |
|
|
value=f"{p_value:.6f}", |
|
|
disabled=True, |
|
|
help="p-value will appear here after clicking Detect.", |
|
|
) |
|
|
|
|
|
st.info( |
|
|
""" |
|
|
**📊 p-value:** |
|
|
- **Lower p-value** (closer to 0) indicates text is **more likely AI-generated** |
|
|
- **Higher p-value** (closer to 1) indicates text is **more likely human-written** |
|
|
- Generally, p-value < 0.05 suggests the text may be LLM-generated |
|
|
""", |
|
|
icon="💡" |
|
|
) |
|
|
|
|
|
|
|
|
st.markdown("**📝 Result Feedback**: Does this detection result meet your expectations?") |
|
|
|
|
|
current_text = text_input |
|
|
current_domain = selected_domain |
|
|
current_statistics = crit |
|
|
current_pvalue = p_value |
|
|
feedback_col1, feedback_col2 = st.columns(2) |
|
|
|
|
|
with feedback_col1: |
|
|
if st.button("✅ Expected", use_container_width=True, type="secondary", key=f"expected_btn_{hash(text_input[:50])}"): |
|
|
try: |
|
|
success, message = feedback_manager.save_feedback( |
|
|
current_text, |
|
|
current_domain, |
|
|
current_statistics, |
|
|
current_pvalue, |
|
|
'expected' |
|
|
) |
|
|
if success: |
|
|
st.success("✅ Thank you for your feedback!") |
|
|
st.caption(f"💾 {message}") |
|
|
else: |
|
|
st.error(f"Failed to save feedback: {message}") |
|
|
except Exception as e: |
|
|
st.error(f"Failed to save feedback: {str(e)}") |
|
|
import traceback |
|
|
st.code(traceback.format_exc()) |
|
|
|
|
|
with feedback_col2: |
|
|
if st.button("❌ Unexpected", use_container_width=True, type="secondary", key=f"unexpected_btn_{hash(text_input[:50])}"): |
|
|
try: |
|
|
success, message = feedback_manager.save_feedback( |
|
|
current_text, |
|
|
current_domain, |
|
|
current_statistics, |
|
|
current_pvalue, |
|
|
'unexpected' |
|
|
) |
|
|
if success: |
|
|
st.warning("❌ Feedback recorded! This will help us improve.") |
|
|
st.caption(f"💾 {message}") |
|
|
else: |
|
|
st.error(f"Failed to save feedback: {message}") |
|
|
except Exception as e: |
|
|
st.error(f"Failed to save feedback: {str(e)}") |
|
|
import traceback |
|
|
st.code(traceback.format_exc()) |
|
|
|
|
|
if st.session_state.feedback_given: |
|
|
st.success("✅ Feedback submitted successfully!") |
|
|
|
|
|
|
|
|
|
|
|
with result_placeholder: |
|
|
st.caption(f"⏱️ Processing time: {elapsed_time:.2f} seconds") |
|
|
|
|
|
except Exception as e: |
|
|
status_placeholder.empty() |
|
|
st.error(f"❌ Error during detection: {str(e)}") |
|
|
st.exception(e) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown( |
|
|
""" |
|
|
<style> |
|
|
.footer { |
|
|
position: fixed; |
|
|
left: 0; |
|
|
bottom: 0; |
|
|
width: 100%; |
|
|
background-color: white; |
|
|
color: gray; |
|
|
text-align: center; |
|
|
padding: 10px; |
|
|
border-top: 1px solid #e0e0e0; |
|
|
z-index: 999; |
|
|
} |
|
|
|
|
|
/* Add padding to main content to prevent overlap with fixed footer */ |
|
|
.main .block-container { |
|
|
padding-bottom: 60px; |
|
|
} |
|
|
</style> |
|
|
<div class='footer'> |
|
|
<small>Powered by Adaptive LLM Text Detection | For research purposes only</small> |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |