Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import os | |
| import json | |
| from core.model import DiscrepancyEstimator | |
| import re | |
| import docx | |
| import spaces | |
| from datasets import load_dataset | |
| def read_file_content(file): | |
| if file is None: | |
| return "" | |
| if file.name.endswith('.txt'): | |
| with open(file.name, 'r', encoding='utf-8') as f: | |
| return f.read() | |
| elif file.name.endswith('.docx'): | |
| doc = docx.Document(file.name) | |
| full_text = [] | |
| for para in doc.paragraphs: | |
| full_text.append(para.text) | |
| return '\n'.join(full_text) | |
| return "" | |
| def split_sentences(text): | |
| """根据句号、句点、分号分割文本成句子,同时保留分句符号。""" | |
| sentences = re.split(r'([。.])', text) | |
| combined_sentences = [sentences[i] + sentences[i+1] for i in range(0, len(sentences)-1, 2)] | |
| if len(sentences) % 2 == 1: | |
| combined_sentences.append(sentences[-1]) | |
| return [s.strip() for s in combined_sentences if s.strip()] | |
| def count_words(sentence, language='Chinese'): | |
| """统计句子的词数。""" | |
| return len(sentence.replace('\n', '').replace('\r', '').split()) if language != 'Chinese' else len(sentence.replace('\n', '').replace('\r', '')) | |
| def segment_text(sentences, language='Chinese'): | |
| """按照要求拼接句子,确保不忽略第一段并处理最后一句话不足100词的情况。""" | |
| result = [] | |
| current_segment = [] | |
| current_length = 0 | |
| for i, sentence in enumerate(sentences): | |
| word_count = count_words(sentence, language) | |
| if word_count > 100: | |
| # 如果单个句子超过100词,考虑拼接 | |
| if i + 1 < len(sentences) and word_count + count_words(sentences[i + 1], language) <= 200: | |
| # 拼接当前和下一个句子 | |
| if current_segment: # 先保存当前段 | |
| result.append(' '.join(current_segment) if language != 'Chinese' else ''.join(current_segment)) | |
| result.append((sentence + ' ' + sentences[i + 1]) if language != 'Chinese' else (sentence + sentences[i + 1])) | |
| current_segment = [] | |
| current_length = 0 | |
| i += 1 # 跳过下一个句子 | |
| continue | |
| else: | |
| # 单独存放 | |
| if current_segment: # 先保存当前段 | |
| result.append(' '.join(current_segment) if language != 'Chinese' else ''.join(current_segment)) | |
| result.append(sentence) | |
| current_segment = [] | |
| current_length = 0 | |
| else: | |
| if current_length + word_count > 100: | |
| # 当前段超过100词,保存并开始新段 | |
| if current_segment: | |
| result.append(' '.join(current_segment) if language != 'Chinese' else ''.join(current_segment)) | |
| current_segment = [sentence] | |
| current_length = word_count | |
| else: | |
| # 继续累积 | |
| current_segment.append(sentence) | |
| current_length += word_count | |
| # 处理最后一段 | |
| if current_segment: | |
| if current_length < 100 and result and current_length + count_words(result[-1], language) <= 200: | |
| # 如果最后一段不足100词,且可以与前一段合并 | |
| last_segment = result.pop() if result else '' | |
| current_segment = (last_segment.split() if language != 'Chinese' else list(last_segment)) + current_segment | |
| result.append(' '.join(current_segment) if language != 'Chinese' else ''.join(current_segment)) | |
| else: | |
| # 直接添加最后一段 | |
| result.append(' '.join(current_segment) if language != 'Chinese' else ''.join(current_segment)) | |
| return result | |
| def extract_latex_text(latex_source): | |
| # 提取document环境中的内容 | |
| doc_pattern = re.compile(r'\\begin{document}(.*?)\\end{document}', re.DOTALL) | |
| match = doc_pattern.search(latex_source) | |
| content = match.group(1) if match else latex_source | |
| # 删除注释(排除转义后的%) | |
| content = re.sub(r'(?<!\\)%.*', '', content, flags=re.MULTILINE) | |
| # 排除常见非文本环境 | |
| excluded_envs = ['figure', 'table', 'equation', 'align\*?', 'verbatim', 'lstlisting'] | |
| env_pattern = re.compile( | |
| r'\\begin{(' + '|'.join(excluded_envs) + r')}.*?\\end{\1}', | |
| re.DOTALL | |
| ) | |
| content = env_pattern.sub('', content) | |
| # 新增处理:删除所有cite命令及其内容 | |
| content = re.sub(r'\\cite(\[[^\]]*\])?\{[^}]*\}', '', content) | |
| # 新增处理:删除行内table/figure命令及其内容 | |
| content = re.sub(r'\\(table|figure)\*?(\[[^\]]*\])?\{[^}]*\}', '', content) | |
| # 删除简单命令(无参数) | |
| content = re.sub(r'\\([a-zA-Z]+)\*?\b', '', content) | |
| # 递归处理带参数的命令(最多迭代10次防止死循环) | |
| for _ in range(10): | |
| new_content = re.sub( | |
| r'\\([a-zA-Z]+)\*?(?:\[.*?\])*{((?:[^{}]*|{[^{}]*})*)}', | |
| lambda m: m.group(2), | |
| content, | |
| flags=re.DOTALL | |
| ) | |
| if new_content == content: | |
| break | |
| content = new_content | |
| # 处理特殊字符 | |
| replacements = { | |
| '~': ' ', '\\&': '&', '\\$': '$', '\\%': '%', | |
| '\\_': '_', '\\#': '#', '\\\\': '\n', '\n': ' ', | |
| '“': '"', '”': '"', '‘': "'", '’': "'" | |
| } | |
| for k, v in replacements.items(): | |
| content = content.replace(k, v) | |
| # 清理空白字符 | |
| content = re.sub(r'[ \t]+', ' ', content) | |
| content = re.sub(r'\n{2,}', '\n\n', content) | |
| return content.strip() | |
| class ProbEstimator: | |
| def __init__(self, ref_file_dir): | |
| self.tasks = ["polish", "generate", "rewrite"] | |
| self.real_crits = {"polish": [], "generate": [], "rewrite": []} | |
| self.fake_crits = {"polish": [], "generate": [], "rewrite": []} | |
| for task in self.tasks: | |
| task_ref_data = load_dataset(ref_file_dir, data_files=f'{task}.json')['train'] | |
| self.real_crits[task].extend(task_ref_data['original_discrepancy']) | |
| self.fake_crits[task].extend(task_ref_data['rewritten_discrepancy']) | |
| print(f'ProbEstimator: total {sum([len(self.real_crits[task]) for task in self.tasks]) * 2} samples.') | |
| def crit_to_prob(self, crit): | |
| probs = {} | |
| for task in self.tasks: | |
| real_crits = self.real_crits[task] | |
| fake_crits = self.fake_crits[task] | |
| total_len = len(real_crits) + len(fake_crits) | |
| offset = np.sort(np.abs(np.array(real_crits + fake_crits) - crit))[int(0.1*total_len)] | |
| cnt_real = np.sum((np.array(real_crits) > crit - offset) & (np.array(real_crits) < crit + offset)) | |
| cnt_fake = np.sum((np.array(fake_crits) > crit - offset) & (np.array(fake_crits) < crit + offset)) | |
| probs[task] = (cnt_fake / (cnt_real + cnt_fake)) if (cnt_real + cnt_fake) > 0 else 0.5 | |
| return probs | |
| device = 'cuda' | |
| zh_prob_estimator = ProbEstimator(ref_file_dir="JiachenFu/Qwen2-0.5B-detectanyllm-detector-ref-zh") | |
| en_prob_estimator = ProbEstimator(ref_file_dir="JiachenFu/Qwen2-0.5B-detectanyllm-detector-ref-en") | |
| def greet(mode, language, input_text): | |
| if mode == "LaTex": | |
| input_text = extract_latex_text(input_text) | |
| split_texts = split_sentences(input_text) | |
| sub_texts = segment_text(split_texts, language=language) | |
| detected = [] | |
| if language == "Chinese": | |
| model = DiscrepancyEstimator(pretrained_ckpt="JiachenFu/Qwen2-0.5B-detectanyllm-detector-zh").to(device) | |
| prob_estimator = zh_prob_estimator | |
| else: | |
| model = DiscrepancyEstimator(pretrained_ckpt="JiachenFu/Qwen2-0.5B-detectanyllm-detector-en").to(device) | |
| prob_estimator = en_prob_estimator | |
| model.eval() | |
| for i, sub_text in enumerate(sub_texts): | |
| text_content = sub_text | |
| print(f'processing {sub_text}') | |
| tokens = model.scoring_tokenizer( | |
| text_content, return_tensors='pt', padding=True, truncation=True, return_token_type_ids=False | |
| ) | |
| print(f'tokenized') | |
| input_ids = tokens['input_ids'].to(device) | |
| attention_mask = tokens['attention_mask'].to(device) | |
| with torch.no_grad(): | |
| output = model.get_discrepancy_of_scoring_and_reference_models( | |
| input_ids_for_scoring_model=input_ids, | |
| attention_mask_for_scoring_model=attention_mask, | |
| input_ids_for_reference_model=None, | |
| attention_mask_for_reference_model=None, | |
| ) | |
| discrepancy = output['scoring_discrepancy'] | |
| discrepancy = discrepancy.cpu().numpy().item() | |
| print(f'discrepancy: {discrepancy}') | |
| probs = prob_estimator.crit_to_prob(discrepancy) | |
| if discrepancy < 15: | |
| for task in probs.keys(): | |
| probs[task] = 0.0 | |
| detected.append({ | |
| 'order': i, | |
| 'text': text_content, | |
| 'words_count': len(text_content) if language == "Chinese" else len(text_content.split()), | |
| 'probs': probs | |
| }) | |
| # 添加绝对定位的总概率显示 | |
| # 构建动画效果 | |
| html_output = ''' | |
| <style> | |
| @keyframes reveal { | |
| from { opacity: 0; } | |
| to { opacity: 1; } | |
| } | |
| .reveal-char { | |
| opacity: 0; | |
| animation: reveal 0.2s forwards; | |
| white-space: pre-wrap; | |
| } | |
| </style> | |
| <div style="position: relative; padding-bottom: 60px; min-height: 120px;"> | |
| ''' | |
| current_delay = 0.0 # 当前动画延迟时间 | |
| char_duration = 0.001 # 每个字符的间隔时间 | |
| # 处理文本内容 | |
| for item in detected: | |
| ai_generate_prob = item['probs']['generate'] | |
| ai_revise_prob = max(item['probs']['polish'], item['probs']['rewrite']) | |
| prob = max(ai_generate_prob, ai_revise_prob) | |
| if prob >= 0.75: | |
| if ai_generate_prob >= ai_revise_prob: | |
| color = "red" | |
| item["generate"] = 1 | |
| item["revise"] = 0 | |
| else: | |
| color = "orange" | |
| item["generate"] = 0 | |
| item["revise"] = 1 | |
| else: | |
| color = "black" | |
| item["generate"] = 0 | |
| item["revise"] = 0 | |
| for char in item['text']: | |
| html_output += f'<span class="reveal-char" style="color: {color}; animation-delay: {current_delay:.2f}s;">{char}</span>' | |
| current_delay += char_duration | |
| total_length = sum(item['words_count'] for item in detected) | |
| # total_prob = sum(item['prob'] * item['words_count'] for item in detected) / total_length if total_length > 0 else 0 | |
| generate_prob = sum(item["generate"] * item["words_count"] for item in detected) / total_length if total_length > 0 else 0 | |
| revise_prob = sum(item["revise"] * item["words_count"] for item in detected) / total_length if total_length > 0 else 0 | |
| html_output += f''' | |
| <div style=" | |
| position: absolute; | |
| bottom: 0; | |
| right: 0; | |
| background-color: rgba(255, 255, 255, 0.9); | |
| padding: 8px 12px; | |
| border-radius: 4px; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| border: 1px solid #e0e0e0; | |
| font-size: 14px; | |
| "> | |
| 🤖 AI Generated Rate: <strong>{generate_prob:.2%}</strong><br> | |
| ✍️ AI Revised Rate: <strong>{revise_prob:.2%}</strong> | |
| </div> | |
| ''' | |
| html_output += '</div>' | |
| return html_output | |
| # 使用Blocks替代Interface以获得更好的自定义能力 | |
| # 修改CSS部分 | |
| with gr.Blocks(css=""" | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&display=swap'); | |
| :root { | |
| --accent-color: #6366f1; | |
| --text-color: #374151; | |
| --border-color: #e5e7eb; | |
| --background-light: #f9fafb; | |
| --background-card: #ffffff; | |
| } | |
| body, .gradio-container { | |
| background: var(--background-light); | |
| font-family: 'Inter', sans-serif; | |
| color: var(--text-color); | |
| } | |
| #header { | |
| text-align: center; | |
| padding: 2rem; | |
| margin: 0 auto; /* Use gap for spacing, remove margin-bottom */ | |
| background-color: var(--background-card); | |
| background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='40' height='40' viewBox='0 0 40 40'%3E%3Cg fill-rule='evenodd'%3E%3Cg fill='%23e5e7eb' fill-opacity='0.3'%3E%3Cpath d='M0 38.59l2.83-2.83 1.41 1.41L1.41 40H0v-1.41zM0 1.4l2.83 2.83 1.41-1.41L1.41 0H0v1.41zM38.59 40l-2.83-2.83 1.41-1.41L40 38.59V40h-1.41zM40 1.41l-2.83 2.83-1.41-1.41L38.59 0H40v1.41zM20 18.6l2.83-2.83 1.41 1.41L21.41 20l2.83 2.83-1.41 1.41L20 21.41l-2.83 2.83-1.41-1.41L18.59 20l-2.83-2.83 1.41-1.41L20 18.59z'/%3E%3C/g%3E%3C/g%3E%3C/svg%3E"); | |
| border: 1px solid var(--border-color); | |
| border-radius: 16px; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.05); | |
| } | |
| #title { | |
| font-weight: 800; | |
| font-size: 2.5em; | |
| letter-spacing: -0.02em; | |
| color: var(--text-color); | |
| margin-bottom: 0.25em; | |
| } | |
| .detect-grad { | |
| background: -webkit-linear-gradient(left, #ff8c8c, #ffc89e); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| font-weight: 800; | |
| } | |
| .anyllm-grad { | |
| background: -webkit-linear-gradient(left, #a0e6ff, #aaffd4); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| font-weight: 800; | |
| } | |
| #authors { | |
| font-size: 1.1em; | |
| color: #6b7280; | |
| margin: 0; | |
| } | |
| #main-container { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| padding: 0 1rem; | |
| gap: 2rem; /* Add gap for consistent spacing */ | |
| } | |
| #controls-row { | |
| justify-content: center; | |
| gap: 2rem; | |
| } | |
| /* Custom styles for Radio Button Groups */ | |
| #controls-row > div { | |
| background-color: var(--background-card); | |
| background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='40' height='40' viewBox='0 0 40 40'%3E%3Cg fill-rule='evenodd'%3E%3Cg fill='%23e5e7eb' fill-opacity='0.3'%3E%3Cpath d='M0 38.59l2.83-2.83 1.41 1.41L1.41 40H0v-1.41zM0 1.4l2.83 2.83 1.41-1.41L1.41 0H0v1.41zM38.59 40l-2.83-2.83 1.41-1.41L40 38.59V40h-1.41zM40 1.41l-2.83 2.83-1.41-1.41L38.59 0H40v1.41zM20 18.6l2.83-2.83 1.41 1.41L21.41 20l2.83 2.83-1.41 1.41L20 21.41l-2.83 2.83-1.41-1.41L18.59 20l-2.83-2.83 1.41-1.41L20 18.59z'/%3E%3C/g%3E%3C/g%3E%3C/svg%3E"); | |
| border: 1px solid var(--border-color); | |
| border-radius: 16px; | |
| padding: 1rem; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.05); | |
| } | |
| #controls-row .gradio-button { | |
| border-radius: 10px !important; | |
| transition: background-color 0.2s ease, color 0.2s ease; | |
| } | |
| #controls-row .gradio-button.selected { | |
| background: var(--accent-color) !important; | |
| color: white !important; | |
| border-color: var(--accent-color) !important; | |
| } | |
| #content-row { | |
| gap: 1.5rem; | |
| } | |
| .card { | |
| background-color: var(--background-card); | |
| background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='40' height='40' viewBox='0 0 40 40'%3E%3Cg fill-rule='evenodd'%3E%3Cg fill='%23e5e7eb' fill-opacity='0.3'%3E%3Cpath d='M0 38.59l2.83-2.83 1.41 1.41L1.41 40H0v-1.41zM0 1.4l2.83 2.83 1.41-1.41L1.41 0H0v1.41zM38.59 40l-2.83-2.83 1.41-1.41L40 38.59V40h-1.41zM40 1.41l-2.83 2.83-1.41-1.41L38.59 0H40v1.41zM20 18.6l2.83-2.83 1.41 1.41L21.41 20l2.83 2.83-1.41 1.41L20 21.41l-2.83 2.83-1.41-1.41L18.59 20l-2.83-2.83 1.41-1.41L20 18.59z'/%3E%3C/g%3E%3C/g%3E%3C/svg%3E"); | |
| border: 1px solid var(--border-color); | |
| border-radius: 16px; | |
| padding: 1.5rem; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.05); | |
| height: 100%; | |
| display: flex; | |
| flex-direction: column; | |
| gap: 1rem; | |
| } | |
| .card-title { | |
| font-weight: 600; | |
| font-size: 1.2rem; | |
| color: var(--text-color); | |
| padding-bottom: 0.75rem; | |
| border-bottom: 1px solid var(--border-color); | |
| } | |
| #input-text textarea { | |
| flex-grow: 1; | |
| border: none !important; | |
| box-shadow: none !important; | |
| padding: 0 !important; | |
| font-size: 1.1em; | |
| line-height: 1.7; | |
| } | |
| #result-html { | |
| flex-grow: 1; | |
| font-size: 1.1em; | |
| line-height: 1.7; | |
| overflow-y: auto; | |
| height: 520px; | |
| } | |
| #input-footer { | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| margin-top: auto; /* Push to bottom */ | |
| } | |
| #char-counter { | |
| font-size: 0.9em; | |
| color: #9ca3af; | |
| } | |
| #char-counter.error { | |
| color: #ef4444; | |
| } | |
| #submit-btn { | |
| flex-grow: 1; | |
| max-width: 200px; | |
| font-size: 1.05em; | |
| font-weight: 600; | |
| background: var(--accent-color); | |
| color: white; | |
| border-radius: 10px; | |
| } | |
| #submit-btn:hover { | |
| background: #4f46e5; | |
| } | |
| .disclaimer { | |
| text-align: center; | |
| margin: 0 auto; /* Remove vertical margins */ | |
| color: #64748b; | |
| font-size: 1.1em; | |
| max-width: 800px; | |
| } | |
| /* Reveal 动画更丝滑 */ | |
| @keyframes reveal { | |
| from { opacity: 0; } | |
| to { opacity: 1; } | |
| } | |
| .reveal-char { | |
| opacity: 0; | |
| animation: reveal 0.2s forwards; | |
| white-space: pre-wrap; | |
| } | |
| """) as demo: | |
| with gr.Column(elem_id="main-container"): | |
| gr.Markdown(""" | |
| <div id="header"> | |
| <h1 id="title"><span class="detect-grad">Detect</span><span class="anyllm-grad">AnyLLM</span>: Towards Generalizable and Robust Detection of Machine-Generated Text Across Domains and Models</h1> | |
| <p id="authors">Jiachen Fu, Chun-Le Guo, Chongyi Li</p> | |
| </div> | |
| """) | |
| with gr.Row(elem_id="controls-row"): | |
| language_radio = gr.Radio( | |
| choices=["English", "Chinese"], | |
| value="English", | |
| label="🌐 Language", | |
| interactive=True | |
| ) | |
| mode_radio = gr.Radio( | |
| choices=["Text-Only", "LaTex"], | |
| value="Text-Only", | |
| label="✍️ Input Type", | |
| interactive=True | |
| ) | |
| with gr.Row(equal_height=True, elem_id="content-row"): | |
| with gr.Column(scale=1, min_width=500): | |
| with gr.Column(elem_classes="card"): | |
| gr.HTML('<div class="card-title">📝 Input</div>') | |
| upload_btn = gr.File( | |
| label="Upload File (txt, docx)", | |
| file_types=['.txt', '.docx'], | |
| elem_id="upload-btn" | |
| ) | |
| input_text = gr.Textbox( | |
| show_label=False, | |
| placeholder="Enter text to detect or upload a file...", | |
| lines=15, | |
| elem_id="input-text", | |
| max_length=100000, | |
| ) | |
| with gr.Row(elem_id="input-footer"): | |
| counter_html = gr.HTML("<div id='char-counter'>0/100000</div>") | |
| submit_btn = gr.Button("✨ Detect", variant="primary", elem_id="submit-btn") | |
| with gr.Column(scale=1, min_width=500): | |
| with gr.Column(elem_classes="card"): | |
| gr.HTML('<div class="card-title">🔍 Result</div>') | |
| result = gr.HTML(elem_id="result-html") | |
| gr.HTML(""" | |
| <div class="disclaimer"> | |
| 💡 <i><b style="color: red;">Red fonts</b> indicate a high probability of AI generation. <b style="color: orange;">Orange fonts</b> indicate a high probability of AI revision or polishing. The detection results are for reference only.</i> | |
| </div> | |
| """) | |
| upload_btn.upload( | |
| read_file_content, | |
| inputs=upload_btn, | |
| outputs=input_text | |
| ) | |
| input_text.input( | |
| None, | |
| [input_text], | |
| None, | |
| js=""" | |
| (text) => { | |
| setTimeout(() => { | |
| const counter = document.getElementById("char-counter"); | |
| if (counter) { | |
| const length = text.length; | |
| counter.innerHTML = `${length}/100000`; | |
| counter.classList.toggle("error", length > 100000); | |
| } | |
| }, 0); | |
| return text; | |
| } | |
| """ | |
| ) | |
| submit_btn.click( | |
| greet, | |
| inputs=[mode_radio, language_radio, input_text], | |
| outputs=result | |
| ) | |
| demo.launch(share=True) | |