|
|
"""Main Streamlit app for ScicoQA Discrepancy Detection Demo.""" |
|
|
|
|
|
import logging |
|
|
import os |
|
|
import time |
|
|
from pathlib import Path |
|
|
|
|
|
import streamlit as st |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
from core.arxiv2md_demo import Arxiv2MD |
|
|
from core.code_loader_demo import CodeLoader |
|
|
from core.llm_demo import LLM |
|
|
from core.model_config import ( |
|
|
PROVIDER_PRESETS, |
|
|
|
|
|
create_provider_model_config, |
|
|
get_api_key_env_name, |
|
|
get_provider_from_model, |
|
|
) |
|
|
|
|
|
from core.openrouter_models import fetch_free_models, get_model_config |
|
|
from core.prompt_demo import Prompt |
|
|
from core.token_counter_demo import TokenCounter |
|
|
from parsing import parse_discrepancies |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
CONTEXT_BUFFER_FACTOR = 0.9 |
|
|
MAX_CONTEXT_SIZE = 131072 |
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="SciCoQA Paper- Code Discrepancy Detection", |
|
|
page_icon="π¬", |
|
|
layout="wide", |
|
|
initial_sidebar_state=400, |
|
|
) |
|
|
|
|
|
|
|
|
def _redact_secrets(text: str, secrets: list[str | None]) -> str: |
|
|
"""Best-effort redaction for secrets that may appear in exception strings/logs.""" |
|
|
redacted = text |
|
|
for secret in secrets: |
|
|
if secret and secret in redacted: |
|
|
redacted = redacted.replace(secret, "***REDACTED***") |
|
|
return redacted |
|
|
|
|
|
|
|
|
def _safe_model_config_for_session(model_config: dict | None) -> dict | None: |
|
|
"""Store model config in session state WITHOUT sensitive fields like API keys.""" |
|
|
if not model_config: |
|
|
return model_config |
|
|
|
|
|
safe = dict(model_config) |
|
|
safe.pop("api_key", None) |
|
|
safe.pop("apiKey", None) |
|
|
return safe |
|
|
|
|
|
|
|
|
def _is_context_length_error(error_msg: str) -> bool: |
|
|
""" |
|
|
Check if an error message indicates a context length error. |
|
|
|
|
|
Args: |
|
|
error_msg: The error message string |
|
|
|
|
|
Returns: |
|
|
True if it's a context length error, False otherwise |
|
|
""" |
|
|
error_lower = error_msg.lower() |
|
|
return ( |
|
|
"maximum context length" in error_lower |
|
|
or "requested about" in error_lower |
|
|
or ("context length is" in error_lower and "you requested" in error_lower) |
|
|
or "context window" in error_lower |
|
|
) |
|
|
|
|
|
|
|
|
def _build_prompt( |
|
|
paper_text: str, |
|
|
code_loader: CodeLoader | None, |
|
|
code_text: str | None, |
|
|
model_config: dict, |
|
|
token_counter: TokenCounter, |
|
|
code_reduction_factor: float = 1.0, |
|
|
) -> tuple[str, str, int, bool]: |
|
|
""" |
|
|
Build prompt by counting tokens and truncating code until prompt + paper + code < CONTEXT_BUFFER_FACTOR * model context length. |
|
|
|
|
|
Args: |
|
|
paper_text: The paper text |
|
|
code_loader: CodeLoader instance (if using GitHub repo) |
|
|
code_text: Raw code text (if using uploaded file) |
|
|
model_config: Model configuration dictionary |
|
|
token_counter: TokenCounter instance |
|
|
code_reduction_factor: Factor to reduce code tokens (1.0 = no reduction, 0.9 = 10% reduction, etc.) |
|
|
|
|
|
Returns: |
|
|
Tuple of (final_prompt, code_prompt, final_tokens, code_was_truncated) |
|
|
""" |
|
|
max_context = model_config["max_context"] |
|
|
max_total_tokens = int(max_context * CONTEXT_BUFFER_FACTOR) |
|
|
|
|
|
|
|
|
prompt_template = Prompt("discrepancy_generation") |
|
|
|
|
|
|
|
|
template_with_paper = prompt_template(paper=paper_text, code="") |
|
|
tokens_template_and_paper = token_counter(template_with_paper) |
|
|
|
|
|
|
|
|
remaining_code_tokens = int((max_total_tokens - tokens_template_and_paper) * code_reduction_factor) |
|
|
|
|
|
if remaining_code_tokens <= 0: |
|
|
raise ValueError( |
|
|
f"Paper text too long: {tokens_template_and_paper} tokens exceeds " |
|
|
f"{int(CONTEXT_BUFFER_FACTOR * 100)}% of context limit ({max_total_tokens} tokens)" |
|
|
) |
|
|
|
|
|
logger.info( |
|
|
f"Template + paper tokens: {tokens_template_and_paper}, " |
|
|
f"Remaining for code (with {code_reduction_factor:.1%} factor): {remaining_code_tokens}" |
|
|
) |
|
|
|
|
|
|
|
|
original_code_size = 0 |
|
|
if code_loader: |
|
|
|
|
|
original_code_size = -1 |
|
|
elif code_text: |
|
|
original_code_size = len(code_text) |
|
|
|
|
|
|
|
|
code_was_truncated = False |
|
|
if code_loader: |
|
|
|
|
|
code_prompt = code_loader.get_code_prompt( |
|
|
token_counter=token_counter, |
|
|
max_tokens=remaining_code_tokens, |
|
|
) |
|
|
|
|
|
code_tokens_used = token_counter(code_prompt) |
|
|
code_was_truncated = code_tokens_used >= remaining_code_tokens * 0.95 |
|
|
else: |
|
|
|
|
|
code_prompt = "" |
|
|
code_tokens = 0 |
|
|
if code_text and remaining_code_tokens > 0: |
|
|
code_lines = code_text.split('\n') |
|
|
|
|
|
for line in code_lines: |
|
|
line_with_newline = line + '\n' |
|
|
line_tokens = token_counter(line_with_newline) |
|
|
if code_tokens + line_tokens > remaining_code_tokens: |
|
|
logger.warning(f"Truncating code at {code_tokens} tokens (limit: {remaining_code_tokens})") |
|
|
code_was_truncated = True |
|
|
break |
|
|
code_prompt += line_with_newline |
|
|
code_tokens += line_tokens |
|
|
|
|
|
|
|
|
if len(code_prompt) < original_code_size: |
|
|
code_was_truncated = True |
|
|
|
|
|
|
|
|
final_prompt = prompt_template(paper=paper_text, code=code_prompt) |
|
|
final_tokens = token_counter(final_prompt) |
|
|
|
|
|
if final_tokens > max_total_tokens: |
|
|
raise ValueError( |
|
|
f"Final prompt too long: {final_tokens} tokens exceeds " |
|
|
f"{int(CONTEXT_BUFFER_FACTOR * 100)}% of context limit ({max_total_tokens} tokens)" |
|
|
) |
|
|
|
|
|
logger.info(f"Final prompt tokens: {final_tokens} (limit: {max_total_tokens})") |
|
|
|
|
|
return final_prompt, code_prompt, final_tokens, code_was_truncated |
|
|
|
|
|
|
|
|
def validate_urls(arxiv_url: str, github_url: str) -> tuple[bool, str]: |
|
|
"""Validate input URLs.""" |
|
|
if not arxiv_url: |
|
|
return False, "Please provide an arXiv URL" |
|
|
if not github_url: |
|
|
return False, "Please provide a GitHub URL" |
|
|
|
|
|
if "arxiv.org" not in arxiv_url and not arxiv_url.startswith("http"): |
|
|
|
|
|
if arxiv_url.replace(".", "").replace("v", "").isdigit(): |
|
|
arxiv_url = f"https://arxiv.org/abs/{arxiv_url}" |
|
|
else: |
|
|
return False, "Invalid arXiv URL format" |
|
|
|
|
|
if "github.com" not in github_url: |
|
|
return False, "Please provide a valid GitHub URL" |
|
|
|
|
|
return True, "" |
|
|
|
|
|
|
|
|
def validate_files(paper_file, code_file) -> tuple[bool, str]: |
|
|
"""Validate uploaded files.""" |
|
|
if paper_file is None: |
|
|
return False, "Please upload a paper markdown file" |
|
|
if code_file is None: |
|
|
return False, "Please upload a repository text file" |
|
|
|
|
|
|
|
|
if paper_file.name and not paper_file.name.endswith(('.md', '.markdown', '.txt')): |
|
|
return False, "Paper file should be a markdown (.md) or text (.txt) file" |
|
|
if code_file.name and not code_file.name.endswith('.txt'): |
|
|
return False, "Repository file should be a text (.txt) file" |
|
|
|
|
|
return True, "" |
|
|
|
|
|
|
|
|
def process_discrepancy_detection( |
|
|
paper_text: str | None = None, |
|
|
code_text: str | None = None, |
|
|
arxiv_url: str | None = None, |
|
|
github_url: str | None = None, |
|
|
model_config: dict | None = None, |
|
|
): |
|
|
"""Main processing pipeline for discrepancy detection.""" |
|
|
results = { |
|
|
"paper_text": None, |
|
|
"code_prompt": None, |
|
|
"prompt": None, |
|
|
"llm_response": None, |
|
|
"discrepancies": None, |
|
|
"error": None, |
|
|
"step_timings": None, |
|
|
} |
|
|
|
|
|
|
|
|
step_timings = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
with st.status("π Processing...", expanded=False) as status: |
|
|
try: |
|
|
|
|
|
step_start = time.time() |
|
|
if arxiv_url: |
|
|
|
|
|
status.update(label="π Fetching paper from arXiv...", state="running") |
|
|
try: |
|
|
|
|
|
arxiv2md = Arxiv2MD(output_dir=Path("data/papers")) |
|
|
paper_text = arxiv2md(arxiv_url) |
|
|
results["paper_text"] = paper_text |
|
|
step_time = time.time() - step_start |
|
|
step_timings["Paper Fetch"] = step_time |
|
|
st.write(f"β
Paper fetched: {step_time:.1f}s") |
|
|
status.update( |
|
|
label=f"β
Paper fetched ({step_time:.1f}s)", |
|
|
state="running", |
|
|
) |
|
|
except Exception as e: |
|
|
error_msg = f"Error fetching paper: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
results["error"] = error_msg |
|
|
status.update(label="β Error fetching paper", state="error") |
|
|
return results |
|
|
else: |
|
|
|
|
|
status.update(label="π Processing paper...", state="running") |
|
|
try: |
|
|
results["paper_text"] = paper_text |
|
|
step_time = time.time() - step_start |
|
|
step_timings["Paper Processing"] = step_time |
|
|
st.write(f"β
Paper processed: {step_time:.1f}s") |
|
|
status.update( |
|
|
label=f"β
Paper processed ({step_time:.1f}s)", |
|
|
state="running", |
|
|
) |
|
|
except Exception as e: |
|
|
error_msg = f"Error processing paper: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
results["error"] = error_msg |
|
|
status.update(label="β Error processing paper", state="error") |
|
|
return results |
|
|
|
|
|
|
|
|
step_start = time.time() |
|
|
code_loader = None |
|
|
if github_url: |
|
|
|
|
|
status.update(label="π¦ Fetching code from GitHub...", state="running") |
|
|
try: |
|
|
|
|
|
code_loader = CodeLoader( |
|
|
github_url=github_url, |
|
|
max_file_size_mb=1.0, |
|
|
raw_repo_dir=Path("data/repos-raw"), |
|
|
) |
|
|
step_time = time.time() - step_start |
|
|
step_timings["Repository Clone"] = step_time |
|
|
st.write(f"β
Repository cloned: {step_time:.1f}s") |
|
|
status.update( |
|
|
label=f"β
Repository cloned ({step_time:.1f}s)", |
|
|
state="running", |
|
|
) |
|
|
except Exception as e: |
|
|
error_msg = f"Error cloning repository: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
results["error"] = error_msg |
|
|
status.update(label="β Error cloning repository", state="error") |
|
|
return results |
|
|
else: |
|
|
|
|
|
status.update(label="π¦ Processing repository...", state="running") |
|
|
step_time = time.time() - step_start |
|
|
step_timings["Code Processing"] = step_time |
|
|
st.write(f"β
Repository processed: {step_time:.1f}s") |
|
|
status.update( |
|
|
label=f"β
Repository processed ({step_time:.1f}s)", |
|
|
state="running", |
|
|
) |
|
|
|
|
|
|
|
|
step_start = time.time() |
|
|
status.update(label="π Preparing prompt...", state="running") |
|
|
|
|
|
|
|
|
tokenizer_name = model_config["tokenizer"] |
|
|
token_counter = TokenCounter(model=tokenizer_name) |
|
|
|
|
|
try: |
|
|
|
|
|
final_prompt, code_prompt, final_tokens, code_was_truncated = _build_prompt( |
|
|
paper_text=paper_text, |
|
|
code_loader=code_loader, |
|
|
code_text=code_text, |
|
|
model_config=model_config, |
|
|
token_counter=token_counter, |
|
|
) |
|
|
|
|
|
results["code_prompt"] = code_prompt |
|
|
results["prompt"] = final_prompt |
|
|
|
|
|
step_time = time.time() - step_start |
|
|
step_timings["Prompt Preparation"] = step_time |
|
|
st.write(f"β
Prompt prepared: {step_time:.1f}s ({final_tokens:,} tokens)") |
|
|
status.update( |
|
|
label=f"β
Prompt prepared ({step_time:.1f}s, {final_tokens:,} tokens)", |
|
|
state="running", |
|
|
) |
|
|
except Exception as e: |
|
|
error_msg = f"Error preparing prompt: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
results["error"] = error_msg |
|
|
status.update(label="β Error preparing prompt", state="error") |
|
|
return results |
|
|
|
|
|
|
|
|
step_start = time.time() |
|
|
status.update(label="π€\uFE0F Detecting discrepancies (this may take a while)...", state="running") |
|
|
|
|
|
|
|
|
code_reduction_factor = 1.0 |
|
|
reduction_step = 0.1 |
|
|
max_retries = 10 |
|
|
retry_count = 0 |
|
|
success = False |
|
|
current_final_prompt = final_prompt |
|
|
current_code_was_truncated = code_was_truncated |
|
|
|
|
|
while not success and retry_count < max_retries: |
|
|
try: |
|
|
|
|
|
if retry_count > 0: |
|
|
logger.info( |
|
|
f"Retrying with code reduction factor: {code_reduction_factor:.1%} " |
|
|
f"(attempt {retry_count}/{max_retries})" |
|
|
) |
|
|
status.update( |
|
|
label=f"π Retrying with reduced code ({code_reduction_factor:.0%})...", |
|
|
state="running" |
|
|
) |
|
|
st.write(f"π Retrying with reduced code ({code_reduction_factor:.0%})...") |
|
|
|
|
|
|
|
|
current_final_prompt, code_prompt, final_tokens, current_code_was_truncated = _build_prompt( |
|
|
paper_text=paper_text, |
|
|
code_loader=code_loader, |
|
|
code_text=code_text, |
|
|
model_config=model_config, |
|
|
token_counter=token_counter, |
|
|
code_reduction_factor=code_reduction_factor, |
|
|
) |
|
|
results["code_prompt"] = code_prompt |
|
|
results["prompt"] = current_final_prompt |
|
|
|
|
|
|
|
|
model = model_config["model"] |
|
|
api_key = model_config.get("api_key") |
|
|
api_base = model_config.get("api_base") |
|
|
max_context = model_config.get("max_context") |
|
|
|
|
|
llm = LLM( |
|
|
model=model, |
|
|
api_key=api_key, |
|
|
api_base=api_base, |
|
|
temperature=1.0, |
|
|
top_p=1.0, |
|
|
reasoning_effort="high", |
|
|
max_context=max_context, |
|
|
) |
|
|
|
|
|
response = llm(current_final_prompt) |
|
|
results["llm_response"] = response |
|
|
|
|
|
|
|
|
choices = response.get("choices", []) |
|
|
if not choices: |
|
|
raise ValueError("No choices in LLM response") |
|
|
|
|
|
content = ( |
|
|
choices[0] |
|
|
.get("message", {}) |
|
|
.get("content", "") |
|
|
) |
|
|
|
|
|
if not content: |
|
|
raise ValueError("Empty content in LLM response") |
|
|
|
|
|
|
|
|
discrepancies = parse_discrepancies(content) |
|
|
results["discrepancies"] = discrepancies |
|
|
|
|
|
step_time = time.time() - step_start |
|
|
step_timings["LLM Inference"] = step_time |
|
|
total_time = sum(step_timings.values()) |
|
|
|
|
|
st.write(f"β
LLM inference: {step_time:.1f}s") |
|
|
|
|
|
|
|
|
if current_code_was_truncated: |
|
|
st.warning("β οΈ **Note**: Some code was truncated from the prompt due to context length limitations.") |
|
|
|
|
|
st.write("---") |
|
|
st.write(f"**Total time: {total_time:.1f}s**") |
|
|
|
|
|
if discrepancies: |
|
|
count = len(discrepancies) |
|
|
discrepancy_text = "discrepancy" if count == 1 else "discrepancies" |
|
|
status.update( |
|
|
label=f"β
Complete! Found {count} {discrepancy_text} ({total_time:.1f}s total)", |
|
|
state="complete", |
|
|
) |
|
|
else: |
|
|
status.update( |
|
|
label=f"β
Complete! No discrepancies found ({total_time:.1f}s total)", |
|
|
state="complete", |
|
|
) |
|
|
|
|
|
success = True |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = str(e) |
|
|
api_key = model_config.get("api_key") if isinstance(model_config, dict) else None |
|
|
redacted_error = _redact_secrets(error_msg, [api_key]) |
|
|
|
|
|
|
|
|
if _is_context_length_error(error_msg): |
|
|
retry_count += 1 |
|
|
|
|
|
|
|
|
|
|
|
if code_reduction_factor <= 0.1: |
|
|
|
|
|
error_msg = ( |
|
|
f"The paper text is too long for the model's context window. " |
|
|
f"Even with all code removed, the paper alone exceeds the context limit. " |
|
|
f"Please use a model with a larger context window or provide a shorter paper." |
|
|
) |
|
|
logger.error(error_msg) |
|
|
results["error"] = error_msg |
|
|
status.update(label="β Paper too long for model", state="error") |
|
|
return results |
|
|
|
|
|
|
|
|
code_reduction_factor = max(0.1, code_reduction_factor - reduction_step) |
|
|
|
|
|
logger.warning( |
|
|
f"Context length error detected: {redacted_error}. " |
|
|
f"Retrying with reduced code ({code_reduction_factor:.0%}) (attempt {retry_count}/{max_retries})" |
|
|
) |
|
|
continue |
|
|
else: |
|
|
|
|
|
logger.error(f"Error during LLM inference: {redacted_error}") |
|
|
results["error"] = f"Error during LLM inference: {redacted_error}" |
|
|
status.update(label="β Error during inference", state="error") |
|
|
return results |
|
|
|
|
|
|
|
|
if not success: |
|
|
error_msg = ( |
|
|
f"Could not fit prompt within context limits after {retry_count} retries. " |
|
|
f"The paper text may be too long for this model's context window." |
|
|
) |
|
|
logger.error(error_msg) |
|
|
results["error"] = error_msg |
|
|
status.update(label="β Prompt too large for model", state="error") |
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
api_key = model_config.get("api_key") if isinstance(model_config, dict) else None |
|
|
error_msg = f"Unexpected error: {_redact_secrets(str(e), [api_key])}" |
|
|
logger.error(error_msg, exc_info=True) |
|
|
results["error"] = error_msg |
|
|
status.update(label="β Unexpected error", state="error") |
|
|
return results |
|
|
|
|
|
results["step_timings"] = step_timings |
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
api_key = model_config.get("api_key") if isinstance(model_config, dict) else None |
|
|
error_msg = f"Unexpected error: {_redact_secrets(str(e), [api_key])}" |
|
|
logger.error(error_msg, exc_info=True) |
|
|
results["error"] = error_msg |
|
|
return results |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main Streamlit app.""" |
|
|
st.title("π¬ :rainbow[SciCoQA] Paper-Code Discrepancy Detection") |
|
|
st.markdown( |
|
|
""" |
|
|
_Detect discrepancies between scientific papers and their code implementations._ |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
with st.expander("βΉοΈ About", expanded=False): |
|
|
st.markdown( |
|
|
""" |
|
|
This tool is a demo of our research paper on detecting discrepancies between scientific papers and their |
|
|
code implementations. You can read our paper here: [arXiv:2601.XXXX](https://arxiv.org/pdf/2601.XXXX). |
|
|
|
|
|
This tool helps researchers and developers identify inconsistencies between scientific papers and their |
|
|
corresponding code implementations. Such discrepancies can lead to reproducibility issues, incorrect |
|
|
implementations, or misunderstandings of the research. By using advanced LLMs to analyze both the paper |
|
|
text and code, this app automatically detects mismatches in algorithms, parameters, data processing steps, |
|
|
and other implementation details. |
|
|
|
|
|
**β οΈ Important Limitations:** |
|
|
Our research found that **recall is still low** - meaning the tool may miss some discrepancies. |
|
|
**All outputs should be used with human verification** and should not be relied upon as the sole method |
|
|
for discrepancy detection. |
|
|
|
|
|
**LLM Provider Recommendations:** |
|
|
- **Free Models (OpenRouter)**: Best for quick checks of already public paper+code combinations |
|
|
- **Provider Models (OpenAI, Anthropic, etc.)**: Best for high precision and best recall |
|
|
|
|
|
**Features:** |
|
|
- Support for multiple LLM providers (free or premium models) |
|
|
- Automatic content fetching from arXiv and GitHub |
|
|
- File upload support for custom papers and repositories |
|
|
- Secure API key handling (keys never stored or logged) |
|
|
|
|
|
**Resources:** |
|
|
- π¦ **Code**: [GitHub Repository](https://github.com/UKPLab/scicoqa) |
|
|
- π **Dataset**: [Hugging Face Dataset](https://huggingface.co/datasets/ukplab/scicoqa) |
|
|
- π **Project Website**: [ukplab.github.io/scicoqa](https://ukplab.github.io/scicoqa) |
|
|
|
|
|
**Citation:** |
|
|
If you find this tool useful, please cite our paper: |
|
|
```bibtex |
|
|
@article{scicoqa2026, |
|
|
title = {SciCoQA: Quality Assurance for Scientific Paper-Code Alignment}, |
|
|
author = {BaumgΓ€rtner, Tim and Gurevych, Iryna}, |
|
|
journal = {arXiv preprint arXiv:XXXX.XXXXX}, |
|
|
year = {2026}, |
|
|
url = {https://github.com/UKPLab/scicoqa} |
|
|
} |
|
|
``` |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("π€\uFE0F Model Configuration") |
|
|
|
|
|
|
|
|
model_config = None |
|
|
model_name = None |
|
|
display_model_name = None |
|
|
|
|
|
|
|
|
if "model_config" in st.session_state and st.session_state.model_config: |
|
|
existing_config = st.session_state.model_config |
|
|
display_model_name = existing_config.get("name") or existing_config.get("model", "Unknown") |
|
|
|
|
|
if display_model_name: |
|
|
st.caption(f"Current: {display_model_name}") |
|
|
|
|
|
|
|
|
model_type = st.radio( |
|
|
"Model Type", |
|
|
options=["Free Models (OpenRouter)", "Provider (OpenAI, Anthropic, Gemini, etc.)"], |
|
|
|
|
|
help="Select free models (no API key) or provider models (requires API key)", |
|
|
|
|
|
key="model_type_radio", |
|
|
index=0, |
|
|
) |
|
|
|
|
|
st.session_state.model_type = model_type |
|
|
|
|
|
st.divider() |
|
|
|
|
|
|
|
|
if model_type == "Free Models (OpenRouter)": |
|
|
|
|
|
if "free_models_cache" not in st.session_state: |
|
|
with st.spinner("Loading free models from OpenRouter..."): |
|
|
free_models_raw = fetch_free_models() |
|
|
st.session_state.free_models_cache = free_models_raw |
|
|
|
|
|
free_models_raw = st.session_state.free_models_cache |
|
|
|
|
|
if not free_models_raw: |
|
|
st.error("β οΈ Could not fetch free models from OpenRouter. Please try again later or use a different model type.") |
|
|
model_config = None |
|
|
else: |
|
|
|
|
|
st.warning( |
|
|
"β οΈ **Privacy Notice**: Free models are provided via [OpenRouter](https://openrouter.ai). " |
|
|
"The model provider may log your prompts and outputs. For enhanced privacy, consider using Provider models with your own API keys." |
|
|
) |
|
|
|
|
|
model_options = {get_model_config(m)["name"]: get_model_config(m) for m in free_models_raw} |
|
|
|
|
|
if model_options: |
|
|
|
|
|
model_names = list(model_options.keys()) |
|
|
default_index = 0 |
|
|
for idx, name in enumerate(model_names): |
|
|
if "nemotron 3 nano 30b" in name.lower(): |
|
|
default_index = idx |
|
|
break |
|
|
|
|
|
model_name = st.selectbox( |
|
|
"Select Free Model", |
|
|
options=model_names, |
|
|
help="Free models via OpenRouter (no API key required)", |
|
|
key="free_model_select", |
|
|
index=default_index, |
|
|
) |
|
|
model_config = model_options[model_name] |
|
|
|
|
|
else: |
|
|
st.error("β οΈ No free models available. Please try again later or use a different model type.") |
|
|
model_config = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
st.info("π **Provider Model**: Use your own API keys to access premium models. Your keys are never stored, logged, or displayed.") |
|
|
|
|
|
provider_subtype = st.radio( |
|
|
"Model Selection", |
|
|
options=["Preset", "Custom"], |
|
|
help="Select from preset models or enter a custom model", |
|
|
key="provider_subtype", |
|
|
) |
|
|
|
|
|
if provider_subtype == "Preset": |
|
|
model_name = st.selectbox( |
|
|
"Select Model", |
|
|
options=list(PROVIDER_PRESETS.keys()), |
|
|
help="Select a preset model (API key required)", |
|
|
key="preset_model_select", |
|
|
) |
|
|
preset_config = PROVIDER_PRESETS[model_name] |
|
|
api_key_env = preset_config["api_key_env"] |
|
|
api_key_label = api_key_env.replace("_", " ").title() |
|
|
|
|
|
api_key = st.text_input( |
|
|
f"{api_key_label}", |
|
|
type="password", |
|
|
help=f"Enter your {api_key_label}. Your key is never stored, logged, or displayed.", |
|
|
placeholder=f"sk-..." if "OPENAI" in api_key_env else "Enter API key", |
|
|
key="preset_api_key", |
|
|
) |
|
|
|
|
|
if api_key: |
|
|
model_config = create_provider_model_config( |
|
|
model=preset_config["model"], |
|
|
api_key=api_key, |
|
|
max_context=preset_config["max_context"], |
|
|
tokenizer=preset_config["tokenizer"], |
|
|
) |
|
|
else: |
|
|
custom_model_name = st.text_input( |
|
|
"Model Name (litellm format)", |
|
|
placeholder="e.g., openai/gpt-5.2, anthropic/claude-sonnet-4-5, gemini/gemini-3-pro-preview", |
|
|
help="Enter the model name in litellm format. See [litellm documentation](https://docs.litellm.ai/docs/providers) for supported formats.", |
|
|
key="custom_model_name", |
|
|
) |
|
|
custom_max_context = st.number_input( |
|
|
"Max Context (tokens)", |
|
|
min_value=1000, |
|
|
max_value=10000000, |
|
|
value=128000, |
|
|
step=1000, |
|
|
help="Maximum context window size in tokens", |
|
|
key="custom_max_context", |
|
|
) |
|
|
|
|
|
if custom_model_name: |
|
|
provider = get_provider_from_model(custom_model_name) |
|
|
api_key_env = get_api_key_env_name(provider) |
|
|
api_key_label = api_key_env.replace("_", " ").title() |
|
|
|
|
|
api_key = st.text_input( |
|
|
f"{api_key_label}", |
|
|
type="password", |
|
|
help=f"Enter your {api_key_label}. Your key is never stored, logged, or displayed.", |
|
|
placeholder=f"sk-..." if "OPENAI" in api_key_env else "Enter API key", |
|
|
key="custom_api_key", |
|
|
) |
|
|
|
|
|
if api_key: |
|
|
model_name = custom_model_name |
|
|
model_config = create_provider_model_config( |
|
|
model=custom_model_name, |
|
|
api_key=api_key, |
|
|
max_context=custom_max_context, |
|
|
) |
|
|
|
|
|
st.markdown( |
|
|
"π **Need help with model format?** See the [litellm documentation](https://docs.litellm.ai/docs/providers) " |
|
|
"for supported providers and model naming conventions." |
|
|
) |
|
|
|
|
|
st.caption("π Your API key is secure: never stored, logged, or displayed") |
|
|
|
|
|
|
|
|
if model_config: |
|
|
display_name = model_config.get("name") or model_config.get("model", model_name or "Unknown") |
|
|
st.caption(f"π Max Context: {model_config['max_context']:,} tokens") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if model_config: |
|
|
st.session_state.model_config = _safe_model_config_for_session(model_config) |
|
|
st.session_state.model_name = model_config.get("name") or model_config.get("model", model_name or "Unknown") |
|
|
|
|
|
|
|
|
with st.form("discrepancy_form"): |
|
|
|
|
|
tab_links, tab_files = st.tabs(["arXiv and GitHub Links", "Upload Paper and Code Files"]) |
|
|
|
|
|
|
|
|
arxiv_url = None |
|
|
github_url = None |
|
|
paper_file = None |
|
|
code_file = None |
|
|
input_method = None |
|
|
|
|
|
with tab_links: |
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
arxiv_url = st.text_input( |
|
|
"arXiv Paper", |
|
|
value=st.session_state.get("example_arxiv_url", ""), |
|
|
placeholder="https://arxiv.org/abs/2006.12834 or 2006.12834", |
|
|
help="Enter the arXiv paper URL or just the paper ID", |
|
|
label_visibility="visible", |
|
|
) |
|
|
|
|
|
with col2: |
|
|
github_url = st.text_input( |
|
|
"GitHub Code", |
|
|
value=st.session_state.get("example_github_url", ""), |
|
|
placeholder="https://github.com/username/repo", |
|
|
help="Enter the full GitHub repository URL", |
|
|
label_visibility="visible", |
|
|
) |
|
|
|
|
|
if arxiv_url or github_url: |
|
|
input_method = "arXiv and GitHub Links" |
|
|
|
|
|
with tab_files: |
|
|
|
|
|
with st.expander("π How to prepare files", expanded=False): |
|
|
st.markdown(""" |
|
|
<h3>Converting PDF to Markdown with Pandoc</h3> |
|
|
|
|
|
1. Install pandoc: |
|
|
``` |
|
|
brew install pandoc |
|
|
``` |
|
|
For installing pandoc on Windows or Linux, see the [pandoc documentation](https://pandoc.org/installing.html). |
|
|
|
|
|
2. Convert your latex to markdown: |
|
|
```bash |
|
|
pandoc main.tex -f latex -t markdown -s --wrap=none -o paper.md |
|
|
``` |
|
|
|
|
|
<h3>Converting Repository to Text with Gitingest</h3> |
|
|
|
|
|
1. Install gitingest: |
|
|
```bash |
|
|
pip install gitingest |
|
|
``` |
|
|
|
|
|
2. Generate repository text file: |
|
|
```bash |
|
|
gitingest https://github.com/your-username/your-repo \\ |
|
|
-i "*.c,*.cc,*.cpp,*.cu,*.h,*.hpp,*.java,*.jl,*.m,*.matlab,Makefile,*.md,*.pl,*.ps1,*.py,*.r,*.sh,config.txt,*.rs,readme.txt,requirements_dev.txt,requirements-dev.txt,requirements.dev.txt,requirements.txt,*.scala,*.yaml,*.yml" \\ |
|
|
-o repo.txt \\ |
|
|
--token YOUR_GITHUB_TOKEN |
|
|
``` |
|
|
|
|
|
**Note**: Modify the file extension list to include the files you want to include in the repository text file. For private repositories, you'll need a GitHub token. For public repositories, you can omit the `--token` parameter. |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
paper_file = st.file_uploader( |
|
|
"Paper Markdown File", |
|
|
type=["md", "markdown", "txt"], |
|
|
help="Upload the paper as a markdown file", |
|
|
label_visibility="visible", |
|
|
accept_multiple_files=False, |
|
|
) |
|
|
|
|
|
with col2: |
|
|
code_file = st.file_uploader( |
|
|
"Repository Text File", |
|
|
type=["txt"], |
|
|
help="Upload the repository as a text file (generated using gitingest)", |
|
|
label_visibility="visible", |
|
|
accept_multiple_files=False, |
|
|
) |
|
|
|
|
|
if paper_file or code_file: |
|
|
input_method = "Upload Paper and Code Files" |
|
|
|
|
|
submitted = st.form_submit_button("Detect Discrepancies", type="primary", use_container_width=True) |
|
|
|
|
|
|
|
|
st.session_state.model_config = _safe_model_config_for_session(model_config) |
|
|
|
|
|
|
|
|
if submitted: |
|
|
|
|
|
|
|
|
if paper_file is not None or code_file is not None: |
|
|
is_valid, error_msg = validate_files(paper_file, code_file) |
|
|
if not is_valid: |
|
|
st.error(error_msg) |
|
|
return |
|
|
|
|
|
|
|
|
try: |
|
|
paper_text = paper_file.read().decode("utf-8") if paper_file else None |
|
|
code_text = code_file.read().decode("utf-8") if code_file else None |
|
|
except Exception as e: |
|
|
st.error(f"Error reading files: {str(e)}") |
|
|
return |
|
|
|
|
|
arxiv_url = None |
|
|
github_url = None |
|
|
|
|
|
elif arxiv_url or github_url: |
|
|
is_valid, error_msg = validate_urls(arxiv_url, github_url) |
|
|
if not is_valid: |
|
|
st.error(error_msg) |
|
|
return |
|
|
|
|
|
paper_text = None |
|
|
code_text = None |
|
|
else: |
|
|
st.error("Please provide either arXiv and GitHub links, or upload paper and code files.") |
|
|
return |
|
|
|
|
|
|
|
|
if "example_arxiv_url" in st.session_state: |
|
|
del st.session_state["example_arxiv_url"] |
|
|
if "example_github_url" in st.session_state: |
|
|
del st.session_state["example_github_url"] |
|
|
|
|
|
|
|
|
if model_config is None: |
|
|
st.error("Please select a valid model.") |
|
|
return |
|
|
|
|
|
|
|
|
model_type = st.session_state.get("model_type", "Provider (OpenAI, Anthropic, Gemini, etc.)") |
|
|
if model_type == "Provider (OpenAI, Anthropic, Gemini, etc.)": |
|
|
if "api_key" not in model_config or not model_config.get("api_key"): |
|
|
st.error("β οΈ API key required for provider models. Please enter your API key.") |
|
|
return |
|
|
|
|
|
|
|
|
with st.spinner("Processing..."): |
|
|
results = process_discrepancy_detection( |
|
|
paper_text=paper_text, |
|
|
code_text=code_text, |
|
|
arxiv_url=arxiv_url, |
|
|
github_url=github_url, |
|
|
model_config=model_config, |
|
|
) |
|
|
|
|
|
|
|
|
if results["error"]: |
|
|
st.error(f"β Error: {results['error']}") |
|
|
return |
|
|
|
|
|
|
|
|
st.divider() |
|
|
st.header("Results") |
|
|
|
|
|
if results["discrepancies"]: |
|
|
count = len(results["discrepancies"]) |
|
|
discrepancy_text = "discrepancy" if count == 1 else "discrepancies" |
|
|
st.success(f"Found {count} {discrepancy_text}") |
|
|
|
|
|
|
|
|
tab_labels = [f"Discrepancy {idx}" for idx in range(1, count + 1)] |
|
|
tabs = st.tabs(tab_labels) |
|
|
|
|
|
for idx, (tab, discrepancy) in enumerate(zip(tabs, results["discrepancies"])): |
|
|
with tab: |
|
|
st.markdown(discrepancy) |
|
|
st.divider() |
|
|
else: |
|
|
st.info("β
No discrepancies found between the paper and code.") |
|
|
st.divider() |
|
|
|
|
|
|
|
|
with st.expander("π§ Technical Details", expanded=False): |
|
|
|
|
|
if results["prompt"]: |
|
|
st.subheader("π Raw Prompt") |
|
|
st.markdown("**Final prompt sent to the LLM (after truncation):**") |
|
|
model_config = st.session_state.get("model_config") |
|
|
if model_config: |
|
|
tokenizer_name = model_config["tokenizer"] |
|
|
token_counter = TokenCounter(model=tokenizer_name) |
|
|
prompt_tokens = token_counter(results["prompt"]) |
|
|
st.caption(f"Prompt tokens: {prompt_tokens:,}") |
|
|
|
|
|
with st.container(height=500): |
|
|
st.code(results["prompt"], language="text") |
|
|
st.divider() |
|
|
|
|
|
|
|
|
if results["llm_response"]: |
|
|
st.subheader("π Raw LLM Output") |
|
|
content = ( |
|
|
results["llm_response"] |
|
|
.get("choices", [{}])[0] |
|
|
.get("message", {}) |
|
|
.get("content", "") |
|
|
) |
|
|
|
|
|
model_config = st.session_state.get("model_config") |
|
|
if model_config: |
|
|
tokenizer_name = model_config["tokenizer"] |
|
|
token_counter = TokenCounter(model=tokenizer_name) |
|
|
output_tokens = token_counter(content) |
|
|
st.caption(f"Output tokens: {output_tokens:,}") |
|
|
st.code(content, language="yaml") |
|
|
st.divider() |
|
|
|
|
|
|
|
|
if results.get("step_timings"): |
|
|
st.subheader("β±οΈ Step Timing") |
|
|
step_timings = results["step_timings"] |
|
|
total_time = sum(step_timings.values()) |
|
|
|
|
|
|
|
|
for step_name, step_time in step_timings.items(): |
|
|
percentage = (step_time / total_time * 100) if total_time > 0 else 0 |
|
|
st.write(f"**{step_name}**: {step_time:.2f}s ({percentage:.1f}%)") |
|
|
|
|
|
st.metric("**Total Time**", f"{total_time:.2f}s") |
|
|
st.divider() |
|
|
|
|
|
|
|
|
st.subheader("π Debug Information") |
|
|
col1, col2, col3 = st.columns(3) |
|
|
with col1: |
|
|
|
|
|
model_config = st.session_state.get("model_config") |
|
|
if model_config: |
|
|
tokenizer_name = model_config["tokenizer"] |
|
|
token_counter = TokenCounter(model=tokenizer_name) |
|
|
|
|
|
if results["paper_text"]: |
|
|
paper_tokens = token_counter(results["paper_text"]) |
|
|
st.metric("Paper Tokens", f"{paper_tokens:,}") |
|
|
if results["code_prompt"]: |
|
|
code_tokens = token_counter(results["code_prompt"]) |
|
|
st.metric("Code Tokens", f"{code_tokens:,}") |
|
|
with col2: |
|
|
if results["llm_response"]: |
|
|
usage = results["llm_response"].get("usage", {}) |
|
|
if usage: |
|
|
input_tokens = usage.get("prompt_tokens", "N/A") |
|
|
output_tokens = usage.get("completion_tokens", "N/A") |
|
|
st.metric("Input Tokens", f"{input_tokens:,}" if input_tokens != "N/A" else "N/A") |
|
|
st.metric("Output Tokens", f"{output_tokens:,}" if output_tokens != "N/A" else "N/A") |
|
|
with col3: |
|
|
if results["llm_response"]: |
|
|
usage = results["llm_response"].get("usage", {}) |
|
|
if usage: |
|
|
total_tokens = usage.get("total_tokens", "N/A") |
|
|
st.metric("Total Tokens", f"{total_tokens:,}" if total_tokens != "N/A" else "N/A") |
|
|
|
|
|
cost = results["llm_response"].get("metadata", {}).get("cost", 0.0) |
|
|
if cost > 0: |
|
|
st.metric("Cost", f"${cost:.4f}") |
|
|
else: |
|
|
st.metric("Cost", "Free") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|