Spaces:
Runtime error
Runtime error
| import argparse | |
| import asyncio | |
| import hashlib | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import sys | |
| import time | |
| from typing import Any, TYPE_CHECKING | |
| def ensure_project_venv() -> None: | |
| project_root = os.path.dirname(os.path.abspath(__file__)) | |
| venv_root = os.path.join(project_root, ".venv") | |
| venv_python = os.path.join(project_root, ".venv", "bin", "python") | |
| if not os.path.exists(venv_python): | |
| return | |
| current_prefix = os.path.realpath(sys.prefix) | |
| expected_prefix = os.path.realpath(venv_root) | |
| if current_prefix != expected_prefix: | |
| os.execv(venv_python, [venv_python, *sys.argv]) | |
| ensure_project_venv() | |
| import numpy as np | |
| from datasets import load_dataset | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| def configure_logging() -> None: | |
| """Configure app and dependency logging level. | |
| Defaults to WARNING to keep CLI output concise. Override with | |
| ASKCHOMSKY_LOG_LEVEL (e.g., INFO, DEBUG) when troubleshooting. | |
| """ | |
| level_name = os.getenv("ASKCHOMSKY_LOG_LEVEL", "WARNING").upper() | |
| level = getattr(logging, level_name, logging.WARNING) | |
| # Keep root and noisy dependencies aligned with the selected verbosity. | |
| logging.getLogger().setLevel(level) | |
| for logger_name in ( | |
| "lightrag", | |
| "nano-vectordb", | |
| "sentence_transformers", | |
| "sentence_transformers.SentenceTransformer", | |
| "httpx", | |
| "openai", | |
| ): | |
| logging.getLogger(logger_name).setLevel(level) | |
| configure_logging() | |
| if TYPE_CHECKING: | |
| # These imports are heavy (transitively pull in torch, CUDA, etc.). | |
| # Import them only for type checking; at runtime we import lazily. | |
| from lightrag import LightRAG, QueryParam | |
| from lightrag.llm.openai import openai_complete_if_cache | |
| from lightrag.utils import EmbeddingFunc | |
| # LightRAG configures its own logger during import, so apply our level again | |
| # once we actually import it lazily at runtime (see initialize_rag). | |
| configure_logging() | |
| OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1" | |
| LLM_MODEL = os.getenv("ASKCHOMSKY_LLM_MODEL", "openai/gpt-4o-mini") | |
| EMBED_MODEL = os.getenv("ASKCHOMSKY_EMBED_MODEL", "openai/text-embedding-3-small") | |
| EMBED_DIM = 1536 | |
| DEFAULT_WORKING_DIR = "./lightrag_store" | |
| LLM_TIMEOUT_SECONDS = int(os.getenv("LLM_TIMEOUT", "600")) | |
| MAX_ASYNC_LLM_CALLS = int(os.getenv("MAX_ASYNC", "2")) | |
| MAX_PARALLEL_INSERT = int(os.getenv("MAX_PARALLEL_INSERT", "2")) | |
| REWRITE_QUERY_ENABLED = os.getenv("REWRITE_QUERY", "true").lower() == "true" | |
| VERIFY_CLAIMS_ENABLED = os.getenv("VERIFY_CLAIMS", "true").lower() == "true" | |
| QUERY_CACHE_TTL_SECONDS = int(os.getenv("QUERY_CACHE_TTL", "86400")) | |
| QUERY_CACHE_PATH = os.path.join(DEFAULT_WORKING_DIR, "query_cache.json") | |
| CITATION_SYSTEM_PROMPT = """You are a retrieval-grounded assistant. | |
| Use only the provided context data. | |
| If context is insufficient, say: I do not have enough information to answer from the retrieved corpus. | |
| Citation rules: | |
| 1) Every factual claim must include at least one citation marker like [1]. | |
| 2) Do not invent citation IDs. | |
| 3) Keep citation IDs consistent with the provided references. | |
| Response style: {response_type} | |
| User preference: {user_prompt} | |
| Context: | |
| {context_data} | |
| """ | |
| def get_langfuse_client(): | |
| """Return a configured Langfuse client or None if unavailable/invalid.""" | |
| public_key = os.getenv("LANGFUSE_PUBLIC_KEY", "").strip().strip('"').strip("'") | |
| secret_key = os.getenv("LANGFUSE_SECRET_KEY", "").strip().strip('"').strip("'") | |
| base_url = ( | |
| os.getenv("LANGFUSE_BASE_URL", "").strip().strip('"').strip("'") | |
| or os.getenv("LANGFUSE_HOST", "").strip().strip('"').strip("'") | |
| or "https://cloud.langfuse.com" | |
| ) | |
| if ( | |
| not public_key | |
| or public_key.startswith("pk-lf-...") | |
| or not secret_key | |
| or secret_key.startswith("sk-lf-...") | |
| ): | |
| return None | |
| try: | |
| from langfuse import Langfuse | |
| client = Langfuse( | |
| public_key=public_key, | |
| secret_key=secret_key, | |
| base_url=base_url, | |
| debug=os.getenv("LANGFUSE_DEBUG", "false").lower() == "true", | |
| ) | |
| if not client.auth_check(): | |
| print("Langfuse auth check failed. Verify keys and LANGFUSE_BASE_URL.") | |
| return None | |
| return client | |
| except Exception as exc: | |
| print(f"Langfuse disabled: {exc}") | |
| return None | |
| def configure_langfuse() -> bool: | |
| """Backward-compatible bool helper used by older call sites.""" | |
| return get_langfuse_client() is not None | |
| # --------------------------------------------------------------------------- | |
| # API-based embeddings (OpenRouter / OpenAI-compatible) | |
| # --------------------------------------------------------------------------- | |
| def _get_api_key() -> str: | |
| api_key = os.getenv("openrouter_key") or os.getenv("OPENAI_API_KEY", "") | |
| if not api_key: | |
| raise ValueError("Missing openrouter_key or OPENAI_API_KEY in .env") | |
| return api_key | |
| def _api_embed_single(text: str) -> list[float]: | |
| import httpx | |
| api_key = _get_api_key() | |
| payload = {"input": text, "model": EMBED_MODEL} | |
| headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json", | |
| } | |
| with httpx.Client(timeout=30.0) as client: | |
| resp = client.post( | |
| OPENROUTER_BASE_URL + "/embeddings", json=payload, headers=headers | |
| ) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| return data["data"][0]["embedding"] | |
| def embed_texts(texts: list[str]) -> np.ndarray: | |
| embeddings = [_api_embed_single(t) for t in texts] | |
| arr = np.array(embeddings, dtype=np.float32) | |
| norms = np.linalg.norm(arr, axis=1, keepdims=True) | |
| norms[norms == 0] = 1.0 | |
| return arr / norms | |
| async def embedding_func(texts: list[str]) -> np.ndarray: | |
| return await asyncio.to_thread(embed_texts, texts) | |
| # --------------------------------------------------------------------------- | |
| # Query result cache (disk-based, TTL-evicted) | |
| # --------------------------------------------------------------------------- | |
| def _load_query_cache() -> dict[str, dict[str, Any]]: | |
| if not os.path.exists(QUERY_CACHE_PATH): | |
| return {} | |
| try: | |
| with open(QUERY_CACHE_PATH, "r") as f: | |
| return json.load(f) | |
| except (json.JSONDecodeError, OSError): | |
| return {} | |
| def _save_query_cache(cache: dict[str, dict[str, Any]]) -> None: | |
| os.makedirs(os.path.dirname(QUERY_CACHE_PATH), exist_ok=True) | |
| with open(QUERY_CACHE_PATH, "w") as f: | |
| json.dump(cache, f) | |
| def _cache_key(question: str, mode: str) -> str: | |
| raw = f"{question.strip().lower()}|{mode}" | |
| return hashlib.sha256(raw.encode()).hexdigest() | |
| def get_cached_answer(question: str, mode: str) -> str | None: | |
| if QUERY_CACHE_TTL_SECONDS <= 0: | |
| return None | |
| key = _cache_key(question, mode) | |
| cache = _load_query_cache() | |
| entry = cache.get(key) | |
| if not entry: | |
| return None | |
| if time.time() - entry.get("ts", 0) > QUERY_CACHE_TTL_SECONDS: | |
| del cache[key] | |
| _save_query_cache(cache) | |
| return None | |
| return entry.get("answer") | |
| def cache_answer(question: str, mode: str, answer: str) -> None: | |
| if QUERY_CACHE_TTL_SECONDS <= 0: | |
| return | |
| key = _cache_key(question, mode) | |
| cache = _load_query_cache() | |
| cache[key] = {"answer": answer, "ts": time.time()} | |
| _save_query_cache(cache) | |
| async def llm_model_func( | |
| prompt, | |
| system_prompt=None, | |
| history_messages=None, | |
| keyword_extraction=False, | |
| **kwargs, | |
| ) -> str: | |
| # Import here to avoid pulling in heavy dependencies during module import. | |
| from lightrag.llm.openai import openai_complete_if_cache | |
| api_key = os.getenv("openrouter_key") | |
| if not api_key: | |
| raise ValueError("Missing openrouter_key in .env") | |
| if history_messages is None: | |
| history_messages = [] | |
| return await openai_complete_if_cache( | |
| model=LLM_MODEL, | |
| prompt=prompt, | |
| system_prompt=system_prompt, | |
| history_messages=history_messages, | |
| api_key=api_key, | |
| base_url=OPENROUTER_BASE_URL, | |
| timeout=LLM_TIMEOUT_SECONDS, | |
| keyword_extraction=keyword_extraction, | |
| **kwargs, | |
| ) | |
| async def initialize_rag(working_dir: str = DEFAULT_WORKING_DIR) -> "LightRAG": | |
| # Lazy imports keep startup fast and avoid loading the full | |
| # LightRAG/torch stack until we actually need RAG functionality. | |
| from lightrag import LightRAG | |
| from lightrag.utils import EmbeddingFunc | |
| os.makedirs(working_dir, exist_ok=True) | |
| rag = LightRAG( | |
| working_dir=working_dir, | |
| llm_model_func=llm_model_func, | |
| llm_model_name=LLM_MODEL, | |
| default_llm_timeout=LLM_TIMEOUT_SECONDS, | |
| llm_model_max_async=MAX_ASYNC_LLM_CALLS, | |
| max_parallel_insert=MAX_PARALLEL_INSERT, | |
| embedding_func=EmbeddingFunc( | |
| embedding_dim=EMBED_DIM, | |
| max_token_size=8192, | |
| model_name=EMBED_MODEL, | |
| func=embedding_func, | |
| ), | |
| ) | |
| await rag.initialize_storages() | |
| return rag | |
| def load_corpus_texts(limit: int) -> list[str]: | |
| ds = load_dataset("mmoise00/chomsky-corpus", split="train") | |
| count = min(limit, len(ds)) | |
| texts = [] | |
| for row in ds.select(range(count)): | |
| title = row.get("article_title") or "Untitled" | |
| date = row.get("article_date") or "" | |
| content = row.get("content") or "" | |
| texts.append(f"Title: {title}\nDate: {date}\n\n{content}") | |
| return texts | |
| async def ingest_corpus( | |
| doc_limit: int = 200, working_dir: str = DEFAULT_WORKING_DIR | |
| ) -> int: | |
| rag = None | |
| try: | |
| rag = await initialize_rag(working_dir) | |
| docs = load_corpus_texts(doc_limit) | |
| await rag.ainsert(docs) | |
| return len(docs) | |
| finally: | |
| if rag is not None: | |
| await rag.finalize_storages() | |
| async def query_rag( | |
| question: str, | |
| mode: str = "hybrid", | |
| working_dir: str = DEFAULT_WORKING_DIR, | |
| ) -> str: | |
| def _looks_like_no_answer(answer: str) -> bool: | |
| text = answer.lower() | |
| return ( | |
| "[no-context]" in text | |
| or "i do not have enough information to answer" in text | |
| or "sorry, i'm not able to provide an answer" in text | |
| ) | |
| def _response_to_text(response: object) -> str: | |
| if isinstance(response, str): | |
| return response | |
| content = getattr(response, "content", None) | |
| if content is not None: | |
| return str(content) | |
| return str(response) | |
| def _has_citation_marker(text: str) -> bool: | |
| return bool(re.search(r"\[\d+\]", text)) | |
| def _extract_json_object(text: str) -> dict[str, Any] | None: | |
| match = re.search(r"\{[\s\S]*\}", text) | |
| if not match: | |
| return None | |
| try: | |
| return json.loads(match.group(0)) | |
| except json.JSONDecodeError: | |
| return None | |
| def _extract_references(raw_result: dict[str, Any]) -> list[dict[str, str]]: | |
| data = raw_result.get("data", {}) | |
| references = data.get("references", []) | |
| if isinstance(references, list): | |
| return [r for r in references if isinstance(r, dict)] | |
| return [] | |
| def _extract_chunks(raw_result: dict[str, Any]) -> list[dict[str, Any]]: | |
| data = raw_result.get("data", {}) | |
| chunks = data.get("chunks", []) | |
| if isinstance(chunks, list): | |
| return [c for c in chunks if isinstance(c, dict)] | |
| return [] | |
| def _extract_llm_text(raw_result: dict[str, Any]) -> str: | |
| llm_response = raw_result.get("llm_response", {}) | |
| content = llm_response.get("content") | |
| if content is None: | |
| return "" | |
| return str(content) | |
| def _render_references(references: list[dict[str, str]]) -> str: | |
| if not references: | |
| return "" | |
| lines: list[str] = ["Sources:"] | |
| for ref in references: | |
| ref_id = str(ref.get("reference_id", "")).strip() | |
| file_path = str(ref.get("file_path", "")).strip() or "unknown" | |
| if ref_id: | |
| lines.append(f"[{ref_id}] {file_path}") | |
| return "\n".join(lines) | |
| def _enforce_citation_answer(answer: str, references: list[dict[str, str]]) -> str: | |
| if not references: | |
| return answer | |
| rendered_references = _render_references(references) | |
| safe_answer = answer.strip() | |
| if not _has_citation_marker(safe_answer): | |
| first_ref = str(references[0].get("reference_id", "1")).strip() or "1" | |
| safe_answer = f"{safe_answer}\n\nPrimary support [{first_ref}]." | |
| if rendered_references and "Sources:" not in safe_answer: | |
| safe_answer = f"{safe_answer}\n\n{rendered_references}" | |
| return safe_answer | |
| async def _rewrite_query_for_retrieval(original_question: str) -> str: | |
| if not REWRITE_QUERY_ENABLED: | |
| return original_question | |
| rewrite_prompt = ( | |
| "Rewrite this question for retrieval over a Noam Chomsky corpus. " | |
| "Preserve intent and named entities. Return one line only, no extra text.\n\n" | |
| f"Question: {original_question}" | |
| ) | |
| try: | |
| rewritten = await llm_model_func( | |
| rewrite_prompt, | |
| system_prompt="You are a retrieval query rewriter.", | |
| history_messages=[], | |
| ) | |
| candidate = _response_to_text(rewritten).strip().splitlines()[0].strip() | |
| if not candidate: | |
| return original_question | |
| return candidate[:600] | |
| except Exception: | |
| return original_question | |
| def _dynamic_query_param( | |
| selected_mode: str, | |
| original_question: str, | |
| rewritten_question: str, | |
| *, | |
| retry_level: int = 0, | |
| ) -> "QueryParam": | |
| base_top_k = int(os.getenv("TOP_K", "40")) | |
| base_chunk_top_k = int(os.getenv("CHUNK_TOP_K", "20")) | |
| text = f"{original_question} {rewritten_question}".lower() | |
| token_count = len(re.findall(r"\w+", rewritten_question)) | |
| top_k = base_top_k | |
| chunk_top_k = base_chunk_top_k | |
| if token_count > 18: | |
| top_k += 15 | |
| chunk_top_k += 15 | |
| if any(k in text for k in ("compare", "versus", "difference", "contrast")): | |
| top_k += 20 | |
| chunk_top_k += 20 | |
| if any(k in text for k in ("timeline", "history", "evolution", "over time")): | |
| top_k += 20 | |
| chunk_top_k += 20 | |
| if any(k in text for k in ("why", "how", "explain", "analyze")): | |
| top_k += 10 | |
| chunk_top_k += 10 | |
| if retry_level == 1: | |
| top_k = max(top_k, 80) | |
| chunk_top_k = max(chunk_top_k, 80) | |
| elif retry_level >= 2: | |
| top_k = max(top_k, 100) | |
| chunk_top_k = max(chunk_top_k, 100) | |
| rerank_default = os.getenv("RERANK_BY_DEFAULT", "false").lower() == "true" | |
| enable_rerank = rerank_default and retry_level == 0 | |
| return QueryParam( | |
| mode=selected_mode, | |
| top_k=top_k, | |
| chunk_top_k=chunk_top_k, | |
| enable_rerank=enable_rerank, | |
| include_references=True, | |
| response_type="Multiple Paragraphs", | |
| ) | |
| async def _verify_claims( | |
| answer_text: str, | |
| chunks: list[dict[str, Any]], | |
| ) -> str: | |
| if not VERIFY_CLAIMS_ENABLED or not answer_text.strip() or not chunks: | |
| return "" | |
| evidence_lines: list[str] = [] | |
| for chunk in chunks[:8]: | |
| ref_id = str(chunk.get("reference_id", "?")).strip() or "?" | |
| content = str(chunk.get("content", "")).strip().replace("\n", " ") | |
| if content: | |
| evidence_lines.append(f"[{ref_id}] {content[:700]}") | |
| if not evidence_lines: | |
| return "" | |
| verifier_prompt = ( | |
| "Verify claims in the answer using ONLY the provided evidence snippets. " | |
| "Return strict JSON with keys: verdict, unsupported_claims, notes. " | |
| "verdict must be one of supported|partially_supported|unsupported.\n\n" | |
| f"Answer:\n{answer_text}\n\n" | |
| f"Evidence:\n{os.linesep.join(evidence_lines)}" | |
| ) | |
| try: | |
| verifier_response = await llm_model_func( | |
| verifier_prompt, | |
| system_prompt="You are a strict factual verifier.", | |
| history_messages=[], | |
| ) | |
| verifier_text = _response_to_text(verifier_response) | |
| verifier_json = _extract_json_object(verifier_text) | |
| if not verifier_json: | |
| return "" | |
| verdict = str(verifier_json.get("verdict", "")).strip().lower() | |
| unsupported_claims = verifier_json.get("unsupported_claims", []) | |
| if verdict in {"supported", ""} or not isinstance(unsupported_claims, list): | |
| return "" | |
| cleaned_claims = [ | |
| str(c).strip() for c in unsupported_claims if str(c).strip() | |
| ][:5] | |
| if not cleaned_claims: | |
| return "" | |
| joined = "\n".join(f"- {claim}" for claim in cleaned_claims) | |
| return ( | |
| "\n\nClaim verification: some claims could not be fully supported by retrieved evidence." | |
| f"\n{joined}" | |
| ) | |
| except Exception: | |
| return "" | |
| cached = get_cached_answer(question, mode) | |
| if cached is not None: | |
| return cached | |
| rag = None | |
| try: | |
| rag = await initialize_rag(working_dir) | |
| rewritten_question = await _rewrite_query_for_retrieval(question) | |
| selected_result: dict[str, Any] | None = None | |
| attempt_modes = [mode, mode, "mix"] if mode != "mix" else ["mix", "mix"] | |
| for retry_level, attempt_mode in enumerate(attempt_modes): | |
| param = _dynamic_query_param( | |
| attempt_mode, | |
| question, | |
| rewritten_question, | |
| retry_level=retry_level, | |
| ) | |
| result = await rag.aquery_llm( | |
| rewritten_question, | |
| param=param, | |
| system_prompt=CITATION_SYSTEM_PROMPT, | |
| ) | |
| answer_text = _extract_llm_text(result) | |
| selected_result = result | |
| if answer_text and not _looks_like_no_answer(answer_text): | |
| break | |
| if selected_result is None: | |
| return ( | |
| "I do not have enough information to answer from the retrieved corpus." | |
| ) | |
| answer_text = _extract_llm_text(selected_result) | |
| references = _extract_references(selected_result) | |
| chunks = _extract_chunks(selected_result) | |
| answer_with_citations = _enforce_citation_answer(answer_text, references) | |
| verification_summary = await _verify_claims(answer_with_citations, chunks) | |
| final_answer = f"{answer_with_citations}{verification_summary}".strip() | |
| cache_answer(question, mode, final_answer) | |
| return final_answer | |
| finally: | |
| if rag is not None: | |
| await rag.finalize_storages() | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="LightRAG over the Chomsky corpus") | |
| parser.add_argument( | |
| "--ingest", action="store_true", help="Index dataset into LightRAG" | |
| ) | |
| parser.add_argument("--query", type=str, help="Question to ask") | |
| parser.add_argument( | |
| "--mode", | |
| type=str, | |
| default="hybrid", | |
| choices=["naive", "local", "global", "hybrid", "mix"], | |
| help="LightRAG query mode", | |
| ) | |
| parser.add_argument( | |
| "--doc-limit", type=int, default=200, help="How many docs to index" | |
| ) | |
| parser.add_argument( | |
| "--working-dir", | |
| type=str, | |
| default=DEFAULT_WORKING_DIR, | |
| help="Directory where LightRAG stores graph/vector state", | |
| ) | |
| return parser.parse_args() | |
| async def run_cli(args: argparse.Namespace) -> None: | |
| if args.ingest: | |
| count = await ingest_corpus( | |
| doc_limit=args.doc_limit, working_dir=args.working_dir | |
| ) | |
| print(f"Indexed {count} documents into LightRAG store: {args.working_dir}") | |
| if args.query: | |
| answer = await query_rag( | |
| args.query, mode=args.mode, working_dir=args.working_dir | |
| ) | |
| print(f"\nQ: {args.query}") | |
| print(f"\nA: {answer}") | |
| if not args.ingest and not args.query: | |
| print("Nothing to do. Use --ingest and/or --query.") | |
| if __name__ == "__main__": | |
| asyncio.run(run_cli(parse_args())) | |