Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import math | |
| import re | |
| from collections import Counter | |
| from dataclasses import dataclass | |
| from tqdm import tqdm | |
| from .config import Settings | |
| from .ollama_client import OllamaClient | |
| from .pdf_processing import ChunkRecord, discover_pdfs, extract_chunks_from_pdf | |
| from .vector_store import ChromaMathStore, RetrievedChunk | |
| PLAIN_FRACTION_RE = re.compile(r"\b(?:[A-Za-z]+|\d+)\s*/\s*(?:[A-Za-z]+|\d+)\b") | |
| class SupplementalChunk: | |
| record: ChunkRecord | |
| embedding: list[float] | |
| def build_messages(question: str, chunks: list[RetrievedChunk]) -> list[dict[str, str]]: | |
| best_chapter = infer_label([chunk.chapter_name for chunk in chunks]) | |
| best_topic = infer_label([chunk.topic for chunk in chunks]) | |
| context_parts = [] | |
| for index, chunk in enumerate(chunks, start=1): | |
| context_parts.append( | |
| f"[Source {index}]\n" | |
| f"Chapter: {chunk.chapter_name}\n" | |
| f"Topic: {chunk.topic}\n" | |
| f"Page: {chunk.page_number}\n" | |
| f"Text:\n{chunk.text}" | |
| ) | |
| joined_context = "\n\n".join(context_parts) | |
| system_prompt = ( | |
| "You are an advanced AI Mathematics Tutor and Solver specialized in accurate, verified, step-by-step mathematical reasoning.\n" | |
| "Your PRIMARY GOAL is MATHEMATICAL CORRECTNESS — not sounding academic or verbose.\n\n" | |
| "Use only the retrieved context to answer the question. If the context is not enough, say that clearly instead of guessing.\n\n" | |
| "━━━━━━━━━━━━━━━━━━\n" | |
| "CORE RULES\n" | |
| "━━━━━━━━━━━━━━━━━━\n" | |
| "1. NEVER hallucinate formulas, theorems, chapter names, topic numbers, labels, or references.\n" | |
| "2. NEVER invent: Formula IDs, “P4”, “Theorem 2”, fake verification statements, or fake derivations.\n" | |
| "3. ONLY use formulas that are mathematically relevant to the problem.\n" | |
| "4. EVERY equation must logically follow from the previous equation.\n" | |
| "5. NEVER generate meaningless expressions such as: nested unnecessary integrals, invalid substitutions, or symbolic nonsense.\n" | |
| "6. If the mathematical expression is unclear, corrupted, or OCR extraction seems wrong, respond with: 'The mathematical expression is unclear. Please upload a clearer image or rewrite the question.'\n" | |
| "7. NEVER pretend verification succeeded unless it is actually correct.\n" | |
| "8. Accuracy is MORE important than speed or style.\n\n" | |
| "━━━━━━━━━━━━━━━━━━\n" | |
| "OCR / IMAGE UNDERSTANDING RULES\n" | |
| "━━━━━━━━━━━━━━━━━━\n" | |
| "When solving questions extracted from images:\n" | |
| "1. Carefully identify: powers (², ³), limits of integration, fractions, derivatives, Greek symbols like π, and superscripts/subscripts.\n" | |
| "2. Distinguish carefully between: sin²x and sin'x, 0 and O, n and π, 1 and l, x² and x.\n" | |
| "3. Before solving, internally verify that the interpreted equation is mathematically sensible.\n" | |
| "4. If OCR output appears inconsistent or nonsensical: DO NOT SOLVE. Instead ask for a clearer image.\n\n" | |
| "━━━━━━━━━━━━━━━━━━\n" | |
| "MATHEMATICAL SOLVING RULES\n" | |
| "━━━━━━━━━━━━━━━━━━\n" | |
| "For EVERY problem:\n" | |
| "1. Identify the problem type first.\n" | |
| "2. Choose ONLY the correct mathematical method.\n" | |
| "3. Show COMPLETE step-by-step derivation.\n" | |
| "4. Clearly show: Given expression, Formula used, Substitution (if any), Simplification, Computation, Final Answer.\n" | |
| "5. Use proper LaTeX formatting. Use `$$ ... $$` for standalone equations and `$ ... $` for inline maths.\n" | |
| "6. Final answers MUST always be enclosed in `\\boxed{...}`.\n\n" | |
| "━━━━━━━━━━━━━━━━━━\n" | |
| "INTEGRATION & DEFINITE INTEGRAL RULES\n" | |
| "━━━━━━━━━━━━━━━━━━\n" | |
| "1. Check carefully for: substitution, integration by parts, trigonometric identities, partial fractions.\n" | |
| "2. NEVER confuse: sin²x with sin'(x) or derivative notation with exponents.\n" | |
| "3. For logarithmic integrals, use correct formulas.\n" | |
| "4. Verify signs (+/-) carefully.\n" | |
| "5. Correctly apply limits for definite integrals.\n" | |
| "6. Check convergence/divergence for improper integrals.\n" | |
| "7. Never add arbitrary constants to definite integrals.\n" | |
| "8. DIFFERENTIATE the final answer to verify correctness.\n\n" | |
| "━━━━━━━━━━━━━━━━━━\n" | |
| "VERIFICATION LAYER\n" | |
| "━━━━━━━━━━━━━━━━━━\n" | |
| "Before giving the final response:\n" | |
| "1. Recheck every algebraic manipulation, substitutions, and trigonometric identities.\n" | |
| "2. Recompute the final answer mentally/symbolically.\n" | |
| "3. Verify the final result: differentiate antiderivatives, substitute solutions back, check arithmetic consistency.\n" | |
| "4. If any mistake is found: AUTOMATICALLY CORRECT the full solution before responding.\n\n" | |
| "━━━━━━━━━━━━━━━━━━\n" | |
| "STRICTLY FORBIDDEN\n" | |
| "━━━━━━━━━━━━━━━━━━\n" | |
| "DO NOT:\n" | |
| "- hallucinate steps\n" | |
| "- fabricate verification\n" | |
| "- invent formulas\n" | |
| "- generate symbolic nonsense\n" | |
| "- output contradictory equations\n" | |
| "- force a solution when OCR is unclear\n" | |
| "- use irrelevant methods\n" | |
| "- claim confidence without verification\n\n" | |
| "━━━━━━━━━━━━━━━━━━\n" | |
| "OUTPUT FORMAT\n" | |
| "━━━━━━━━━━━━━━━━━━\n" | |
| "Use EXACTLY this structure:\n\n" | |
| f"Chapter: {best_chapter}\n" | |
| f"Topic: {best_topic}\n\n" | |
| "## Given\n\n" | |
| "## Formula Used\n\n" | |
| "## Step-by-Step Solution\n\n" | |
| "## Verification\n\n" | |
| "## Final Answer\n\n" | |
| "━━━━━━━━━━━━━━━━━━\n" | |
| "FINAL PRIORITY\n" | |
| "━━━━━━━━━━━━━━━━━━\n" | |
| "Correctness > Clarity > Formatting > Speed\n\n" | |
| "If uncertain: re-solve carefully, verify again, do NOT guess." | |
| ) | |
| user_prompt = f"Question:\n{question}\n\nRetrieved context:\n{joined_context}" | |
| return [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ] | |
| def build_reformat_messages( | |
| question: str, | |
| answer: str, | |
| chapter: str, | |
| topic: str, | |
| ) -> list[dict[str, str]]: | |
| system_prompt = ( | |
| "You are an advanced AI Mathematics Tutor cleaning up a Class 12 Mathematics answer.\n" | |
| "Rewrite the answer into clean Markdown without changing the maths or adding new facts.\n" | |
| "Convert all mathematics into proper LaTeX.\n" | |
| "Use `$$ ... $$` for standalone equations and `$ ... $` for short inline maths.\n" | |
| "Do not leave plain-text fractions like `y/x` or `dy/dx` if they should be proper fractions.\n" | |
| "Final answers must always be enclosed in a box using `\\boxed{}`.\n" | |
| "Use EXACTLY this structure:\n" | |
| f"Chapter: {chapter}\n" | |
| f"Topic: {topic}\n\n" | |
| "## Given\n" | |
| "## Formula Used\n" | |
| "## Step-by-Step Solution\n" | |
| "## Verification\n" | |
| "## Final Answer\n" | |
| "Put important equations on separate lines and do not use tables.\n" | |
| ) | |
| user_prompt = ( | |
| f"Original question:\n{question}\n\n" | |
| f"Answer to clean up:\n{answer}" | |
| ) | |
| return [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ] | |
| def needs_reformatting(answer: str) -> bool: | |
| has_final_answer = "## Final Answer" in answer | |
| has_structure = any( | |
| marker in answer | |
| for marker in ("## Given", "## Step-by-Step Solution", "## Formula Used") | |
| ) | |
| uses_plain_fractions = bool(PLAIN_FRACTION_RE.search(answer)) and "\\frac{" not in answer | |
| missing_display_math = "\\frac{" in answer and "$$" not in answer and "\\[" not in answer | |
| return not (has_final_answer and has_structure) or uses_plain_fractions or missing_display_math | |
| def infer_label(values: list[str]) -> str: | |
| if not values: | |
| return "Not identified" | |
| return Counter(values).most_common(1)[0][0] | |
| def batched(items: list[ChunkRecord], size: int) -> list[list[ChunkRecord]]: | |
| return [items[index:index + size] for index in range(0, len(items), size)] | |
| def cosine_similarity(left: list[float], right: list[float]) -> float: | |
| numerator = sum(l * r for l, r in zip(left, right)) | |
| left_norm = math.sqrt(sum(value * value for value in left)) | |
| right_norm = math.sqrt(sum(value * value for value in right)) | |
| if left_norm == 0 or right_norm == 0: | |
| return 0.0 | |
| return numerator / (left_norm * right_norm) | |
| def build_supplemental_chunks( | |
| chunks: list[ChunkRecord], | |
| *, | |
| ollama: OllamaClient, | |
| model: str, | |
| ) -> list[SupplementalChunk]: | |
| embeddings: list[list[float]] = [] | |
| for chunk_batch in batched(chunks, 16): | |
| embeddings.extend( | |
| ollama.embed_texts( | |
| [chunk.text for chunk in chunk_batch], | |
| model=model, | |
| ) | |
| ) | |
| return [ | |
| SupplementalChunk(record=chunk, embedding=embedding) | |
| for chunk, embedding in zip(chunks, embeddings) | |
| ] | |
| def rank_supplemental_chunks( | |
| *, | |
| query_embedding: list[float], | |
| extra_chunks: list[SupplementalChunk], | |
| top_k: int, | |
| ) -> list[RetrievedChunk]: | |
| ranked = sorted( | |
| extra_chunks, | |
| key=lambda chunk: cosine_similarity(query_embedding, chunk.embedding), | |
| reverse=True, | |
| )[:top_k] | |
| return [ | |
| RetrievedChunk( | |
| chunk_id=chunk.record.chunk_id, | |
| text=chunk.record.text, | |
| chapter_number=chunk.record.chapter_number, | |
| chapter_name=chunk.record.chapter_name, | |
| topic=chunk.record.topic, | |
| page_number=chunk.record.page_number, | |
| source_file=chunk.record.source_file, | |
| distance=max(0.0, 1.0 - cosine_similarity(query_embedding, chunk.embedding)), | |
| ) | |
| for chunk in ranked | |
| ] | |
| def merge_retrieved_chunks( | |
| store_chunks: list[RetrievedChunk], | |
| extra_chunks: list[RetrievedChunk], | |
| *, | |
| top_k: int, | |
| ) -> list[RetrievedChunk]: | |
| if not store_chunks: | |
| return extra_chunks[:top_k] | |
| if not extra_chunks: | |
| return store_chunks[:top_k] | |
| store_limit = min(len(store_chunks), max(1, (top_k + 1) // 2)) | |
| extra_limit = min(len(extra_chunks), max(1, top_k - store_limit)) | |
| merged = store_chunks[:store_limit] + extra_chunks[:extra_limit] | |
| remaining = top_k - len(merged) | |
| if remaining <= 0: | |
| return merged[:top_k] | |
| leftovers = store_chunks[store_limit:] + extra_chunks[extra_limit:] | |
| return (merged + leftovers[:remaining])[:top_k] | |
| def run_ingestion(settings: Settings, reset: bool = False) -> dict[str, int]: | |
| settings.chroma_dir.mkdir(parents=True, exist_ok=True) | |
| settings.processed_dir.mkdir(parents=True, exist_ok=True) | |
| data_dir = settings.root_dir / "data" | |
| data_dir.mkdir(parents=True, exist_ok=True) | |
| pdf_files = discover_pdfs(data_dir) | |
| if not pdf_files: | |
| raise FileNotFoundError(f"No PDF files were found in the data directory ({data_dir}).") | |
| all_chunks: list[ChunkRecord] = [] | |
| for pdf_file in pdf_files: | |
| all_chunks.extend( | |
| extract_chunks_from_pdf( | |
| file_path=pdf_file, | |
| chunk_size=settings.chunk_size, | |
| chunk_overlap=settings.chunk_overlap, | |
| ) | |
| ) | |
| if not all_chunks: | |
| raise ValueError("No text could be extracted from the PDFs.") | |
| store = ChromaMathStore(settings) | |
| if reset: | |
| store.reset() | |
| ollama = OllamaClient(settings.ollama_base_url, timeout=settings.request_timeout) | |
| for chunk_batch in tqdm(batched(all_chunks, 16), desc="Embedding chunks"): | |
| embeddings = ollama.embed_texts( | |
| [chunk.text for chunk in chunk_batch], | |
| model=settings.embed_model, | |
| ) | |
| store.add_chunks(chunk_batch, embeddings) | |
| summary = { | |
| "pdf_count": len(pdf_files), | |
| "chunks_created": len(all_chunks), | |
| "vector_count": store.count(), | |
| } | |
| summary_path = settings.processed_dir / "ingestion_summary.json" | |
| summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") | |
| return summary | |
| class MathRAGAssistant: | |
| def __init__(self, settings: Settings) -> None: | |
| self.settings = settings | |
| self.store = ChromaMathStore(settings) | |
| self.ollama = OllamaClient(settings.ollama_base_url, timeout=settings.request_timeout) | |
| def answer( | |
| self, | |
| question: str, | |
| extra_chunks: list[SupplementalChunk] | None = None, | |
| ) -> dict[str, object]: | |
| store_count = self.store.count() | |
| if store_count == 0 and not extra_chunks: | |
| raise ValueError( | |
| "The ChromaDB collection is empty. Run `python ingest.py --reset` first." | |
| ) | |
| query_embedding = self.ollama.embed_texts( | |
| [question], | |
| model=self.settings.embed_model, | |
| )[0] | |
| store_retrieved: list[RetrievedChunk] = [] | |
| if store_count > 0: | |
| store_retrieved = ( | |
| self.store.query(query_embedding=query_embedding, top_k=self.settings.top_k) | |
| ) | |
| extra_retrieved: list[RetrievedChunk] = [] | |
| if extra_chunks: | |
| extra_retrieved = rank_supplemental_chunks( | |
| query_embedding=query_embedding, | |
| extra_chunks=extra_chunks, | |
| top_k=self.settings.top_k, | |
| ) | |
| retrieved = merge_retrieved_chunks( | |
| store_retrieved, | |
| extra_retrieved, | |
| top_k=self.settings.top_k, | |
| ) | |
| if not retrieved: | |
| raise ValueError("No retrievable context is available for this question.") | |
| answer = self.ollama.chat( | |
| model=self.settings.llm_model, | |
| messages=build_messages(question, retrieved), | |
| ) | |
| best_chapter = infer_label([chunk.chapter_name for chunk in retrieved]) | |
| best_topic = infer_label([chunk.topic for chunk in retrieved]) | |
| if needs_reformatting(answer): | |
| answer = self.ollama.chat( | |
| model=self.settings.llm_model, | |
| messages=build_reformat_messages( | |
| question=question, | |
| answer=answer, | |
| chapter=best_chapter, | |
| topic=best_topic, | |
| ), | |
| ) | |
| sources = [ | |
| { | |
| "chapter_number": chunk.chapter_number, | |
| "chapter_name": chunk.chapter_name, | |
| "topic": chunk.topic, | |
| "page_number": chunk.page_number, | |
| "source_file": chunk.source_file, | |
| } | |
| for chunk in retrieved | |
| ] | |
| return {"answer": answer, "sources": sources} | |
| def parse_args_and_run_ingestion(settings: Settings) -> None: | |
| parser = argparse.ArgumentParser(description="Ingest Class 12 Maths PDFs into ChromaDB.") | |
| parser.add_argument( | |
| "--reset", | |
| action="store_true", | |
| help="Delete the existing collection before re-ingesting the PDFs.", | |
| ) | |
| args = parser.parse_args() | |
| summary = run_ingestion(settings=settings, reset=args.reset) | |
| print(json.dumps(summary, indent=2)) | |