Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" | |
| os.environ["HF_HUB_DISABLE_SYMLINKS"] = "1" | |
| # Load environment variables from .env file | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except ImportError: | |
| pass # dotenv not installed, skip | |
| if "--device" in sys.argv: | |
| try: | |
| device_idx = sys.argv.index("--device") | |
| device_value = sys.argv[device_idx + 1] | |
| os.environ["DOCLING_DEVICE"] = device_value | |
| print(f"[INFO] DOCLING_DEVICE set to: {device_value}") | |
| except (IndexError, ValueError): | |
| print("[WARNING] --device requires a value (cpu or cuda)") | |
| os.environ["GRADIO_LANGUAGE"] = "en" | |
| import uuid | |
| import time | |
| import shutil | |
| import threading | |
| from typing import Optional, Tuple, List, Dict, Any | |
| import gradio as gr | |
| from fastapi import FastAPI | |
| def _noop(self, app: FastAPI): | |
| pass | |
| gr.blocks.Blocks._add_health_routes = _noop | |
| from rebuttal_service import ( | |
| rebuttal_service, | |
| ProcessStatus, | |
| SessionState, | |
| QuestionState, | |
| init_llm_client, | |
| LogCollector, | |
| ) | |
| _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| SAVE_DIR = os.path.join(_CURRENT_DIR, "gradio_uploads") | |
| os.makedirs(SAVE_DIR, exist_ok=True) | |
| def read_gradio_file(file_obj) -> Tuple[Optional[str], Optional[Any]]: | |
| if file_obj is None: | |
| return None, None | |
| if isinstance(file_obj, str): | |
| return "path", file_obj | |
| if isinstance(file_obj, dict) and "data" in file_obj: | |
| return "bytes", file_obj["data"] | |
| if hasattr(file_obj, "read"): | |
| return "fileobj", file_obj.read() | |
| raise ValueError(f"Unknown gr.File object format: {type(file_obj)}") | |
| def save_uploaded_files(pdf_file, review_file, session_id: str) -> Tuple[str, str, str]: | |
| session_dir = os.path.join(SAVE_DIR, session_id) | |
| os.makedirs(session_dir, exist_ok=True) | |
| pdf_save_path = os.path.join(session_dir, "paper.pdf") | |
| review_save_path = os.path.join(session_dir, "review.txt") | |
| pdf_type, pdf_data = read_gradio_file(pdf_file) | |
| if pdf_type is None: | |
| raise ValueError("PDF file upload failed or incorrect format") | |
| if pdf_type == "path": | |
| shutil.copy(pdf_data, pdf_save_path) | |
| elif pdf_type in ("bytes", "fileobj"): | |
| with open(pdf_save_path, "wb") as f: | |
| f.write(pdf_data if isinstance(pdf_data, bytes) else pdf_data) | |
| rev_type, rev_data = read_gradio_file(review_file) | |
| if rev_type is None: | |
| raise ValueError("Review file upload failed or incorrect format") | |
| review_text = "" | |
| if rev_type == "path": | |
| with open(rev_data, "r", encoding="utf-8") as f: | |
| review_text = f.read() | |
| elif rev_type in ("bytes", "fileobj"): | |
| review_text = rev_data.decode("utf-8") if isinstance(rev_data, bytes) else rev_data.decode("utf-8") | |
| with open(review_save_path, "w", encoding="utf-8") as f: | |
| f.write(review_text) | |
| return pdf_save_path, review_save_path, review_text | |
| processing_threads: Dict[str, threading.Thread] = {} | |
| # Provider configurations | |
| PROVIDER_CONFIGS = { | |
| "OpenRouter": { | |
| "provider_key": "openrouter", | |
| "env_var": "OPENROUTER_API_KEY", | |
| "label": "OpenRouter API Key", | |
| "placeholder": "sk-or-v1-...", | |
| }, | |
| "Qwen (DashScope)": { | |
| "provider_key": "qwen", | |
| "env_var": "QWEN_API_KEY", | |
| "label": "Qwen API Key", | |
| "placeholder": "sk-...", | |
| }, | |
| "DeepSeek": { | |
| "provider_key": "deepseek", | |
| "env_var": "DEEPSEEK_API_KEY", | |
| "label": "DeepSeek API Key", | |
| "placeholder": "sk-...", | |
| }, | |
| "OpenAI": { | |
| "provider_key": "openai", | |
| "env_var": "OPENAI_API_KEY", | |
| "label": "OpenAI API Key", | |
| "placeholder": "sk-...", | |
| }, | |
| "Gemini": { | |
| "provider_key": "gemini", | |
| "env_var": "GEMINI_API_KEY", | |
| "label": "Gemini API Key", | |
| "placeholder": "AIza...", | |
| }, | |
| "ZhiPu (GLM)": { | |
| "provider_key": "zhipu", | |
| "env_var": "ZHIPUAI_API_KEY", | |
| "label": "ZhiPu API Key", | |
| "placeholder": "...", | |
| }, | |
| } | |
| # Model choices per provider | |
| MODEL_CHOICES_BY_PROVIDER = { | |
| "OpenRouter": { | |
| "Gemini 3 Flash": "google/gemini-3-flash-preview", | |
| "Grok 4.1 Fast": "x-ai/grok-4.1-fast", | |
| "GPT-5 Mini": "openai/gpt-5-mini", | |
| "DeepSeek V3.2": "deepseek/deepseek-chat-v3.2", | |
| "Other models": "custom", | |
| }, | |
| "Qwen (DashScope)": { | |
| "Qwen-Turbo": "qwen-turbo", | |
| "Qwen-Plus": "qwen-plus", | |
| "Qwen-Max": "qwen-max", | |
| "Other models": "custom", | |
| }, | |
| "DeepSeek": { | |
| "DeepSeek Chat": "deepseek-chat", | |
| "DeepSeek Reasoner": "deepseek-reasoner", | |
| "Other models": "custom", | |
| }, | |
| "OpenAI": { | |
| "GPT-5.2": "gpt-5.2", | |
| "GPT-5 Mini": "gpt-5-mini", | |
| "Other models": "custom", | |
| }, | |
| "Gemini": { | |
| "Gemini-3-Pro": "gemini-3-pro-preview", | |
| "Gemini-3-Flash": "models/gemini-3-flash-preview", | |
| "Other models": "custom", | |
| }, | |
| "ZhiPu (GLM)": { | |
| "GLM-4.7": "glm-4.7", | |
| "Other models": "custom", | |
| }, | |
| } | |
| def get_api_key_for_provider(provider: str) -> str: | |
| """Get API key from environment for specified provider""" | |
| config = PROVIDER_CONFIGS.get(provider, PROVIDER_CONFIGS["OpenRouter"]) | |
| return os.environ.get(config["env_var"], "") | |
| def get_default_model_for_provider(provider: str) -> str: | |
| """Get default model for specified provider""" | |
| models = MODEL_CHOICES_BY_PROVIDER.get(provider, MODEL_CHOICES_BY_PROVIDER["OpenRouter"]) | |
| # Return first model (excluding "Other models") | |
| for name, value in models.items(): | |
| if name != "Other models": | |
| return name | |
| return list(models.keys())[0] | |
| def start_analysis(pdf_file, review_file, provider_choice, api_key, model_choice, custom_model): | |
| if not pdf_file or not review_file: | |
| return ( | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| None, | |
| "⚠️ Please upload paper PDF and review file!", | |
| gr.Timer(active=False), | |
| ) | |
| if not api_key or not api_key.strip(): | |
| return ( | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| None, | |
| "⚠️ Please enter API Key!", | |
| gr.Timer(active=False), | |
| ) | |
| # Get provider key from config | |
| provider_config = PROVIDER_CONFIGS.get(provider_choice, PROVIDER_CONFIGS["OpenRouter"]) | |
| provider_key = provider_config["provider_key"] | |
| # Get model choices for this provider | |
| model_choices = MODEL_CHOICES_BY_PROVIDER.get(provider_choice, MODEL_CHOICES_BY_PROVIDER["OpenRouter"]) | |
| if model_choice == "Other models": | |
| if not custom_model or not custom_model.strip(): | |
| return ( | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| None, | |
| "⚠️ Please enter custom model name!", | |
| gr.Timer(active=False), | |
| ) | |
| selected_model = custom_model.strip() | |
| else: | |
| selected_model = model_choices.get(model_choice, list(model_choices.values())[0]) | |
| session_id = str(uuid.uuid4())[:8] | |
| try: | |
| init_llm_client(api_key=api_key.strip(), provider=provider_key, model=selected_model) | |
| pdf_path, review_path, _ = save_uploaded_files(pdf_file, review_file, session_id) | |
| session = rebuttal_service.create_session(session_id, pdf_path, review_path) | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| {"session_id": session_id, "current_idx": 0}, | |
| "📤 Files uploaded successfully, initializing analysis...", | |
| gr.Timer(active=True), | |
| ) | |
| except Exception as e: | |
| return ( | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| None, | |
| f"❌ Processing failed: {str(e)}", | |
| gr.Timer(active=False), | |
| ) | |
| def run_initial_analysis(session_state): | |
| if not session_state: | |
| return ( | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| session_state, | |
| "❌ Session state lost", | |
| gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(), gr.update(), gr.update(), | |
| gr.Timer(active=False), | |
| ) | |
| session_id = session_state.get("session_id") | |
| try: | |
| session = rebuttal_service.run_initial_analysis(session_id) | |
| rebuttal_service.process_all_questions_parallel(session_id, max_workers=3) | |
| session = rebuttal_service.get_session(session_id) | |
| session_state["current_idx"] = 0 | |
| q_state = session.questions[0] | |
| history_text = format_feedback_history(q_state.feedback_history) | |
| strategy_content = q_state.agent7_output or "" | |
| # Determine button text based on whether this is the last question | |
| is_last_question = len(session.questions) == 1 | |
| btn_text = "📝 Generate Final Rebuttal" if is_last_question else "✅ Satisfied, Next Question" | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| session_state, | |
| "", | |
| f"### Question 1 / {len(session.questions)}", | |
| q_state.question_text, | |
| strategy_content, | |
| strategy_content, | |
| "", | |
| f"📝 Revisions have been revised {q_state.revision_count} times", | |
| gr.update(interactive=True), | |
| history_text, | |
| gr.Timer(active=False), | |
| gr.update(value=btn_text), | |
| ) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return ( | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| session_state, | |
| f"❌ Analysis failed : {str(e)}", | |
| gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(), gr.update(), gr.update(), | |
| gr.Timer(active=False), | |
| gr.update(), | |
| ) | |
| def regenerate_strategy(feedback_text, session_state): | |
| if not session_state: | |
| return ( | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| session_state, | |
| ) | |
| if not feedback_text or not feedback_text.strip(): | |
| return ( | |
| gr.update(), | |
| gr.update(), | |
| "⚠️ Please enter feedback ", | |
| gr.update(), | |
| session_state, | |
| ) | |
| session_id = session_state.get("session_id") | |
| current_idx = session_state.get("current_idx", 0) | |
| try: | |
| q_state = rebuttal_service.revise_with_feedback( | |
| session_id, | |
| current_idx, | |
| feedback_text.strip() | |
| ) | |
| history_text = format_feedback_history(q_state.feedback_history) | |
| strategy_content = q_state.agent7_output or "" | |
| return ( | |
| strategy_content, | |
| strategy_content, | |
| "", | |
| f"📝 Revisions have been revised {q_state.revision_count} times ✓ Latest revision applied", | |
| history_text, | |
| session_state, | |
| ) | |
| except Exception as e: | |
| return ( | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| f"❌ Revision failed: {str(e)}", | |
| gr.update(), | |
| session_state, | |
| ) | |
| def format_feedback_history(history: list) -> str: | |
| if not history: | |
| return "*No revisions yet*" | |
| lines = [] | |
| for i, record in enumerate(history, 1): | |
| feedback = record.get("feedback", "") | |
| if len(feedback) > 100: | |
| feedback = feedback[:100] + "..." | |
| lines.append(f"**#{i}** {feedback}") | |
| return "\n\n".join(lines) | |
| def generate_strategy_summary(session) -> str: | |
| lines = [] | |
| lines.append(" This document contains all questions' rebuttal strategies and To-Do List\n") | |
| lines.append("=" * 60 + "\n") | |
| for q in session.questions: | |
| lines.append(f"## Q{q.question_id}: {q.question_text[:100]}{'...' if len(q.question_text) > 100 else ''}") | |
| lines.append("") | |
| lines.append("### Rebuttal strategy & To-Do List") | |
| lines.append("") | |
| lines.append(q.agent7_output if q.agent7_output else "**Not generated**") | |
| lines.append("") | |
| if q.revision_count > 0: | |
| lines.append(f"> 📝 Revisions have been revised {q.revision_count} times") | |
| lines.append("") | |
| lines.append("-" * 40) | |
| lines.append("") | |
| return "\n".join(lines) | |
| def skip_question(session_state): | |
| if not session_state: | |
| return ( | |
| gr.update(), | |
| gr.update(), | |
| session_state, | |
| gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(), gr.update(), gr.update(), | |
| gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(), | |
| ) | |
| session_id = session_state.get("session_id") | |
| current_idx = session_state.get("current_idx", 0) | |
| try: | |
| session = rebuttal_service.get_session(session_id) | |
| if not session: | |
| raise ValueError(f"Session {session_id} not found") | |
| rebuttal_service.mark_question_satisfied(session_id, current_idx) | |
| next_idx = current_idx + 1 | |
| if next_idx < len(session.questions): | |
| q_state = session.questions[next_idx] | |
| session_state["current_idx"] = next_idx | |
| history_text = format_feedback_history(q_state.feedback_history) | |
| strategy_content = q_state.agent7_output or "" | |
| # Determine button text based on whether this is the last question | |
| is_last_question = (next_idx + 1) == len(session.questions) | |
| btn_text = "📝 Generate Final Rebuttal" if is_last_question else "✅ Satisfied, Next Question" | |
| return ( | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| session_state, | |
| f"### Question {next_idx + 1} / {len(session.questions)}", | |
| q_state.question_text, | |
| strategy_content, | |
| strategy_content, | |
| "", | |
| f"📝 Revisions have been revised {q_state.revision_count} times", | |
| gr.update(interactive=True), | |
| history_text, | |
| gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(value=btn_text), | |
| ) | |
| else: | |
| strategy_summary = generate_strategy_summary(session) | |
| final_text = rebuttal_service.generate_final_rebuttal(session_id) | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| session_state, | |
| gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(), gr.update(), gr.update(), | |
| strategy_summary, strategy_summary, final_text, final_text, | |
| gr.update(), | |
| ) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return ( | |
| gr.update(), | |
| gr.update(), | |
| session_state, | |
| gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(), gr.update(), gr.update(), | |
| gr.update(), gr.update(), gr.update(), f"❌ Processing failed: {str(e)}", | |
| gr.update(), | |
| ) | |
| def confirm_and_next(strategy_text, session_state): | |
| if not session_state: | |
| return ( | |
| gr.update(), | |
| gr.update(), | |
| session_state, | |
| gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(), gr.update(), gr.update(), | |
| gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(), | |
| ) | |
| session_id = session_state.get("session_id") | |
| current_idx = session_state.get("current_idx", 0) | |
| try: | |
| session = rebuttal_service.get_session(session_id) | |
| if not session: | |
| raise ValueError(f"Session {session_id} not found") | |
| session.questions[current_idx].agent7_output = strategy_text | |
| rebuttal_service.mark_question_satisfied(session_id, current_idx) | |
| next_idx = current_idx + 1 | |
| if next_idx < len(session.questions): | |
| q_state = session.questions[next_idx] | |
| session_state["current_idx"] = next_idx | |
| history_text = format_feedback_history(q_state.feedback_history) | |
| strategy_content = q_state.agent7_output or "" | |
| # Determine button text based on whether this is the last question | |
| is_last_question = (next_idx + 1) == len(session.questions) | |
| btn_text = "📝 Generate Final Rebuttal" if is_last_question else "✅ Satisfied, Next Question" | |
| return ( | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| session_state, | |
| f"### Question {next_idx + 1} / {len(session.questions)}", | |
| q_state.question_text, | |
| strategy_content, | |
| strategy_content, | |
| "", | |
| f"📝 Revisions have been revised {q_state.revision_count} times", | |
| gr.update(interactive=True), | |
| history_text, | |
| gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(value=btn_text), | |
| ) | |
| else: | |
| strategy_summary = generate_strategy_summary(session) | |
| final_text = rebuttal_service.generate_final_rebuttal(session_id) | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| session_state, | |
| gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(), gr.update(), gr.update(), | |
| strategy_summary, strategy_summary, final_text, final_text, | |
| gr.update(), | |
| ) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return ( | |
| gr.update(), | |
| gr.update(), | |
| session_state, | |
| gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(), gr.update(), gr.update(), | |
| gr.update(), gr.update(), gr.update(), f"❌ Processing failed: {str(e)}", | |
| gr.update(), | |
| ) | |
| def restart_session(): | |
| return ( | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| None, | |
| "", | |
| None, None, | |
| ) | |
| def poll_logs(session_state): | |
| """Poll logs for real-time updates on the loading page""" | |
| if not session_state: | |
| return gr.update(), session_state | |
| session_id = session_state.get("session_id") | |
| if not session_id: | |
| return gr.update(), session_state | |
| session = rebuttal_service.get_session(session_id) | |
| if not session or not session.log_collector: | |
| return gr.update(), session_state | |
| logs = session.log_collector.get_recent(30) | |
| if not logs: | |
| return gr.update(), session_state | |
| prev_logs = session_state.get("_prev_logs", "") | |
| if logs == prev_logs: | |
| return gr.update(), session_state | |
| session_state["_prev_logs"] = logs | |
| return logs, session_state | |
| # CSS for the application | |
| APP_CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Noto+Sans+SC:wght@400;500;700&display=swap'); | |
| /* Global fonts: Georgia for English, Noto Sans SC for Chinese */ | |
| * { | |
| font-family: Georgia, 'Noto Sans SC', 'PingFang SC', 'Hiragino Sans GB', sans-serif !important; | |
| } | |
| .prose, .prose * { | |
| font-family: Georgia, 'Noto Sans SC', 'PingFang SC', 'Hiragino Sans GB', sans-serif !important; | |
| } | |
| /* Code blocks keep monospace font */ | |
| code, pre, .code, pre *, code * { | |
| font-family: 'Consolas', 'Monaco', 'Courier New', monospace !important; | |
| } | |
| .strategy-preview { | |
| background: #f8fafc; | |
| border: 1px solid #e2e8f0; | |
| border-radius: 12px; | |
| padding: 24px; | |
| line-height: 1.8; | |
| max-height: 600px; | |
| overflow-y: auto; | |
| } | |
| .strategy-preview h3 { | |
| color: #1e40af; | |
| border-bottom: 2px solid #3b82f6; | |
| padding-bottom: 8px; | |
| margin-top: 20px; | |
| } | |
| .strategy-preview h4 { | |
| color: #7c3aed; | |
| margin-top: 16px; | |
| } | |
| .strategy-preview strong { | |
| color: #1e293b; | |
| border-radius: 4px; /* Optional: adds subtle highlight connection */ | |
| } | |
| .strategy-preview table { | |
| width: 100%; | |
| border-collapse: collapse; | |
| margin: 12px 0; | |
| } | |
| .strategy-preview th, .strategy-preview td { | |
| border: 1px solid #e2e8f0; | |
| padding: 8px 12px; | |
| text-align: left; | |
| } | |
| .strategy-preview th { | |
| background: #f1f5f9; | |
| } | |
| .strategy-edit textarea { | |
| font-family: 'Consolas', 'Monaco', monospace !important; | |
| font-size: 13px !important; | |
| line-height: 1.5 !important; | |
| background: #1e293b !important; | |
| color: #e2e8f0 !important; | |
| border-radius: 8px !important; | |
| } | |
| .question-box { | |
| background: linear-gradient(135deg, #fef3c7, #fef9c3); | |
| border-left: 4px solid #f59e0b; | |
| border-radius: 8px; | |
| padding: 16px; | |
| } | |
| .feedback-box textarea { | |
| border: 2px solid #4CAF50; | |
| } | |
| #log-display { | |
| background: #f8fafc; | |
| color: #334155; | |
| border: 1px solid #e2e8f0; | |
| border-radius: 8px; | |
| padding: 16px; | |
| font-family: 'Consolas', 'Monaco', 'Courier New', monospace; | |
| font-size: 12px; | |
| line-height: 1.6; | |
| max-height: 300px; | |
| overflow-y: auto; | |
| } | |
| /* Download tip animation */ | |
| @keyframes pulse-glow { | |
| 0%, 100% { | |
| opacity: 1; | |
| transform: scale(1); | |
| } | |
| 50% { | |
| opacity: 0.7; | |
| transform: scale(1.02); | |
| } | |
| } | |
| @keyframes arrow-bounce { | |
| 0%, 100% { transform: translateX(0); } | |
| 50% { transform: translateX(5px); } | |
| } | |
| .download-tip { | |
| background: linear-gradient(135deg, #fef3c7, #fde68a); | |
| border: 2px solid #f59e0b; | |
| border-radius: 12px; | |
| padding: 12px 16px; | |
| margin-top: 12px; | |
| animation: pulse-glow 2s ease-in-out infinite; | |
| } | |
| .download-tip em { | |
| font-style: normal; | |
| display: flex; | |
| align-items: center; | |
| gap: 8px; | |
| } | |
| /* Important warning notice - single layer only */ | |
| .important-warning { | |
| background: linear-gradient(135deg, #fef2cd, #fff3cd) !important; | |
| border: 2px solid #ff9800 !important; | |
| border-left: 6px solid #ff5722 !important; | |
| border-radius: 8px !important; | |
| padding: 16px 20px !important; | |
| margin: 16px 0 !important; | |
| box-shadow: 0 4px 12px rgba(255, 152, 0, 0.25) !important; | |
| } | |
| .important-warning * { | |
| background: transparent !important; | |
| border: none !important; | |
| box-shadow: none !important; | |
| padding: 0 !important; | |
| margin: 0 !important; | |
| } | |
| .important-warning p { | |
| margin: 4px 0 !important; | |
| color: #5d4037 !important; | |
| font-weight: 500 !important; | |
| } | |
| /* Bright download buttons */ | |
| #download-strategy-btn, #download-rebuttal-btn { | |
| background: linear-gradient(135deg, #22c55e, #16a34a) !important; | |
| border: none !important; | |
| color: white !important; | |
| font-weight: bold !important; | |
| font-size: 16px !important; | |
| padding: 16px 24px !important; | |
| box-shadow: 0 4px 15px rgba(34, 197, 94, 0.4) !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| #download-strategy-btn:hover, #download-rebuttal-btn:hover { | |
| background: linear-gradient(135deg, #16a34a, #15803d) !important; | |
| box-shadow: 0 6px 20px rgba(34, 197, 94, 0.6) !important; | |
| transform: translateY(-2px) !important; | |
| } | |
| """ | |
| with gr.Blocks(title="AI Rebuttal Assistant", theme=gr.themes.Soft(), css=APP_CSS) as demo: | |
| session_state = gr.State(None) | |
| gr.Markdown( | |
| """ | |
| # AI Rebuttal Assistant | |
| **We recommend to deploy our work locally, following the instructions in [GitHub](https://github.com/AutoLab-SAI-SJTU/Paper2Rebuttal) , running on huggingface may be very slow, taking up to an hour to finish** | |
| **Workflow usage process:** | |
| - **Upload** - Upload your paper PDF and the review file (Supports .txt or .md format, please input the original document containing all reviewers' comments, including reviewers' IDs ; there is no need to split them manually) | |
| - **Analysis** - The system will automatically analyze your paper and extract questions from the review | |
| - **Review Strategy** - For each question, view the AI-generated rebuttal strategy and to-do list and referenced response snippets | |
| - **Refinement** - Enter your feedback and click "Regenerate" to refine the strategy | |
| - **Generate Rebuttal** - After all questions are processed, generate the final rebuttal document | |
| """ | |
| ) | |
| with gr.Column(visible=True) as upload_col: | |
| gr.Markdown("## 📤 Configure & Upload Files") | |
| with gr.Group(): | |
| gr.Markdown("### 🔑 API Configuration") | |
| # Provider selection | |
| provider_choice = gr.Dropdown( | |
| label="LLM Provider", | |
| choices=list(PROVIDER_CONFIGS.keys()), | |
| value="OpenRouter", | |
| info="Select your LLM provider", | |
| ) | |
| # Pre-fill API key from environment variable based on provider | |
| default_provider = "OpenRouter" | |
| env_api_key = get_api_key_for_provider(default_provider) | |
| api_key_input = gr.Textbox( | |
| label=PROVIDER_CONFIGS[default_provider]["label"], | |
| placeholder=f"Please enter your API Key ({PROVIDER_CONFIGS[default_provider]['placeholder']})", | |
| value=env_api_key, | |
| type="password", | |
| info="Your API key will not be stored, only used for this session." + (" (Loaded from .env)" if env_api_key else "") | |
| ) | |
| def on_provider_change(provider): | |
| """Update API key field and model choices when provider changes""" | |
| config = PROVIDER_CONFIGS.get(provider, PROVIDER_CONFIGS["OpenRouter"]) | |
| env_key = get_api_key_for_provider(provider) | |
| model_choices = MODEL_CHOICES_BY_PROVIDER.get(provider, MODEL_CHOICES_BY_PROVIDER["OpenRouter"]) | |
| default_model = get_default_model_for_provider(provider) | |
| return ( | |
| gr.update( | |
| label=config["label"], | |
| placeholder=f"Please enter your API Key ({config['placeholder']})", | |
| value=env_key, | |
| info="Your API key will not be stored, only used for this session." + (" (Loaded from .env)" if env_key else "") | |
| ), | |
| gr.update( | |
| choices=list(model_choices.keys()), | |
| value=default_model, | |
| ), | |
| ) | |
| gr.Markdown("---") | |
| with gr.Group(): | |
| gr.Markdown("### 🤖 Model Selection") | |
| with gr.Row(): | |
| model_choice = gr.Dropdown( | |
| label="Select Model", | |
| choices=list(MODEL_CHOICES_BY_PROVIDER["OpenRouter"].keys()), | |
| value="Gemini 3 Flash", | |
| info="Choose the LLM model to use", | |
| scale=2, | |
| ) | |
| custom_model_input = gr.Textbox( | |
| label="Custom Model Name", | |
| placeholder="Enter model name", | |
| visible=False, | |
| scale=3, | |
| ) | |
| def toggle_custom_model(choice): | |
| return gr.update(visible=(choice == "Other models")) | |
| model_choice.change( | |
| fn=toggle_custom_model, | |
| inputs=[model_choice], | |
| outputs=[custom_model_input], | |
| ) | |
| # Connect provider change to update API key and model choices | |
| provider_choice.change( | |
| fn=on_provider_change, | |
| inputs=[provider_choice], | |
| outputs=[api_key_input, model_choice], | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### 📄 Upload Files") | |
| with gr.Row(): | |
| pdf_input = gr.File( | |
| label="📄 Paper PDF", | |
| file_types=[".pdf"], | |
| file_count="single", | |
| ) | |
| review_input = gr.File( | |
| label="📝 Review File(.md / .txt)", | |
| file_types=[".md", ".txt"], | |
| file_count="single", | |
| ) | |
| upload_status = gr.Markdown("") | |
| start_btn = gr.Button( | |
| "🚀 Submit & Start Analysis", | |
| variant="primary", | |
| size="lg", | |
| ) | |
| with gr.Column(visible=False) as loading_col: | |
| gr.Markdown("## ⏳ Analyzing...") | |
| loading_status = gr.Markdown("Initializing...") | |
| gr.Markdown( | |
| """ | |
| > 📊 **Analysis Process:** | |
| > 1. Convert PDF to Markdown | |
| > 2. AI reads and summarizes the paper | |
| > 3. AI extracts questions from the review | |
| > 4. Process all questions in parallel | |
| > 5. Present results for your review | |
| 🚀 **All questions will be processed in parallel**, so you can quickly review and refine each one after completion! | |
| After all questions are processed, you can generate the referenced final rebuttal document. | |
| This may take about 15 minutes (use gpu) or up to 1 hour (cpu), please be patient... | |
| """ | |
| ) | |
| gr.Markdown("### 📋 Live Logs") | |
| log_display = gr.Textbox( | |
| value="Waiting to start...", | |
| label="", | |
| lines=10, | |
| max_lines=15, | |
| interactive=False, | |
| elem_id="log-display", | |
| ) | |
| log_timer = gr.Timer(value=1.5, active=False) | |
| with gr.Column(visible=False) as interact_col: | |
| with gr.Row(): | |
| progress_info = gr.Markdown("### Question 1 / N") | |
| processing_status = gr.Markdown("", elem_id="processing-status") | |
| gr.Markdown( | |
| """ | |
| > 📘 **Quick Reference:** | |
| > - **Strategy** — High-level approach and key arguments to address this reviewer question | |
| > - **To-Do List** — Concrete action items (experiments, analysis, writing) to implement the strategy | |
| > - **Response Draft** — Snippets you can refer to when writing your rebuttal | |
| """ | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("#### 🔍 Reviewer's Question") | |
| question_display = gr.Markdown(elem_classes=["question-box"]) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| gr.Markdown("#### 💡 Rebuttal Strategy & To-Do List Rebuttal") | |
| with gr.Tabs(): | |
| with gr.TabItem("📖 Preview"): | |
| gr.Markdown("*Rendered strategy content below:*") | |
| strategy_preview = gr.Markdown(elem_classes=["strategy-preview"]) | |
| with gr.TabItem("✏️ Edit"): | |
| gr.Markdown("*Edit raw Markdown, switch back to Preview to see results:*") | |
| strategy_edit = gr.Textbox( | |
| label="", | |
| lines=20, | |
| max_lines=40, | |
| elem_classes=["strategy-edit"], | |
| ) | |
| revision_info = gr.Markdown("📝 Revisions have been modified 0 times") | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### 📜 Revision History") | |
| feedback_history_display = gr.Markdown( | |
| "*No revisions yet*", | |
| elem_id="feedback-history", | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("#### 💬 Human Feedback") | |
| gr.Markdown("*Enter your feedback, and the AI will adjust the strategy accordingly. You can revise multiple times until satisfied.*") | |
| with gr.Row(): | |
| feedback_input = gr.Textbox( | |
| label="Feedback", | |
| placeholder="For example:\n• Please add more discussion on experimental data\n• This strategy is too general, needs more specific suggestions\n• Please include a comparison with baseline methods", | |
| lines=4, | |
| scale=4, | |
| elem_classes=["feedback-box"], | |
| ) | |
| with gr.Column(scale=1): | |
| regenerate_btn = gr.Button( | |
| "🔄 Regenerate", | |
| variant="secondary", | |
| size="lg", | |
| ) | |
| gr.Markdown("*AI will revise strategy\nbased on your feedback*", elem_id="regen-hint") | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| skip_btn = gr.Button( | |
| "⏭️ Skip This Question", | |
| variant="secondary", | |
| size="lg", | |
| ) | |
| confirm_btn = gr.Button( | |
| "✅ Satisfied, Next Question", | |
| variant="primary", | |
| size="lg", | |
| ) | |
| with gr.Column(visible=False) as result_col: | |
| gr.Markdown("## 🎉 Complete!") | |
| gr.Markdown( | |
| """ | |
| All questions have been processed. Here are the generated results: | |
| This page contains **two modules**: | |
| 1. **Strategy Summary** - Contains rebuttal strategies, To-Do Lists, and draft response snippets for all questions. | |
| 2. **Final Reference Rebuttal** - A complete rebuttal document for your reference. | |
| """ | |
| ) | |
| gr.Markdown( | |
| """ | |
| ⚠️ **IMPORTANT NOTICE:** The Final Reference Rebuttal contains **LLM-estimated numerical values marked with asterisks (*)**. | |
| These estimated values are placeholders and **MUST be replaced with actual experimental results**. | |
| Please carefully review and verify all data. | |
| """, | |
| elem_classes=["important-warning"] | |
| ) | |
| with gr.Tabs(): | |
| with gr.TabItem("📝 Strategy Summary"): | |
| gr.Markdown("*Contains rebuttal strategies, To-Do Lists, and draft response snippets for all questions.*") | |
| with gr.Tabs(): | |
| with gr.TabItem("📖 Preview"): | |
| strategy_summary_preview = gr.Markdown(elem_classes=["strategy-preview"]) | |
| with gr.TabItem("✏️ Raw Text"): | |
| strategy_summary_output = gr.Textbox( | |
| label="Rebuttal Strategy & To-Do List", | |
| lines=20, | |
| max_lines=40, | |
| ) | |
| with gr.TabItem("📄 Final Reference Rebuttal"): | |
| gr.Markdown( | |
| """ | |
| *The complete reference rebuttal document.* | |
| > ⚠️ **Note:** Numerical values marked with **asterisks (*)** are LLM-estimated placeholders. | |
| > You **MUST supplement these with actual experimental data** . | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| with gr.TabItem("📖 Preview"): | |
| final_preview = gr.Markdown(elem_classes=["strategy-preview"]) | |
| with gr.TabItem("✏️ Raw Text"): | |
| final_output = gr.Textbox( | |
| label="Final Reference Rebuttal", | |
| lines=20, | |
| max_lines=40, | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### 📥 Download Files") | |
| with gr.Row(): | |
| download_strategy_btn = gr.Button( | |
| "📥 Download Strategy Summary", | |
| variant="primary", | |
| size="lg", | |
| elem_id="download-strategy-btn", | |
| ) | |
| download_rebuttal_btn = gr.Button( | |
| "📥 Download Reference Rebuttal", | |
| variant="primary", | |
| size="lg", | |
| elem_id="download-rebuttal-btn", | |
| ) | |
| restart_btn = gr.Button( | |
| "🔄 Start Over", | |
| variant="secondary", | |
| size="lg", | |
| ) | |
| download_strategy_file = gr.File(label="Strategy File", visible=False) | |
| download_rebuttal_file = gr.File(label="Rebuttal File", visible=False) | |
| gr.Markdown( | |
| "💡 **Tip:** After clicking the download button, click the **file size link on the right** ➡️ of the file component to start download.", | |
| elem_classes=["download-tip"] | |
| ) | |
| start_btn.click( | |
| fn=start_analysis, | |
| inputs=[pdf_input, review_input, provider_choice, api_key_input, model_choice, custom_model_input], | |
| outputs=[ | |
| upload_col, loading_col, interact_col, result_col, | |
| session_state, upload_status, log_timer, | |
| ], | |
| ).then( | |
| fn=run_initial_analysis, | |
| inputs=[session_state], | |
| outputs=[ | |
| upload_col, loading_col, interact_col, result_col, | |
| session_state, loading_status, | |
| progress_info, question_display, strategy_preview, strategy_edit, feedback_input, | |
| revision_info, regenerate_btn, feedback_history_display, log_timer, | |
| confirm_btn, | |
| ], | |
| ) | |
| log_timer.tick( | |
| fn=poll_logs, | |
| inputs=[session_state], | |
| outputs=[log_display, session_state], | |
| ) | |
| regenerate_btn.click( | |
| fn=regenerate_strategy, | |
| inputs=[feedback_input, session_state], | |
| outputs=[strategy_preview, strategy_edit, feedback_input, revision_info, feedback_history_display, session_state], | |
| ) | |
| def sync_preview(text): | |
| return text | |
| strategy_edit.blur( | |
| fn=sync_preview, | |
| inputs=[strategy_edit], | |
| outputs=[strategy_preview], | |
| ) | |
| confirm_btn.click( | |
| fn=confirm_and_next, | |
| inputs=[strategy_edit, session_state], | |
| outputs=[ | |
| interact_col, result_col, session_state, | |
| progress_info, question_display, strategy_preview, strategy_edit, feedback_input, | |
| revision_info, regenerate_btn, feedback_history_display, | |
| strategy_summary_preview, strategy_summary_output, final_preview, final_output, | |
| confirm_btn, | |
| ], | |
| ) | |
| skip_btn.click( | |
| fn=skip_question, | |
| inputs=[session_state], | |
| outputs=[ | |
| interact_col, result_col, session_state, | |
| progress_info, question_display, strategy_preview, strategy_edit, feedback_input, | |
| revision_info, regenerate_btn, feedback_history_display, | |
| strategy_summary_preview, strategy_summary_output, final_preview, final_output, | |
| confirm_btn, | |
| ], | |
| ) | |
| restart_btn.click( | |
| fn=restart_session, | |
| inputs=[], | |
| outputs=[ | |
| upload_col, loading_col, interact_col, result_col, | |
| session_state, upload_status, | |
| pdf_input, review_input, | |
| ], | |
| ) | |
| def download_strategy(strategy_text): | |
| if not strategy_text: | |
| return gr.update() | |
| import tempfile | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='_strategy.md', delete=False, encoding='utf-8') as f: | |
| f.write("# Rebuttal Strategy & To-Do List \n\n") | |
| f.write(strategy_text) | |
| return gr.update(value=f.name, visible=True) | |
| download_strategy_btn.click( | |
| fn=download_strategy, | |
| inputs=[strategy_summary_output], | |
| outputs=[download_strategy_file], | |
| ) | |
| def download_rebuttal(final_text): | |
| if not final_text: | |
| return gr.update() | |
| import tempfile | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='_rebuttal.md', delete=False, encoding='utf-8') as f: | |
| f.write("# Final Rebuttal\n\n") | |
| f.write(final_text) | |
| return gr.update(value=f.name, visible=True) | |
| download_rebuttal_btn.click( | |
| fn=download_rebuttal, | |
| inputs=[final_output], | |
| outputs=[download_rebuttal_file], | |
| ) | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="AI Rebuttal Assistant") | |
| parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host address") | |
| parser.add_argument("--port", type=int, default=7860, help="Server port") | |
| parser.add_argument("--share", action="store_true", help="Create public link") | |
| parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "cuda"], | |
| help="Device for docling PDF processing (cpu or cuda)") | |
| args = parser.parse_args() | |
| device_used = os.environ.get("DOCLING_DEVICE", "cpu") | |
| print(f"\n🚀 Starting AI Rebuttal Assistant") | |
| print(f" URL: http://localhost:{args.port}") | |
| print(f" Device: {device_used.upper()}") | |
| print(f" Share: {'Yes' if args.share else 'No'}\n") | |
| demo.launch( | |
| server_name=args.host, | |
| server_port=args.port, | |
| share=args.share, | |
| ) | |