feat: Implement Graph RAG pipeline with chunking, vector storage, and graph building
Browse files- Added `rag` module with core components:
- `chunker.py`: Implements semantic chunking of documents.
- `vector_store.py`: Integrates ChromaDB for storing and retrieving document chunks.
- `graph_builder.py`: Constructs a knowledge graph from document chunks, establishing relationships based on similarity and section headings.
- `groq_chat.py`: Facilitates chat interactions using Groq API with context from the knowledge graph.
- `rag_pipeline.py`: Orchestrates the entire RAG process, from ingestion to querying.
- Introduced `PipelineState` to manage the state of the RAG pipeline.
- Enhanced document processing with robust text extraction and chunking strategies.
- Added support for entity linking and cross-document similarity in the graph.
- Integrated debug utilities for inspecting raw document attributes.
- requirements.txt +9 -0
- src/rag/__init__.py +12 -0
- src/rag/chunker.py +231 -0
- src/rag/graph_builder.py +245 -0
- src/rag/groq_chat.py +130 -0
- src/rag/rag_pipeline.py +149 -0
- src/rag/vector_store.py +130 -0
- src/streamlit_app.py +249 -316
- src/utils/__pycache__/visualization.cpython-313.pyc +0 -0
|
@@ -25,6 +25,15 @@ numpy>=1.26.0
|
|
| 25 |
pandas>=2.2.0
|
| 26 |
Pillow>=10.2.0
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
# Utilities
|
| 29 |
python-dotenv>=1.0.0
|
| 30 |
|
|
|
|
| 25 |
pandas>=2.2.0
|
| 26 |
Pillow>=10.2.0
|
| 27 |
|
| 28 |
+
# Vector DB (Graph RAG)
|
| 29 |
+
chromadb>=0.5.0
|
| 30 |
+
|
| 31 |
+
# Knowledge Graph (Graph RAG)
|
| 32 |
+
networkx>=3.2.0
|
| 33 |
+
|
| 34 |
+
# Groq API (Graph RAG Chat)
|
| 35 |
+
groq>=0.9.0
|
| 36 |
+
|
| 37 |
# Utilities
|
| 38 |
python-dotenv>=1.0.0
|
| 39 |
|
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .chunker import Chunk, chunk_document, chunk_text
|
| 2 |
+
from .vector_store import VectorStore
|
| 3 |
+
from .graph_builder import GraphBuilder
|
| 4 |
+
from .groq_chat import GroqGraphChat
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"Chunk", "chunk_document", "chunk_text",
|
| 8 |
+
"VectorStore",
|
| 9 |
+
"GraphBuilder",
|
| 10 |
+
"GroqGraphChat",
|
| 11 |
+
]
|
| 12 |
+
|
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Smart Semantic Chunker
|
| 3 |
+
Chunks documents efficiently using sentence boundaries + structural signals.
|
| 4 |
+
"""
|
| 5 |
+
import re
|
| 6 |
+
from typing import List, Dict, Any
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class Chunk:
|
| 12 |
+
chunk_id: str
|
| 13 |
+
doc_id: str # "doc1" or "doc2"
|
| 14 |
+
text: str
|
| 15 |
+
chunk_index: int
|
| 16 |
+
section: str = "" # heading/section title if detected
|
| 17 |
+
page: int = 0
|
| 18 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _split_sentences(text: str) -> List[str]:
|
| 22 |
+
"""Split text into sentences using regex."""
|
| 23 |
+
sentences = re.split(r'(?<=[.!?])\s+(?=[A-Z])', text.strip())
|
| 24 |
+
return [s.strip() for s in sentences if s.strip()]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _detect_heading(line: str) -> bool:
|
| 28 |
+
"""Detect if a line looks like a section heading."""
|
| 29 |
+
line = line.strip()
|
| 30 |
+
if not line:
|
| 31 |
+
return False
|
| 32 |
+
if re.match(r'^(\d+[\.\)]\s+|[A-Z][A-Z\s]{3,50}$)', line):
|
| 33 |
+
return True
|
| 34 |
+
if len(line) < 80 and not line.endswith('.') and line[0].isupper():
|
| 35 |
+
if re.match(r'^(Abstract|Introduction|Conclusion|Method|Result|Discussion|Background|Overview|Summary)', line, re.I):
|
| 36 |
+
return True
|
| 37 |
+
return False
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def chunk_text(
|
| 41 |
+
text: str,
|
| 42 |
+
doc_id: str,
|
| 43 |
+
chunk_size: int = 300,
|
| 44 |
+
overlap: int = 50,
|
| 45 |
+
) -> List[Chunk]:
|
| 46 |
+
"""
|
| 47 |
+
Semantic chunking with section awareness, sentence boundary respect,
|
| 48 |
+
and sliding window overlap.
|
| 49 |
+
"""
|
| 50 |
+
chunks = []
|
| 51 |
+
lines = text.split('\n')
|
| 52 |
+
|
| 53 |
+
current_section = "General"
|
| 54 |
+
buffer_sentences = []
|
| 55 |
+
buffer_words = 0
|
| 56 |
+
chunk_index = 0
|
| 57 |
+
|
| 58 |
+
def flush_buffer(section: str) -> None:
|
| 59 |
+
nonlocal chunk_index, buffer_sentences, buffer_words
|
| 60 |
+
if not buffer_sentences:
|
| 61 |
+
return
|
| 62 |
+
chunk_text_val = ' '.join(buffer_sentences)
|
| 63 |
+
chunks.append(Chunk(
|
| 64 |
+
chunk_id=f"{doc_id}_chunk_{chunk_index}",
|
| 65 |
+
doc_id=doc_id,
|
| 66 |
+
text=chunk_text_val,
|
| 67 |
+
chunk_index=chunk_index,
|
| 68 |
+
section=section,
|
| 69 |
+
metadata={"word_count": buffer_words}
|
| 70 |
+
))
|
| 71 |
+
chunk_index += 1
|
| 72 |
+
overlap_sentences = []
|
| 73 |
+
overlap_words = 0
|
| 74 |
+
for sent in reversed(buffer_sentences):
|
| 75 |
+
w = len(sent.split())
|
| 76 |
+
if overlap_words + w <= overlap:
|
| 77 |
+
overlap_sentences.insert(0, sent)
|
| 78 |
+
overlap_words += w
|
| 79 |
+
else:
|
| 80 |
+
break
|
| 81 |
+
buffer_sentences = overlap_sentences
|
| 82 |
+
buffer_words = overlap_words
|
| 83 |
+
|
| 84 |
+
paragraph_buffer = []
|
| 85 |
+
|
| 86 |
+
for line in lines:
|
| 87 |
+
stripped = line.strip()
|
| 88 |
+
|
| 89 |
+
if _detect_heading(stripped):
|
| 90 |
+
if paragraph_buffer:
|
| 91 |
+
full_text = ' '.join(paragraph_buffer)
|
| 92 |
+
sentences = _split_sentences(full_text)
|
| 93 |
+
for sent in sentences:
|
| 94 |
+
buffer_sentences.append(sent)
|
| 95 |
+
buffer_words += len(sent.split())
|
| 96 |
+
if buffer_words >= chunk_size:
|
| 97 |
+
flush_buffer(current_section)
|
| 98 |
+
paragraph_buffer = []
|
| 99 |
+
flush_buffer(current_section)
|
| 100 |
+
current_section = stripped
|
| 101 |
+
continue
|
| 102 |
+
|
| 103 |
+
if stripped:
|
| 104 |
+
paragraph_buffer.append(stripped)
|
| 105 |
+
else:
|
| 106 |
+
if paragraph_buffer:
|
| 107 |
+
full_text = ' '.join(paragraph_buffer)
|
| 108 |
+
sentences = _split_sentences(full_text)
|
| 109 |
+
for sent in sentences:
|
| 110 |
+
buffer_sentences.append(sent)
|
| 111 |
+
buffer_words += len(sent.split())
|
| 112 |
+
if buffer_words >= chunk_size:
|
| 113 |
+
flush_buffer(current_section)
|
| 114 |
+
paragraph_buffer = []
|
| 115 |
+
|
| 116 |
+
if paragraph_buffer:
|
| 117 |
+
full_text = ' '.join(paragraph_buffer)
|
| 118 |
+
sentences = _split_sentences(full_text)
|
| 119 |
+
for sent in sentences:
|
| 120 |
+
buffer_sentences.append(sent)
|
| 121 |
+
buffer_words += len(sent.split())
|
| 122 |
+
flush_buffer(current_section)
|
| 123 |
+
|
| 124 |
+
return chunks
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ββ Debug helper ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 128 |
+
|
| 129 |
+
def debug_raw_doc(raw_doc) -> str:
|
| 130 |
+
"""Return a string summarising all attributes of a raw_doc for debugging."""
|
| 131 |
+
lines = [f"Type: {type(raw_doc).__name__}"]
|
| 132 |
+
try:
|
| 133 |
+
d = raw_doc.model_dump() if hasattr(raw_doc, 'model_dump') else vars(raw_doc)
|
| 134 |
+
for k, v in d.items():
|
| 135 |
+
if isinstance(v, str):
|
| 136 |
+
lines.append(f" str attr '{k}': len={len(v)} preview={repr(v[:80])}")
|
| 137 |
+
elif isinstance(v, list):
|
| 138 |
+
lines.append(f" list attr '{k}': len={len(v)}")
|
| 139 |
+
else:
|
| 140 |
+
lines.append(f" attr '{k}': {type(v).__name__} = {repr(str(v)[:60])}")
|
| 141 |
+
except Exception as e:
|
| 142 |
+
lines.append(f" (could not introspect: {e})")
|
| 143 |
+
return '\n'.join(lines)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# ββ Robust text extraction ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 147 |
+
|
| 148 |
+
def extract_text_from_raw_doc(raw_doc) -> str:
|
| 149 |
+
"""
|
| 150 |
+
Robustly extract text from whatever RawDocument the ingestion agent returns.
|
| 151 |
+
Tries all known attribute names and fallback strategies.
|
| 152 |
+
"""
|
| 153 |
+
# Strategy 1: Common direct string attributes
|
| 154 |
+
for attr in ['text_content', 'content', 'text', 'raw_text', 'full_text', 'body',
|
| 155 |
+
'extracted_text', 'plain_text', 'document_text']:
|
| 156 |
+
val = getattr(raw_doc, attr, None)
|
| 157 |
+
if val and isinstance(val, str) and len(val.strip()) > 10:
|
| 158 |
+
return val.strip()
|
| 159 |
+
|
| 160 |
+
# Strategy 2: List of pages / sections
|
| 161 |
+
for attr in ['pages', 'sections', 'chunks', 'paragraphs', 'text_chunks']:
|
| 162 |
+
val = getattr(raw_doc, attr, None)
|
| 163 |
+
if val and isinstance(val, list):
|
| 164 |
+
parts = []
|
| 165 |
+
for item in val:
|
| 166 |
+
if isinstance(item, str):
|
| 167 |
+
parts.append(item)
|
| 168 |
+
elif hasattr(item, 'text') and isinstance(item.text, str):
|
| 169 |
+
parts.append(item.text)
|
| 170 |
+
elif hasattr(item, 'content') and isinstance(item.content, str):
|
| 171 |
+
parts.append(item.content)
|
| 172 |
+
elif isinstance(item, dict):
|
| 173 |
+
parts.append(str(item.get('text') or item.get('content') or ''))
|
| 174 |
+
combined = '\n'.join(p for p in parts if p.strip())
|
| 175 |
+
if len(combined.strip()) > 10:
|
| 176 |
+
return combined.strip()
|
| 177 |
+
|
| 178 |
+
# Strategy 3: Pydantic model_dump / __dict__ β grab longest string field
|
| 179 |
+
try:
|
| 180 |
+
d = raw_doc.model_dump() if hasattr(raw_doc, 'model_dump') else vars(raw_doc)
|
| 181 |
+
# Preferred keys first
|
| 182 |
+
for key in ['text_content', 'content', 'text', 'raw_text', 'full_text', 'body']:
|
| 183 |
+
if key in d and isinstance(d[key], str) and len(d[key].strip()) > 10:
|
| 184 |
+
return d[key].strip()
|
| 185 |
+
# Any long string
|
| 186 |
+
best = max(
|
| 187 |
+
((k, v) for k, v in d.items() if isinstance(v, str)),
|
| 188 |
+
key=lambda kv: len(kv[1]),
|
| 189 |
+
default=(None, ''),
|
| 190 |
+
)
|
| 191 |
+
if len(best[1]) > 100:
|
| 192 |
+
return best[1].strip()
|
| 193 |
+
except Exception:
|
| 194 |
+
pass
|
| 195 |
+
|
| 196 |
+
# Strategy 4: str() fallback
|
| 197 |
+
fallback = str(raw_doc)
|
| 198 |
+
if len(fallback) > 50 and not fallback.startswith('<'):
|
| 199 |
+
return fallback
|
| 200 |
+
|
| 201 |
+
return ""
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def chunk_document(raw_doc, doc_id: str, chunk_size: int = 300, overlap: int = 50) -> List[Chunk]:
|
| 205 |
+
"""
|
| 206 |
+
Chunk a RawDocument object from the ingestion agent.
|
| 207 |
+
Robustly handles any attribute structure.
|
| 208 |
+
"""
|
| 209 |
+
text = extract_text_from_raw_doc(raw_doc)
|
| 210 |
+
|
| 211 |
+
if not text:
|
| 212 |
+
return [Chunk(
|
| 213 |
+
chunk_id=f"{doc_id}_chunk_0",
|
| 214 |
+
doc_id=doc_id,
|
| 215 |
+
text=f"[Could not extract text from {doc_id}. Attributes: {debug_raw_doc(raw_doc)[:200]}]",
|
| 216 |
+
chunk_index=0,
|
| 217 |
+
section="Error",
|
| 218 |
+
)]
|
| 219 |
+
|
| 220 |
+
chunks = chunk_text(text, doc_id, chunk_size, overlap)
|
| 221 |
+
|
| 222 |
+
if not chunks:
|
| 223 |
+
return [Chunk(
|
| 224 |
+
chunk_id=f"{doc_id}_chunk_0",
|
| 225 |
+
doc_id=doc_id,
|
| 226 |
+
text=text[:500],
|
| 227 |
+
chunk_index=0,
|
| 228 |
+
section="General",
|
| 229 |
+
)]
|
| 230 |
+
|
| 231 |
+
return chunks
|
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Graph RAG β Knowledge Graph Builder
|
| 3 |
+
Builds a NetworkX graph where:
|
| 4 |
+
- Nodes = chunks (from doc1 & doc2)
|
| 5 |
+
- Edges = relationships between chunks:
|
| 6 |
+
* sequential : consecutive chunks in same document
|
| 7 |
+
* same_section : chunks sharing the same heading/section
|
| 8 |
+
* cross_similar: high cosine similarity between doc1 chunk & doc2 chunk
|
| 9 |
+
* entity_link : chunks sharing important noun phrases (entities)
|
| 10 |
+
"""
|
| 11 |
+
import re
|
| 12 |
+
import networkx as nx
|
| 13 |
+
from typing import List, Dict, Any, Tuple
|
| 14 |
+
from sentence_transformers import SentenceTransformer
|
| 15 |
+
import numpy as np
|
| 16 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 17 |
+
|
| 18 |
+
from .chunker import Chunk
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
_EMBED_MODEL_NAME = "all-MiniLM-L6-v2"
|
| 22 |
+
_CROSS_SIM_THRESHOLD = 0.55 # min similarity to create a cross-doc edge
|
| 23 |
+
_ENTITY_MIN_LEN = 4 # min characters for an entity term
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _extract_noun_phrases(text: str) -> set:
|
| 27 |
+
"""
|
| 28 |
+
Lightweight noun phrase extraction via regex patterns.
|
| 29 |
+
No spacy dependency β works in constrained environments.
|
| 30 |
+
"""
|
| 31 |
+
# Capitalised multi-word phrases and key technical terms
|
| 32 |
+
patterns = [
|
| 33 |
+
r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)+\b', # "Neural Network", "New York"
|
| 34 |
+
r'\b[A-Z]{2,}\b', # acronyms: "RAG", "LLM"
|
| 35 |
+
r'\b\w{5,}\b', # any long word (catch technical terms)
|
| 36 |
+
]
|
| 37 |
+
entities = set()
|
| 38 |
+
for pat in patterns:
|
| 39 |
+
found = re.findall(pat, text)
|
| 40 |
+
entities.update(f.strip().lower() for f in found if len(f) >= _ENTITY_MIN_LEN)
|
| 41 |
+
# Remove very common stopwords
|
| 42 |
+
stopwords = {'which', 'these', 'those', 'their', 'there', 'where', 'about',
|
| 43 |
+
'would', 'could', 'should', 'other', 'being', 'using', 'having'}
|
| 44 |
+
return entities - stopwords
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class GraphBuilder:
|
| 48 |
+
"""
|
| 49 |
+
Builds and queries a knowledge graph from doc chunks.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(self):
|
| 53 |
+
self._model = SentenceTransformer(_EMBED_MODEL_NAME)
|
| 54 |
+
self.graph: nx.Graph = nx.Graph()
|
| 55 |
+
self._chunk_map: Dict[str, Chunk] = {} # chunk_id -> Chunk
|
| 56 |
+
|
| 57 |
+
# ------------------------------------------------------------------
|
| 58 |
+
# Build
|
| 59 |
+
# ------------------------------------------------------------------
|
| 60 |
+
|
| 61 |
+
def build(self, doc1_chunks: List[Chunk], doc2_chunks: List[Chunk]) -> nx.Graph:
|
| 62 |
+
"""
|
| 63 |
+
Full graph construction pipeline.
|
| 64 |
+
Returns the built NetworkX graph.
|
| 65 |
+
"""
|
| 66 |
+
self.graph = nx.Graph()
|
| 67 |
+
self._chunk_map = {}
|
| 68 |
+
|
| 69 |
+
all_chunks = doc1_chunks + doc2_chunks
|
| 70 |
+
|
| 71 |
+
# 1. Add nodes
|
| 72 |
+
for chunk in all_chunks:
|
| 73 |
+
self._chunk_map[chunk.chunk_id] = chunk
|
| 74 |
+
self.graph.add_node(
|
| 75 |
+
chunk.chunk_id,
|
| 76 |
+
text=chunk.text[:200], # store snippet
|
| 77 |
+
doc_id=chunk.doc_id,
|
| 78 |
+
section=chunk.section,
|
| 79 |
+
chunk_index=chunk.chunk_index,
|
| 80 |
+
entities=list(_extract_noun_phrases(chunk.text)),
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# 2. Sequential edges (within same doc)
|
| 84 |
+
self._add_sequential_edges(doc1_chunks)
|
| 85 |
+
self._add_sequential_edges(doc2_chunks)
|
| 86 |
+
|
| 87 |
+
# 3. Same-section edges
|
| 88 |
+
self._add_section_edges(all_chunks)
|
| 89 |
+
|
| 90 |
+
# 4. Cross-document similarity edges
|
| 91 |
+
self._add_cross_similarity_edges(doc1_chunks, doc2_chunks)
|
| 92 |
+
|
| 93 |
+
# 5. Entity co-occurrence edges
|
| 94 |
+
self._add_entity_edges(all_chunks)
|
| 95 |
+
|
| 96 |
+
return self.graph
|
| 97 |
+
|
| 98 |
+
def _add_sequential_edges(self, chunks: List[Chunk]) -> None:
|
| 99 |
+
sorted_chunks = sorted(chunks, key=lambda c: c.chunk_index)
|
| 100 |
+
for i in range(len(sorted_chunks) - 1):
|
| 101 |
+
a, b = sorted_chunks[i], sorted_chunks[i + 1]
|
| 102 |
+
self.graph.add_edge(
|
| 103 |
+
a.chunk_id, b.chunk_id,
|
| 104 |
+
relation="sequential",
|
| 105 |
+
weight=0.9,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
def _add_section_edges(self, chunks: List[Chunk]) -> None:
|
| 109 |
+
section_map: Dict[str, List[str]] = {}
|
| 110 |
+
for chunk in chunks:
|
| 111 |
+
key = f"{chunk.doc_id}::{chunk.section}"
|
| 112 |
+
section_map.setdefault(key, []).append(chunk.chunk_id)
|
| 113 |
+
|
| 114 |
+
for ids in section_map.values():
|
| 115 |
+
for i in range(len(ids)):
|
| 116 |
+
for j in range(i + 1, len(ids)):
|
| 117 |
+
if not self.graph.has_edge(ids[i], ids[j]):
|
| 118 |
+
self.graph.add_edge(
|
| 119 |
+
ids[i], ids[j],
|
| 120 |
+
relation="same_section",
|
| 121 |
+
weight=0.6,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def _add_cross_similarity_edges(
|
| 125 |
+
self, doc1_chunks: List[Chunk], doc2_chunks: List[Chunk]
|
| 126 |
+
) -> None:
|
| 127 |
+
if not doc1_chunks or not doc2_chunks:
|
| 128 |
+
return
|
| 129 |
+
|
| 130 |
+
texts1 = [c.text for c in doc1_chunks]
|
| 131 |
+
texts2 = [c.text for c in doc2_chunks]
|
| 132 |
+
|
| 133 |
+
emb1 = self._model.encode(texts1, batch_size=32, show_progress_bar=False)
|
| 134 |
+
emb2 = self._model.encode(texts2, batch_size=32, show_progress_bar=False)
|
| 135 |
+
|
| 136 |
+
sim_matrix = cosine_similarity(emb1, emb2)
|
| 137 |
+
|
| 138 |
+
for i, c1 in enumerate(doc1_chunks):
|
| 139 |
+
for j, c2 in enumerate(doc2_chunks):
|
| 140 |
+
sim = float(sim_matrix[i, j])
|
| 141 |
+
if sim >= _CROSS_SIM_THRESHOLD:
|
| 142 |
+
self.graph.add_edge(
|
| 143 |
+
c1.chunk_id, c2.chunk_id,
|
| 144 |
+
relation="cross_similar",
|
| 145 |
+
weight=round(sim, 4),
|
| 146 |
+
similarity=round(sim, 4),
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def _add_entity_edges(self, chunks: List[Chunk]) -> None:
|
| 150 |
+
entity_to_chunks: Dict[str, List[str]] = {}
|
| 151 |
+
for chunk in chunks:
|
| 152 |
+
entities = _extract_noun_phrases(chunk.text)
|
| 153 |
+
for ent in entities:
|
| 154 |
+
entity_to_chunks.setdefault(ent, []).append(chunk.chunk_id)
|
| 155 |
+
|
| 156 |
+
for ent, ids in entity_to_chunks.items():
|
| 157 |
+
if len(ids) < 2:
|
| 158 |
+
continue
|
| 159 |
+
# Only connect cross-doc pairs to avoid too many same-doc entity edges
|
| 160 |
+
doc_ids = {self._chunk_map[cid].doc_id: cid for cid in ids}
|
| 161 |
+
if len(doc_ids) >= 2:
|
| 162 |
+
cids = list(doc_ids.values())
|
| 163 |
+
for i in range(len(cids)):
|
| 164 |
+
for j in range(i + 1, len(cids)):
|
| 165 |
+
if not self.graph.has_edge(cids[i], cids[j]):
|
| 166 |
+
self.graph.add_edge(
|
| 167 |
+
cids[i], cids[j],
|
| 168 |
+
relation="entity_link",
|
| 169 |
+
entity=ent,
|
| 170 |
+
weight=0.5,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# ------------------------------------------------------------------
|
| 174 |
+
# Query
|
| 175 |
+
# ------------------------------------------------------------------
|
| 176 |
+
|
| 177 |
+
def retrieve(
|
| 178 |
+
self,
|
| 179 |
+
query: str,
|
| 180 |
+
seed_chunks: List[Dict[str, Any]], # from VectorStore.search()
|
| 181 |
+
hops: int = 2,
|
| 182 |
+
max_nodes: int = 10,
|
| 183 |
+
) -> List[Dict[str, Any]]:
|
| 184 |
+
"""
|
| 185 |
+
Graph-aware retrieval:
|
| 186 |
+
1. Start from seed chunk nodes (vector search results)
|
| 187 |
+
2. Expand via BFS up to `hops` hops, prioritising high-weight edges
|
| 188 |
+
3. Return unique chunks from both docs, ranked by relevance
|
| 189 |
+
"""
|
| 190 |
+
visited = set()
|
| 191 |
+
result_nodes = []
|
| 192 |
+
|
| 193 |
+
seed_ids = [
|
| 194 |
+
f"{s['doc_id']}_chunk_{s['chunk_index']}"
|
| 195 |
+
for s in seed_chunks
|
| 196 |
+
if s.get('chunk_index') is not None
|
| 197 |
+
]
|
| 198 |
+
|
| 199 |
+
# BFS queue: (node_id, remaining_hops, accumulated_weight)
|
| 200 |
+
queue = [(nid, hops, 1.0) for nid in seed_ids if nid in self.graph]
|
| 201 |
+
|
| 202 |
+
while queue and len(result_nodes) < max_nodes:
|
| 203 |
+
node_id, remaining, acc_weight = queue.pop(0)
|
| 204 |
+
if node_id in visited:
|
| 205 |
+
continue
|
| 206 |
+
visited.add(node_id)
|
| 207 |
+
|
| 208 |
+
chunk = self._chunk_map.get(node_id)
|
| 209 |
+
if chunk:
|
| 210 |
+
result_nodes.append({
|
| 211 |
+
"chunk_id": node_id,
|
| 212 |
+
"text": chunk.text,
|
| 213 |
+
"doc_id": chunk.doc_id,
|
| 214 |
+
"section": chunk.section,
|
| 215 |
+
"relevance": round(acc_weight, 4),
|
| 216 |
+
})
|
| 217 |
+
|
| 218 |
+
if remaining > 0:
|
| 219 |
+
neighbors = sorted(
|
| 220 |
+
self.graph[node_id].items(),
|
| 221 |
+
key=lambda x: x[1].get("weight", 0),
|
| 222 |
+
reverse=True,
|
| 223 |
+
)
|
| 224 |
+
for neighbor_id, edge_data in neighbors[:4]: # top-4 neighbours
|
| 225 |
+
if neighbor_id not in visited:
|
| 226 |
+
queue.append((
|
| 227 |
+
neighbor_id,
|
| 228 |
+
remaining - 1,
|
| 229 |
+
acc_weight * edge_data.get("weight", 0.5),
|
| 230 |
+
))
|
| 231 |
+
|
| 232 |
+
# Sort by relevance
|
| 233 |
+
result_nodes.sort(key=lambda x: x["relevance"], reverse=True)
|
| 234 |
+
return result_nodes[:max_nodes]
|
| 235 |
+
|
| 236 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 237 |
+
edge_types = {}
|
| 238 |
+
for _, _, data in self.graph.edges(data=True):
|
| 239 |
+
rel = data.get("relation", "unknown")
|
| 240 |
+
edge_types[rel] = edge_types.get(rel, 0) + 1
|
| 241 |
+
return {
|
| 242 |
+
"nodes": self.graph.number_of_nodes(),
|
| 243 |
+
"edges": self.graph.number_of_edges(),
|
| 244 |
+
"edge_types": edge_types,
|
| 245 |
+
}
|
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Groq Chat with Graph RAG context injection.
|
| 3 |
+
Uses llama-3.3-70b-versatile (fast + smart) via Groq API.
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
from typing import List, Dict, Any, Generator
|
| 7 |
+
from groq import Groq
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
_DEFAULT_MODEL = "llama-3.3-70b-versatile"
|
| 11 |
+
_MAX_CONTEXT_CHARS = 6000 # stay within context window safely
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _build_context(retrieved_nodes: List[Dict[str, Any]]) -> str:
|
| 15 |
+
"""
|
| 16 |
+
Format retrieved graph nodes into a clean context block for the LLM.
|
| 17 |
+
Groups by document for clarity.
|
| 18 |
+
"""
|
| 19 |
+
doc1_nodes = [n for n in retrieved_nodes if n.get("doc_id") == "doc1"]
|
| 20 |
+
doc2_nodes = [n for n in retrieved_nodes if n.get("doc_id") == "doc2"]
|
| 21 |
+
|
| 22 |
+
parts = []
|
| 23 |
+
|
| 24 |
+
if doc1_nodes:
|
| 25 |
+
parts.append("### Relevant passages from Document 1:")
|
| 26 |
+
for node in doc1_nodes:
|
| 27 |
+
sec = f" [{node['section']}]" if node.get("section") else ""
|
| 28 |
+
parts.append(f"- {node['text'][:500]}{sec}")
|
| 29 |
+
|
| 30 |
+
if doc2_nodes:
|
| 31 |
+
parts.append("\n### Relevant passages from Document 2:")
|
| 32 |
+
for node in doc2_nodes:
|
| 33 |
+
sec = f" [{node['section']}]" if node.get("section") else ""
|
| 34 |
+
parts.append(f"- {node['text'][:500]}{sec}")
|
| 35 |
+
|
| 36 |
+
context = "\n".join(parts)
|
| 37 |
+
return context[:_MAX_CONTEXT_CHARS]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
_SYSTEM_PROMPT = """You are an expert document analyst assistant with access to two documents that have been processed, chunked, and indexed using a Knowledge Graph RAG system.
|
| 41 |
+
|
| 42 |
+
You will be given:
|
| 43 |
+
1. CONTEXT: Relevant passages retrieved from both documents via graph-enhanced semantic search
|
| 44 |
+
2. USER QUESTION: What the user wants to know
|
| 45 |
+
|
| 46 |
+
Your job:
|
| 47 |
+
- Answer using ONLY the provided context
|
| 48 |
+
- Clearly indicate which document (Document 1 or Document 2) information comes from
|
| 49 |
+
- If comparing both documents, highlight similarities and differences
|
| 50 |
+
- If the context doesn't contain the answer, say so honestly
|
| 51 |
+
- Be concise, accurate, and helpful
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class GroqGraphChat:
|
| 56 |
+
"""
|
| 57 |
+
Stateful chat session backed by Groq API + GraphRAG context injection.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, api_key: str, model: str = _DEFAULT_MODEL):
|
| 61 |
+
self._client = Groq(api_key=api_key)
|
| 62 |
+
self._model = model
|
| 63 |
+
self._history: List[Dict[str, str]] = []
|
| 64 |
+
|
| 65 |
+
def reset(self) -> None:
|
| 66 |
+
self._history = []
|
| 67 |
+
|
| 68 |
+
def chat(
|
| 69 |
+
self,
|
| 70 |
+
user_query: str,
|
| 71 |
+
retrieved_nodes: List[Dict[str, Any]],
|
| 72 |
+
stream: bool = True,
|
| 73 |
+
) -> str | Generator:
|
| 74 |
+
"""
|
| 75 |
+
Send a message with GraphRAG context and get a response.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
user_query: The user's question
|
| 79 |
+
retrieved_nodes: Chunks from GraphBuilder.retrieve()
|
| 80 |
+
stream: If True, returns a generator for streaming UI
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Full response string (if stream=False) or generator (if stream=True)
|
| 84 |
+
"""
|
| 85 |
+
context = _build_context(retrieved_nodes)
|
| 86 |
+
|
| 87 |
+
# Build the user turn with injected context
|
| 88 |
+
augmented_user_message = f"""<context>
|
| 89 |
+
{context}
|
| 90 |
+
</context>
|
| 91 |
+
|
| 92 |
+
<question>
|
| 93 |
+
{user_query}
|
| 94 |
+
</question>"""
|
| 95 |
+
|
| 96 |
+
# Append to history
|
| 97 |
+
self._history.append({"role": "user", "content": augmented_user_message})
|
| 98 |
+
|
| 99 |
+
messages = [{"role": "system", "content": _SYSTEM_PROMPT}] + self._history
|
| 100 |
+
|
| 101 |
+
if stream:
|
| 102 |
+
return self._stream_response(messages)
|
| 103 |
+
else:
|
| 104 |
+
return self._full_response(messages)
|
| 105 |
+
|
| 106 |
+
def _full_response(self, messages: List[Dict]) -> str:
|
| 107 |
+
response = self._client.chat.completions.create(
|
| 108 |
+
model=self._model,
|
| 109 |
+
messages=messages,
|
| 110 |
+
max_tokens=1024,
|
| 111 |
+
temperature=0.3,
|
| 112 |
+
)
|
| 113 |
+
answer = response.choices[0].message.content
|
| 114 |
+
self._history.append({"role": "assistant", "content": answer})
|
| 115 |
+
return answer
|
| 116 |
+
|
| 117 |
+
def _stream_response(self, messages: List[Dict]) -> Generator:
|
| 118 |
+
stream = self._client.chat.completions.create(
|
| 119 |
+
model=self._model,
|
| 120 |
+
messages=messages,
|
| 121 |
+
max_tokens=1024,
|
| 122 |
+
temperature=0.3,
|
| 123 |
+
stream=True,
|
| 124 |
+
)
|
| 125 |
+
full_response = ""
|
| 126 |
+
for chunk in stream:
|
| 127 |
+
delta = chunk.choices[0].delta.content or ""
|
| 128 |
+
full_response += delta
|
| 129 |
+
yield delta
|
| 130 |
+
self._history.append({"role": "assistant", "content": full_response})
|
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RAG Pipeline β wires everything together.
|
| 3 |
+
Used by the Streamlit chat tab.
|
| 4 |
+
"""
|
| 5 |
+
from typing import List, Dict, Any, Optional
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
|
| 8 |
+
from .chunker import Chunk, chunk_document
|
| 9 |
+
from .vector_store import VectorStore
|
| 10 |
+
from .graph_builder import GraphBuilder
|
| 11 |
+
from .groq_chat import GroqGraphChat
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class PipelineState:
|
| 16 |
+
"""Holds the built RAG state after ingestion."""
|
| 17 |
+
doc1_chunks: List[Chunk] = field(default_factory=list)
|
| 18 |
+
doc2_chunks: List[Chunk] = field(default_factory=list)
|
| 19 |
+
vector_store: Optional[VectorStore] = None
|
| 20 |
+
graph_builder: Optional[GraphBuilder] = None
|
| 21 |
+
is_ready: bool = False
|
| 22 |
+
stats: Dict[str, Any] = field(default_factory=dict)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class GraphRAGPipeline:
|
| 26 |
+
"""
|
| 27 |
+
End-to-end Graph RAG pipeline.
|
| 28 |
+
|
| 29 |
+
Usage:
|
| 30 |
+
pipeline = GraphRAGPipeline(groq_api_key="...")
|
| 31 |
+
state = pipeline.ingest(raw_doc1, raw_doc2)
|
| 32 |
+
answer = pipeline.query("What does doc1 say about climate?", state)
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
groq_api_key: str,
|
| 38 |
+
chunk_size: int = 300,
|
| 39 |
+
chunk_overlap: int = 50,
|
| 40 |
+
top_k_vector: int = 5,
|
| 41 |
+
graph_hops: int = 2,
|
| 42 |
+
graph_max_nodes: int = 10,
|
| 43 |
+
):
|
| 44 |
+
self.groq_api_key = groq_api_key
|
| 45 |
+
self.chunk_size = chunk_size
|
| 46 |
+
self.chunk_overlap = chunk_overlap
|
| 47 |
+
self.top_k_vector = top_k_vector
|
| 48 |
+
self.graph_hops = graph_hops
|
| 49 |
+
self.graph_max_nodes = graph_max_nodes
|
| 50 |
+
|
| 51 |
+
self._chat: Optional[GroqGraphChat] = None
|
| 52 |
+
|
| 53 |
+
# ------------------------------------------------------------------
|
| 54 |
+
# Ingestion
|
| 55 |
+
# ------------------------------------------------------------------
|
| 56 |
+
|
| 57 |
+
def ingest(self, raw_doc1, raw_doc2) -> PipelineState:
|
| 58 |
+
"""
|
| 59 |
+
Process both documents: chunk β embed β store β build graph.
|
| 60 |
+
Returns a PipelineState that should be stored in st.session_state.
|
| 61 |
+
"""
|
| 62 |
+
state = PipelineState()
|
| 63 |
+
|
| 64 |
+
# 1. Chunk
|
| 65 |
+
state.doc1_chunks = chunk_document(
|
| 66 |
+
raw_doc1, "doc1", self.chunk_size, self.chunk_overlap
|
| 67 |
+
)
|
| 68 |
+
state.doc2_chunks = chunk_document(
|
| 69 |
+
raw_doc2, "doc2", self.chunk_size, self.chunk_overlap
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# 2. Vector store
|
| 73 |
+
state.vector_store = VectorStore()
|
| 74 |
+
state.vector_store.add_chunks(state.doc1_chunks)
|
| 75 |
+
state.vector_store.add_chunks(state.doc2_chunks)
|
| 76 |
+
|
| 77 |
+
# 3. Knowledge graph
|
| 78 |
+
state.graph_builder = GraphBuilder()
|
| 79 |
+
state.graph_builder.build(state.doc1_chunks, state.doc2_chunks)
|
| 80 |
+
|
| 81 |
+
# 4. Stats
|
| 82 |
+
graph_stats = state.graph_builder.get_stats()
|
| 83 |
+
state.stats = {
|
| 84 |
+
"doc1_chunks": len(state.doc1_chunks),
|
| 85 |
+
"doc2_chunks": len(state.doc2_chunks),
|
| 86 |
+
"total_vectors": state.vector_store.count(),
|
| 87 |
+
**graph_stats,
|
| 88 |
+
}
|
| 89 |
+
state.is_ready = True
|
| 90 |
+
|
| 91 |
+
# 5. Fresh chat session
|
| 92 |
+
self._chat = GroqGraphChat(api_key=self.groq_api_key)
|
| 93 |
+
|
| 94 |
+
return state
|
| 95 |
+
|
| 96 |
+
# ------------------------------------------------------------------
|
| 97 |
+
# Query
|
| 98 |
+
# ------------------------------------------------------------------
|
| 99 |
+
|
| 100 |
+
def query(
|
| 101 |
+
self,
|
| 102 |
+
user_query: str,
|
| 103 |
+
state: PipelineState,
|
| 104 |
+
stream: bool = True,
|
| 105 |
+
):
|
| 106 |
+
"""
|
| 107 |
+
Retrieve relevant context via vector + graph search,
|
| 108 |
+
then pass to Groq for generation.
|
| 109 |
+
"""
|
| 110 |
+
if not state.is_ready:
|
| 111 |
+
raise RuntimeError("Pipeline not ready. Call ingest() first.")
|
| 112 |
+
|
| 113 |
+
# Step 1: Vector search (both docs)
|
| 114 |
+
seed_chunks = state.vector_store.search(
|
| 115 |
+
user_query, n_results=self.top_k_vector
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Step 2: Graph expansion
|
| 119 |
+
retrieved_nodes = state.graph_builder.retrieve(
|
| 120 |
+
query=user_query,
|
| 121 |
+
seed_chunks=seed_chunks,
|
| 122 |
+
hops=self.graph_hops,
|
| 123 |
+
max_nodes=self.graph_max_nodes,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Fallback: if graph expansion returned nothing, use raw vector results
|
| 127 |
+
if not retrieved_nodes:
|
| 128 |
+
retrieved_nodes = [
|
| 129 |
+
{
|
| 130 |
+
"chunk_id": f"{s['doc_id']}_chunk_{s['chunk_index']}",
|
| 131 |
+
"text": s["text"],
|
| 132 |
+
"doc_id": s["doc_id"],
|
| 133 |
+
"section": s.get("section", ""),
|
| 134 |
+
"relevance": s["score"],
|
| 135 |
+
}
|
| 136 |
+
for s in seed_chunks
|
| 137 |
+
]
|
| 138 |
+
|
| 139 |
+
# Step 3: Generate answer via Groq
|
| 140 |
+
return self._chat.chat(
|
| 141 |
+
user_query=user_query,
|
| 142 |
+
retrieved_nodes=retrieved_nodes,
|
| 143 |
+
stream=stream,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
def reset_chat(self) -> None:
|
| 147 |
+
"""Clear conversation history (keep the indexed data)."""
|
| 148 |
+
if self._chat:
|
| 149 |
+
self._chat.reset()
|
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Vector Store using ChromaDB (in-memory, HF Spaces compatible)
|
| 3 |
+
Stores and retrieves chunks from both documents via semantic search.
|
| 4 |
+
"""
|
| 5 |
+
import chromadb
|
| 6 |
+
from chromadb.config import Settings
|
| 7 |
+
from sentence_transformers import SentenceTransformer
|
| 8 |
+
from typing import List, Dict, Any, Optional
|
| 9 |
+
import hashlib
|
| 10 |
+
|
| 11 |
+
from .chunker import Chunk
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_EMBED_MODEL_NAME = "all-MiniLM-L6-v2" # fast, small, works great
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class VectorStore:
|
| 18 |
+
"""
|
| 19 |
+
Wraps ChromaDB with a SentenceTransformer embedding function.
|
| 20 |
+
Collection name: 'doc_chunks' β shared for both documents.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, persist_dir: Optional[str] = None):
|
| 24 |
+
self._model = SentenceTransformer(_EMBED_MODEL_NAME)
|
| 25 |
+
|
| 26 |
+
if persist_dir:
|
| 27 |
+
self._client = chromadb.PersistentClient(path=persist_dir)
|
| 28 |
+
else:
|
| 29 |
+
self._client = chromadb.EphemeralClient()
|
| 30 |
+
|
| 31 |
+
self._collection = self._client.get_or_create_collection(
|
| 32 |
+
name="doc_chunks",
|
| 33 |
+
metadata={"hnsw:space": "cosine"},
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# ------------------------------------------------------------------
|
| 37 |
+
# Write
|
| 38 |
+
# ------------------------------------------------------------------
|
| 39 |
+
|
| 40 |
+
def add_chunks(self, chunks: List[Chunk]) -> None:
|
| 41 |
+
"""Embed and upsert chunks into the collection."""
|
| 42 |
+
if not chunks:
|
| 43 |
+
return
|
| 44 |
+
|
| 45 |
+
texts = [c.text for c in chunks]
|
| 46 |
+
embeddings = self._model.encode(texts, batch_size=32, show_progress_bar=False).tolist()
|
| 47 |
+
|
| 48 |
+
ids = [c.chunk_id for c in chunks]
|
| 49 |
+
metadatas = [
|
| 50 |
+
{
|
| 51 |
+
"doc_id": c.doc_id,
|
| 52 |
+
"chunk_index": c.chunk_index,
|
| 53 |
+
"section": c.section,
|
| 54 |
+
"page": c.page,
|
| 55 |
+
**{k: str(v) for k, v in c.metadata.items()},
|
| 56 |
+
}
|
| 57 |
+
for c in chunks
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
self._collection.upsert(
|
| 61 |
+
ids=ids,
|
| 62 |
+
embeddings=embeddings,
|
| 63 |
+
documents=texts,
|
| 64 |
+
metadatas=metadatas,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
def clear(self) -> None:
|
| 68 |
+
"""Remove all chunks (useful for re-ingestion)."""
|
| 69 |
+
self._client.delete_collection("doc_chunks")
|
| 70 |
+
self._collection = self._client.get_or_create_collection(
|
| 71 |
+
name="doc_chunks",
|
| 72 |
+
metadata={"hnsw:space": "cosine"},
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# ------------------------------------------------------------------
|
| 76 |
+
# Read
|
| 77 |
+
# ------------------------------------------------------------------
|
| 78 |
+
|
| 79 |
+
def search(
|
| 80 |
+
self,
|
| 81 |
+
query: str,
|
| 82 |
+
n_results: int = 5,
|
| 83 |
+
doc_filter: Optional[str] = None, # "doc1" | "doc2" | None
|
| 84 |
+
) -> List[Dict[str, Any]]:
|
| 85 |
+
"""
|
| 86 |
+
Semantic search over stored chunks.
|
| 87 |
+
Returns list of dicts with keys: text, doc_id, section, score.
|
| 88 |
+
"""
|
| 89 |
+
query_embedding = self._model.encode([query]).tolist()
|
| 90 |
+
|
| 91 |
+
where = {"doc_id": doc_filter} if doc_filter else None
|
| 92 |
+
|
| 93 |
+
results = self._collection.query(
|
| 94 |
+
query_embeddings=query_embedding,
|
| 95 |
+
n_results=min(n_results, self._collection.count() or 1),
|
| 96 |
+
where=where,
|
| 97 |
+
include=["documents", "metadatas", "distances"],
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
hits = []
|
| 101 |
+
for text, meta, dist in zip(
|
| 102 |
+
results["documents"][0],
|
| 103 |
+
results["metadatas"][0],
|
| 104 |
+
results["distances"][0],
|
| 105 |
+
):
|
| 106 |
+
hits.append({
|
| 107 |
+
"text": text,
|
| 108 |
+
"doc_id": meta.get("doc_id"),
|
| 109 |
+
"section": meta.get("section", ""),
|
| 110 |
+
"chunk_index": meta.get("chunk_index", -1),
|
| 111 |
+
"score": round(1 - dist, 4), # cosine similarity
|
| 112 |
+
})
|
| 113 |
+
|
| 114 |
+
return hits
|
| 115 |
+
|
| 116 |
+
def count(self) -> int:
|
| 117 |
+
return self._collection.count()
|
| 118 |
+
|
| 119 |
+
def get_all_chunks_for_doc(self, doc_id: str) -> List[Dict[str, Any]]:
|
| 120 |
+
"""Retrieve all stored chunks for a given document."""
|
| 121 |
+
results = self._collection.get(
|
| 122 |
+
where={"doc_id": doc_id},
|
| 123 |
+
include=["documents", "metadatas"],
|
| 124 |
+
)
|
| 125 |
+
items = []
|
| 126 |
+
for text, meta in zip(results["documents"], results["metadatas"]):
|
| 127 |
+
items.append({"text": text, **meta})
|
| 128 |
+
# Sort by chunk_index
|
| 129 |
+
items.sort(key=lambda x: int(x.get("chunk_index", 0)))
|
| 130 |
+
return items
|
|
@@ -1,10 +1,11 @@
|
|
| 1 |
"""
|
| 2 |
Multi-Agent Document Comparison Streamlit App
|
|
|
|
| 3 |
"""
|
| 4 |
import sys
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
|
| 7 |
-
# Add project root to Python path for imports
|
| 8 |
project_root = Path(__file__).parent
|
| 9 |
if str(project_root) not in sys.path:
|
| 10 |
sys.path.insert(0, str(project_root))
|
|
@@ -13,7 +14,6 @@ import streamlit as st
|
|
| 13 |
import asyncio
|
| 14 |
import json
|
| 15 |
|
| 16 |
-
# Import agents and utilities
|
| 17 |
from agents.ingestion_agent import IngestionAgent
|
| 18 |
from agents.text_agent import TextAgent
|
| 19 |
from agents.table_agent import TableAgent
|
|
@@ -28,7 +28,10 @@ from utils.visualization import (
|
|
| 28 |
from models.document import ProcessedDocument
|
| 29 |
import config
|
| 30 |
|
| 31 |
-
#
|
|
|
|
|
|
|
|
|
|
| 32 |
try:
|
| 33 |
from agents.image_agent import ImageAgent
|
| 34 |
IMAGE_AGENT_AVAILABLE = True
|
|
@@ -48,7 +51,6 @@ except ImportError:
|
|
| 48 |
META_AGENT_AVAILABLE = False
|
| 49 |
|
| 50 |
|
| 51 |
-
# Page configuration
|
| 52 |
st.set_page_config(
|
| 53 |
page_title="Multi-Agent Document Comparator",
|
| 54 |
page_icon="π",
|
|
@@ -58,448 +60,379 @@ st.set_page_config(
|
|
| 58 |
|
| 59 |
|
| 60 |
def main():
|
| 61 |
-
"
|
| 62 |
-
|
| 63 |
-
# Header
|
| 64 |
-
st.title("π Multi-Agent Document Comparator")
|
| 65 |
-
st.markdown("**An agentic system to accurately match document similarity**")
|
| 66 |
|
| 67 |
-
# Show architecture diagram
|
| 68 |
with st.expander("ποΈ View System Architecture", expanded=False):
|
| 69 |
arch_path = Path("src/img/multi_agent_doc_similarity_architecture.svg")
|
| 70 |
if arch_path.exists():
|
| 71 |
st.image(str(arch_path), use_container_width=True)
|
| 72 |
-
else:
|
| 73 |
-
st.info("Architecture diagram not found")
|
| 74 |
|
| 75 |
st.markdown("---")
|
| 76 |
|
| 77 |
-
# Sidebar
|
| 78 |
with st.sidebar:
|
| 79 |
st.header("βοΈ Configuration")
|
| 80 |
|
| 81 |
-
# Phase 2 feature toggles
|
| 82 |
st.subheader("Phase 2 Features")
|
| 83 |
enable_phase2 = st.checkbox(
|
| 84 |
"Enable Phase 2 Modalities",
|
| 85 |
value=config.ENABLE_IMAGE_COMPARISON,
|
| 86 |
help="Enable image, layout, and metadata comparison"
|
| 87 |
)
|
| 88 |
-
|
| 89 |
-
# Modality weights
|
| 90 |
st.markdown("---")
|
| 91 |
st.subheader("Modality Weights")
|
| 92 |
|
| 93 |
if enable_phase2:
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
table_weight = st.slider(
|
| 103 |
-
"Table Weight",
|
| 104 |
-
min_value=0.0,
|
| 105 |
-
max_value=1.0,
|
| 106 |
-
value=config.MODALITY_WEIGHTS["table"],
|
| 107 |
-
step=0.05
|
| 108 |
-
)
|
| 109 |
-
image_weight = st.slider(
|
| 110 |
-
"Image Weight",
|
| 111 |
-
min_value=0.0,
|
| 112 |
-
max_value=1.0,
|
| 113 |
-
value=config.MODALITY_WEIGHTS["image"],
|
| 114 |
-
step=0.05
|
| 115 |
-
)
|
| 116 |
-
layout_weight = st.slider(
|
| 117 |
-
"Layout Weight",
|
| 118 |
-
min_value=0.0,
|
| 119 |
-
max_value=1.0,
|
| 120 |
-
value=config.MODALITY_WEIGHTS["layout"],
|
| 121 |
-
step=0.05
|
| 122 |
-
)
|
| 123 |
-
metadata_weight = st.slider(
|
| 124 |
-
"Metadata Weight",
|
| 125 |
-
min_value=0.0,
|
| 126 |
-
max_value=1.0,
|
| 127 |
-
value=config.MODALITY_WEIGHTS["metadata"],
|
| 128 |
-
step=0.05
|
| 129 |
-
)
|
| 130 |
-
|
| 131 |
-
# Normalize weights to sum to 1.0
|
| 132 |
-
total_weight = text_weight + table_weight + image_weight + layout_weight + metadata_weight
|
| 133 |
-
if total_weight > 0:
|
| 134 |
weights = {
|
| 135 |
-
"text":
|
| 136 |
-
"table":
|
| 137 |
-
"image":
|
| 138 |
-
"layout":
|
| 139 |
-
"metadata":
|
| 140 |
}
|
| 141 |
else:
|
| 142 |
weights = config.MODALITY_WEIGHTS
|
| 143 |
-
|
| 144 |
-
st.info(f"Weights normalized to sum to 1.0")
|
| 145 |
-
|
| 146 |
else:
|
| 147 |
-
|
| 148 |
-
text_weight = st.slider(
|
| 149 |
-
"Text Weight",
|
| 150 |
-
min_value=0.0,
|
| 151 |
-
max_value=1.0,
|
| 152 |
-
value=config.MODALITY_WEIGHTS_PHASE1["text"],
|
| 153 |
-
step=0.05
|
| 154 |
-
)
|
| 155 |
table_weight = 1.0 - text_weight
|
| 156 |
st.write(f"Table Weight: {table_weight:.2f}")
|
| 157 |
-
|
| 158 |
weights = {"text": text_weight, "table": table_weight}
|
| 159 |
|
| 160 |
-
# Phase status
|
| 161 |
st.markdown("---")
|
| 162 |
-
st.subheader("π
|
| 163 |
st.write("β
Text comparison")
|
| 164 |
st.write("β
Table comparison")
|
| 165 |
-
|
| 166 |
if enable_phase2:
|
| 167 |
-
st.write(f"{'β
' if IMAGE_AGENT_AVAILABLE
|
| 168 |
st.write(f"{'β
' if LAYOUT_AGENT_AVAILABLE else 'β οΈ'} Layout comparison")
|
| 169 |
-
st.write(f"{'β
' if META_AGENT_AVAILABLE
|
| 170 |
else:
|
| 171 |
-
st.write("βΈοΈ Image
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
-
|
| 179 |
-
st.subheader("π€ Document 1 (Main)")
|
| 180 |
-
uploaded_file1 = st.file_uploader(
|
| 181 |
-
"Upload PDF or DOCX",
|
| 182 |
-
type=["pdf", "docx"],
|
| 183 |
-
key="file1",
|
| 184 |
-
help="Maximum file size: 50MB"
|
| 185 |
-
)
|
| 186 |
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
)
|
| 195 |
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
return
|
| 203 |
-
|
| 204 |
-
# Process documents and compare
|
| 205 |
-
with st.spinner("Processing documents..."):
|
| 206 |
-
try:
|
| 207 |
-
# Save uploaded files
|
| 208 |
-
file1_path = save_uploaded_file(uploaded_file1)
|
| 209 |
-
file2_path = save_uploaded_file(uploaded_file2)
|
| 210 |
-
|
| 211 |
-
# Validate files
|
| 212 |
-
valid1, error1 = validate_file(file1_path)
|
| 213 |
-
valid2, error2 = validate_file(file2_path)
|
| 214 |
-
|
| 215 |
-
if not valid1:
|
| 216 |
-
st.error(f"Document 1 error: {error1}")
|
| 217 |
-
return
|
| 218 |
-
if not valid2:
|
| 219 |
-
st.error(f"Document 2 error: {error2}")
|
| 220 |
-
return
|
| 221 |
-
|
| 222 |
-
# Process documents
|
| 223 |
-
report = asyncio.run(process_and_compare(
|
| 224 |
-
file1_path,
|
| 225 |
-
file2_path,
|
| 226 |
-
weights,
|
| 227 |
-
enable_phase2
|
| 228 |
-
))
|
| 229 |
-
|
| 230 |
-
# Display results
|
| 231 |
-
display_results(report)
|
| 232 |
-
|
| 233 |
-
except Exception as e:
|
| 234 |
-
st.error(f"An error occurred: {str(e)}")
|
| 235 |
-
import traceback
|
| 236 |
-
st.code(traceback.format_exc())
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
async def process_and_compare(file1_path: str, file2_path: str, weights: dict, enable_phase2: bool = False):
|
| 240 |
-
"""
|
| 241 |
-
Process two documents and compare them.
|
| 242 |
-
|
| 243 |
-
Args:
|
| 244 |
-
file1_path: Path to first document
|
| 245 |
-
file2_path: Path to second document
|
| 246 |
-
weights: Modality weights
|
| 247 |
-
enable_phase2: Enable Phase 2 modalities (image, layout, metadata)
|
| 248 |
-
|
| 249 |
-
Returns:
|
| 250 |
-
SimilarityReport
|
| 251 |
-
"""
|
| 252 |
-
# Initialize agents
|
| 253 |
ingestion_agent = IngestionAgent()
|
| 254 |
-
text_agent
|
| 255 |
-
table_agent
|
| 256 |
-
orchestrator
|
| 257 |
|
| 258 |
-
|
| 259 |
-
image_agent = ImageAgent() if enable_phase2 and IMAGE_AGENT_AVAILABLE else None
|
| 260 |
layout_agent = LayoutAgent() if enable_phase2 and LAYOUT_AGENT_AVAILABLE else None
|
| 261 |
-
meta_agent
|
| 262 |
|
| 263 |
-
# Progress tracking
|
| 264 |
progress_bar = st.progress(0)
|
| 265 |
-
status_text
|
| 266 |
|
| 267 |
-
# Step 1: Ingest documents
|
| 268 |
status_text.text("β³ Ingesting documents...")
|
| 269 |
progress_bar.progress(10)
|
| 270 |
-
|
| 271 |
raw_doc1 = await ingestion_agent.process(file1_path)
|
| 272 |
raw_doc2 = await ingestion_agent.process(file2_path)
|
| 273 |
-
|
| 274 |
progress_bar.progress(15)
|
| 275 |
|
| 276 |
-
|
| 277 |
-
status_text.text("β³ Extracting and embedding text...")
|
| 278 |
-
|
| 279 |
text_chunks1, text_embeddings1 = await text_agent.process(raw_doc1)
|
| 280 |
text_chunks2, text_embeddings2 = await text_agent.process(raw_doc2)
|
| 281 |
-
|
| 282 |
progress_bar.progress(30)
|
| 283 |
|
| 284 |
-
|
| 285 |
-
status_text.text("β³ Extracting and embedding tables...")
|
| 286 |
-
|
| 287 |
tables1, table_embeddings1 = await table_agent.process(raw_doc1)
|
| 288 |
tables2, table_embeddings2 = await table_agent.process(raw_doc2)
|
| 289 |
-
|
| 290 |
progress_bar.progress(45)
|
| 291 |
|
| 292 |
-
|
| 293 |
-
images1, image_embeddings1 = [], None
|
| 294 |
-
images2, image_embeddings2 = [], None
|
| 295 |
if image_agent:
|
| 296 |
-
status_text.text("β³ Extracting
|
| 297 |
try:
|
| 298 |
images1, image_embeddings1 = await image_agent.process(raw_doc1)
|
| 299 |
images2, image_embeddings2 = await image_agent.process(raw_doc2)
|
| 300 |
except Exception as e:
|
| 301 |
st.warning(f"Image extraction failed: {e}")
|
| 302 |
-
|
| 303 |
progress_bar.progress(60)
|
| 304 |
|
| 305 |
-
|
| 306 |
-
layout1, layout2 = None, None
|
| 307 |
if layout_agent:
|
| 308 |
-
status_text.text("β³
|
| 309 |
try:
|
| 310 |
layout1 = await layout_agent.process(raw_doc1)
|
| 311 |
layout2 = await layout_agent.process(raw_doc2)
|
| 312 |
except Exception as e:
|
| 313 |
st.warning(f"Layout analysis failed: {e}")
|
| 314 |
-
|
| 315 |
progress_bar.progress(70)
|
| 316 |
|
| 317 |
-
|
| 318 |
-
metadata1, metadata2 = None, None
|
| 319 |
if meta_agent:
|
| 320 |
-
status_text.text("β³ Extracting metadata
|
| 321 |
try:
|
| 322 |
metadata1 = await meta_agent.process(raw_doc1)
|
| 323 |
metadata2 = await meta_agent.process(raw_doc2)
|
| 324 |
except Exception as e:
|
| 325 |
st.warning(f"Metadata extraction failed: {e}")
|
| 326 |
-
|
| 327 |
progress_bar.progress(80)
|
| 328 |
|
| 329 |
-
# Create processed documents
|
| 330 |
processed_doc1 = ProcessedDocument(
|
| 331 |
-
filename=raw_doc1.filename,
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
total_pages=raw_doc1.total_pages,
|
| 335 |
-
file_type=raw_doc1.file_type,
|
| 336 |
-
images=images1,
|
| 337 |
-
layout=layout1,
|
| 338 |
-
metadata=metadata1
|
| 339 |
)
|
| 340 |
-
|
| 341 |
processed_doc2 = ProcessedDocument(
|
| 342 |
-
filename=raw_doc2.filename,
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
total_pages=raw_doc2.total_pages,
|
| 346 |
-
file_type=raw_doc2.file_type,
|
| 347 |
-
images=images2,
|
| 348 |
-
layout=layout2,
|
| 349 |
-
metadata=metadata2
|
| 350 |
)
|
| 351 |
|
| 352 |
-
|
| 353 |
-
status_text.text("β³ Comparing documents...")
|
| 354 |
-
|
| 355 |
report = await orchestrator.compare_documents(
|
| 356 |
-
processed_doc1,
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
text_embeddings2,
|
| 361 |
-
table_embeddings2,
|
| 362 |
-
# Phase 2 parameters
|
| 363 |
-
image_embeddings1,
|
| 364 |
-
image_embeddings2,
|
| 365 |
-
layout1,
|
| 366 |
-
layout2,
|
| 367 |
-
metadata1,
|
| 368 |
-
metadata2
|
| 369 |
)
|
| 370 |
|
| 371 |
progress_bar.progress(100)
|
| 372 |
status_text.text("β
Comparison complete!")
|
| 373 |
|
| 374 |
-
|
|
|
|
| 375 |
|
| 376 |
|
| 377 |
def display_results(report):
|
| 378 |
-
"""
|
| 379 |
-
Display comparison results.
|
| 380 |
-
|
| 381 |
-
Args:
|
| 382 |
-
report: SimilarityReport object
|
| 383 |
-
"""
|
| 384 |
st.markdown("---")
|
| 385 |
st.header("π Comparison Results")
|
| 386 |
|
| 387 |
-
# Overall similarity gauge
|
| 388 |
col1, col2 = st.columns([1, 1])
|
| 389 |
-
|
| 390 |
with col1:
|
| 391 |
gauge_fig = create_similarity_gauge(report.overall_score)
|
| 392 |
st.plotly_chart(gauge_fig, use_container_width=True)
|
| 393 |
-
|
| 394 |
with col2:
|
| 395 |
st.markdown(create_score_legend())
|
| 396 |
|
| 397 |
-
# Modality breakdown
|
| 398 |
st.markdown("---")
|
| 399 |
st.subheader("π Per-Modality Breakdown")
|
| 400 |
-
|
| 401 |
breakdown_fig = create_modality_breakdown_chart(report)
|
| 402 |
st.plotly_chart(breakdown_fig, use_container_width=True)
|
| 403 |
|
| 404 |
-
# Detailed scores
|
| 405 |
cols = st.columns(5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
|
| 407 |
-
with cols[0]:
|
| 408 |
-
if report.text_score:
|
| 409 |
-
st.metric(
|
| 410 |
-
"Text Similarity",
|
| 411 |
-
f"{report.text_score.score:.1%}",
|
| 412 |
-
f"{report.text_score.details.get('num_matches', 0)} matches"
|
| 413 |
-
)
|
| 414 |
-
|
| 415 |
-
with cols[1]:
|
| 416 |
-
if report.table_score:
|
| 417 |
-
st.metric(
|
| 418 |
-
"Table Similarity",
|
| 419 |
-
f"{report.table_score.score:.1%}",
|
| 420 |
-
f"{report.table_score.details.get('num_matches', 0)} matches"
|
| 421 |
-
)
|
| 422 |
-
|
| 423 |
-
with cols[2]:
|
| 424 |
-
if report.image_score:
|
| 425 |
-
st.metric(
|
| 426 |
-
"Image Similarity",
|
| 427 |
-
f"{report.image_score.score:.1%}",
|
| 428 |
-
f"{report.image_score.details.get('num_matches', 0)} matches"
|
| 429 |
-
)
|
| 430 |
-
|
| 431 |
-
with cols[3]:
|
| 432 |
-
if report.layout_score:
|
| 433 |
-
st.metric(
|
| 434 |
-
"Layout Similarity",
|
| 435 |
-
f"{report.layout_score.score:.1%}",
|
| 436 |
-
f"{report.layout_score.details.get('num_metrics', 0)} metrics"
|
| 437 |
-
)
|
| 438 |
-
|
| 439 |
-
with cols[4]:
|
| 440 |
-
if report.metadata_score:
|
| 441 |
-
st.metric(
|
| 442 |
-
"Metadata Similarity",
|
| 443 |
-
f"{report.metadata_score.score:.1%}",
|
| 444 |
-
f"{report.metadata_score.details.get('num_fields_compared', 0)} fields"
|
| 445 |
-
)
|
| 446 |
-
|
| 447 |
-
# Matched sections
|
| 448 |
st.markdown("---")
|
| 449 |
st.subheader("π Top Matched Sections")
|
| 450 |
-
|
| 451 |
if report.matched_sections:
|
| 452 |
-
|
| 453 |
-
st.markdown(formatted_sections)
|
| 454 |
else:
|
| 455 |
-
st.info("No significant matches found
|
| 456 |
|
| 457 |
-
# Phase 2: Additional modality details
|
| 458 |
if report.image_score or report.layout_score or report.metadata_score:
|
| 459 |
st.markdown("---")
|
| 460 |
st.subheader("π¨ Phase 2 Modality Details")
|
| 461 |
-
|
| 462 |
-
# Image matches
|
| 463 |
if report.image_score and report.image_score.matched_items:
|
| 464 |
-
with st.expander(f"πΌοΈ Image Matches ({len(report.image_score.matched_items)}
|
| 465 |
-
for idx,
|
| 466 |
-
st.markdown(f"**Match {idx}**
|
| 467 |
-
st.write(f"Doc1: Page {match['doc1_page']}, Size: {match['doc1_size']}")
|
| 468 |
-
st.write(f"Doc2: Page {match['doc2_page']}, Size: {match['doc2_size']}")
|
| 469 |
-
st.markdown("---")
|
| 470 |
-
|
| 471 |
-
# Layout details
|
| 472 |
if report.layout_score:
|
| 473 |
-
with st.expander(f"π Layout
|
| 474 |
-
for
|
| 475 |
-
if
|
| 476 |
-
st.metric(
|
| 477 |
-
|
| 478 |
-
# Metadata matches
|
| 479 |
if report.metadata_score and report.metadata_score.matched_items:
|
| 480 |
-
with st.expander(f"π Metadata
|
| 481 |
-
for
|
| 482 |
-
st.markdown(f"**{
|
| 483 |
-
|
| 484 |
-
with col1:
|
| 485 |
-
st.write(f"Doc1: {match['doc1_value']}")
|
| 486 |
-
with col2:
|
| 487 |
-
st.write(f"Doc2: {match['doc2_value']}")
|
| 488 |
-
st.markdown("---")
|
| 489 |
-
|
| 490 |
-
# Download report
|
| 491 |
st.markdown("---")
|
| 492 |
report_json = json.dumps(report.model_dump(), indent=2, default=str)
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
label="π₯ Download Report (JSON)",
|
| 499 |
-
data=report_json,
|
| 500 |
-
file_name=f"similarity_report_{report.timestamp.strftime('%Y%m%d_%H%M%S')}.json",
|
| 501 |
-
mime="application/json"
|
| 502 |
-
)
|
| 503 |
|
| 504 |
|
| 505 |
if __name__ == "__main__":
|
|
|
|
| 1 |
"""
|
| 2 |
Multi-Agent Document Comparison Streamlit App
|
| 3 |
+
+ Graph RAG Chat Tab (new)
|
| 4 |
"""
|
| 5 |
import sys
|
| 6 |
+
import os
|
| 7 |
from pathlib import Path
|
| 8 |
|
|
|
|
| 9 |
project_root = Path(__file__).parent
|
| 10 |
if str(project_root) not in sys.path:
|
| 11 |
sys.path.insert(0, str(project_root))
|
|
|
|
| 14 |
import asyncio
|
| 15 |
import json
|
| 16 |
|
|
|
|
| 17 |
from agents.ingestion_agent import IngestionAgent
|
| 18 |
from agents.text_agent import TextAgent
|
| 19 |
from agents.table_agent import TableAgent
|
|
|
|
| 28 |
from models.document import ProcessedDocument
|
| 29 |
import config
|
| 30 |
|
| 31 |
+
# Graph RAG imports
|
| 32 |
+
from rag.rag_pipeline import GraphRAGPipeline, PipelineState
|
| 33 |
+
|
| 34 |
+
# Phase 2 imports (conditional)
|
| 35 |
try:
|
| 36 |
from agents.image_agent import ImageAgent
|
| 37 |
IMAGE_AGENT_AVAILABLE = True
|
|
|
|
| 51 |
META_AGENT_AVAILABLE = False
|
| 52 |
|
| 53 |
|
|
|
|
| 54 |
st.set_page_config(
|
| 55 |
page_title="Multi-Agent Document Comparator",
|
| 56 |
page_icon="π",
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
def main():
|
| 63 |
+
st.title("π Multi-Agent Document Comparator + Graph RAG Chat")
|
| 64 |
+
st.markdown("**Agentic document similarity Β· Knowledge Graph RAG Β· Groq-powered chat**")
|
|
|
|
|
|
|
|
|
|
| 65 |
|
|
|
|
| 66 |
with st.expander("ποΈ View System Architecture", expanded=False):
|
| 67 |
arch_path = Path("src/img/multi_agent_doc_similarity_architecture.svg")
|
| 68 |
if arch_path.exists():
|
| 69 |
st.image(str(arch_path), use_container_width=True)
|
|
|
|
|
|
|
| 70 |
|
| 71 |
st.markdown("---")
|
| 72 |
|
| 73 |
+
# ββ Sidebar βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 74 |
with st.sidebar:
|
| 75 |
st.header("βοΈ Configuration")
|
| 76 |
|
|
|
|
| 77 |
st.subheader("Phase 2 Features")
|
| 78 |
enable_phase2 = st.checkbox(
|
| 79 |
"Enable Phase 2 Modalities",
|
| 80 |
value=config.ENABLE_IMAGE_COMPARISON,
|
| 81 |
help="Enable image, layout, and metadata comparison"
|
| 82 |
)
|
|
|
|
|
|
|
| 83 |
st.markdown("---")
|
| 84 |
st.subheader("Modality Weights")
|
| 85 |
|
| 86 |
if enable_phase2:
|
| 87 |
+
text_weight = st.slider("Text Weight", 0.0, 1.0, config.MODALITY_WEIGHTS["text"], 0.05)
|
| 88 |
+
table_weight = st.slider("Table Weight", 0.0, 1.0, config.MODALITY_WEIGHTS["table"], 0.05)
|
| 89 |
+
image_weight = st.slider("Image Weight", 0.0, 1.0, config.MODALITY_WEIGHTS["image"], 0.05)
|
| 90 |
+
layout_weight = st.slider("Layout Weight", 0.0, 1.0, config.MODALITY_WEIGHTS["layout"], 0.05)
|
| 91 |
+
meta_weight = st.slider("Metadata Weight", 0.0, 1.0, config.MODALITY_WEIGHTS["metadata"], 0.05)
|
| 92 |
+
|
| 93 |
+
total = text_weight + table_weight + image_weight + layout_weight + meta_weight
|
| 94 |
+
if total > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
weights = {
|
| 96 |
+
"text": text_weight / total,
|
| 97 |
+
"table": table_weight / total,
|
| 98 |
+
"image": image_weight / total,
|
| 99 |
+
"layout": layout_weight / total,
|
| 100 |
+
"metadata": meta_weight / total,
|
| 101 |
}
|
| 102 |
else:
|
| 103 |
weights = config.MODALITY_WEIGHTS
|
| 104 |
+
st.info("Weights normalised to 1.0")
|
|
|
|
|
|
|
| 105 |
else:
|
| 106 |
+
text_weight = st.slider("Text Weight", 0.0, 1.0, config.MODALITY_WEIGHTS_PHASE1["text"], 0.05)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
table_weight = 1.0 - text_weight
|
| 108 |
st.write(f"Table Weight: {table_weight:.2f}")
|
|
|
|
| 109 |
weights = {"text": text_weight, "table": table_weight}
|
| 110 |
|
|
|
|
| 111 |
st.markdown("---")
|
| 112 |
+
st.subheader("π Status")
|
| 113 |
st.write("β
Text comparison")
|
| 114 |
st.write("β
Table comparison")
|
|
|
|
| 115 |
if enable_phase2:
|
| 116 |
+
st.write(f"{'β
' if IMAGE_AGENT_AVAILABLE else 'β οΈ'} Image comparison")
|
| 117 |
st.write(f"{'β
' if LAYOUT_AGENT_AVAILABLE else 'β οΈ'} Layout comparison")
|
| 118 |
+
st.write(f"{'β
' if META_AGENT_AVAILABLE else 'β οΈ'} Metadata comparison")
|
| 119 |
else:
|
| 120 |
+
st.write("βΈοΈ Image / Layout / Metadata (disabled)")
|
| 121 |
+
|
| 122 |
+
st.markdown("---")
|
| 123 |
+
st.subheader("π Graph RAG Settings")
|
| 124 |
+
chunk_size = st.slider("Chunk size (words)", 100, 600, 300, 50)
|
| 125 |
+
chunk_overlap = st.slider("Overlap (words)", 20, 150, 50, 10)
|
| 126 |
+
top_k = st.slider("Vector top-k", 3, 15, 5, 1)
|
| 127 |
+
graph_hops = st.slider("Graph hops", 1, 4, 2, 1)
|
| 128 |
+
|
| 129 |
+
# ββ Main tabs βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 130 |
+
tab1, tab2 = st.tabs(["π Document Comparison", "π¬ Graph RAG Chat"])
|
| 131 |
+
|
| 132 |
+
# ββ Session state init ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 133 |
+
for key in ["raw_doc1", "raw_doc2", "rag_state", "rag_pipeline", "chat_history"]:
|
| 134 |
+
if key not in st.session_state:
|
| 135 |
+
st.session_state[key] = None if key != "chat_history" else []
|
| 136 |
+
|
| 137 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 138 |
+
# TAB 1 β Comparison
|
| 139 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 140 |
+
with tab1:
|
| 141 |
+
col1, col2 = st.columns(2)
|
| 142 |
+
|
| 143 |
+
with col1:
|
| 144 |
+
st.subheader("π€ Document 1 (Main)")
|
| 145 |
+
uploaded_file1 = st.file_uploader(
|
| 146 |
+
"Upload PDF or DOCX", type=["pdf", "docx"], key="file1",
|
| 147 |
+
help="Maximum file size: 50MB"
|
| 148 |
+
)
|
| 149 |
|
| 150 |
+
with col2:
|
| 151 |
+
st.subheader("π€ Document 2 (Comparison)")
|
| 152 |
+
uploaded_file2 = st.file_uploader(
|
| 153 |
+
"Upload PDF or DOCX", type=["pdf", "docx"], key="file2",
|
| 154 |
+
help="Maximum file size: 50MB"
|
| 155 |
+
)
|
| 156 |
|
| 157 |
+
st.markdown("---")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
+
if st.button("π Compare Documents", type="primary", use_container_width=True):
|
| 160 |
+
if not uploaded_file1 or not uploaded_file2:
|
| 161 |
+
st.error("Please upload both documents before comparing.")
|
| 162 |
+
else:
|
| 163 |
+
with st.spinner("Processing documents..."):
|
| 164 |
+
try:
|
| 165 |
+
file1_path = save_uploaded_file(uploaded_file1)
|
| 166 |
+
file2_path = save_uploaded_file(uploaded_file2)
|
| 167 |
+
|
| 168 |
+
valid1, error1 = validate_file(file1_path)
|
| 169 |
+
valid2, error2 = validate_file(file2_path)
|
| 170 |
+
|
| 171 |
+
if not valid1:
|
| 172 |
+
st.error(f"Document 1 error: {error1}"); st.stop()
|
| 173 |
+
if not valid2:
|
| 174 |
+
st.error(f"Document 2 error: {error2}"); st.stop()
|
| 175 |
+
|
| 176 |
+
report, raw_doc1, raw_doc2 = asyncio.run(
|
| 177 |
+
process_and_compare(file1_path, file2_path, weights, enable_phase2)
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Store raw docs for Graph RAG tab
|
| 181 |
+
st.session_state["raw_doc1"] = raw_doc1
|
| 182 |
+
st.session_state["raw_doc2"] = raw_doc2
|
| 183 |
+
# Reset any previous RAG state
|
| 184 |
+
st.session_state["rag_state"] = None
|
| 185 |
+
st.session_state["chat_history"] = []
|
| 186 |
+
|
| 187 |
+
display_results(report)
|
| 188 |
+
|
| 189 |
+
except Exception as e:
|
| 190 |
+
st.error(f"An error occurred: {str(e)}")
|
| 191 |
+
import traceback
|
| 192 |
+
st.code(traceback.format_exc())
|
| 193 |
+
|
| 194 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 195 |
+
# TAB 2 β Graph RAG Chat
|
| 196 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 197 |
+
with tab2:
|
| 198 |
+
st.subheader("π¬ Chat with your Documents (Graph RAG + Groq)")
|
| 199 |
+
|
| 200 |
+
docs_ready = (
|
| 201 |
+
st.session_state["raw_doc1"] is not None
|
| 202 |
+
and st.session_state["raw_doc2"] is not None
|
| 203 |
)
|
| 204 |
|
| 205 |
+
if not docs_ready:
|
| 206 |
+
st.info("π Please upload and compare documents in the **Document Comparison** tab first.")
|
| 207 |
+
else:
|
| 208 |
+
# Load Groq API key from environment (Hugging Face Spaces secrets)
|
| 209 |
+
groq_key = os.environ.get("GROQ_API_KEY", "")
|
| 210 |
+
|
| 211 |
+
if not groq_key:
|
| 212 |
+
st.warning("β οΈ GROQ_API_KEY not found in environment. Please set it in Hugging Face Spaces secrets.")
|
| 213 |
+
|
| 214 |
+
col_build, col_reset = st.columns([2, 1])
|
| 215 |
+
|
| 216 |
+
with col_build:
|
| 217 |
+
build_btn = st.button(
|
| 218 |
+
"π¨ Build Graph RAG Index",
|
| 219 |
+
disabled=not groq_key,
|
| 220 |
+
help="Chunks docs β embeds β builds vector DB + knowledge graph",
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
with col_reset:
|
| 224 |
+
if st.button("π Reset Chat"):
|
| 225 |
+
st.session_state["chat_history"] = []
|
| 226 |
+
if st.session_state["rag_pipeline"]:
|
| 227 |
+
st.session_state["rag_pipeline"].reset_chat()
|
| 228 |
+
st.rerun()
|
| 229 |
+
|
| 230 |
+
if build_btn:
|
| 231 |
+
with st.spinner("Chunking, embedding, building knowledge graph β this takes ~30sβ¦"):
|
| 232 |
+
pipeline = GraphRAGPipeline(
|
| 233 |
+
groq_api_key=groq_key,
|
| 234 |
+
chunk_size=chunk_size,
|
| 235 |
+
chunk_overlap=chunk_overlap,
|
| 236 |
+
top_k_vector=top_k,
|
| 237 |
+
graph_hops=graph_hops,
|
| 238 |
+
)
|
| 239 |
+
rag_state = pipeline.ingest(
|
| 240 |
+
st.session_state["raw_doc1"],
|
| 241 |
+
st.session_state["raw_doc2"],
|
| 242 |
+
)
|
| 243 |
+
st.session_state["rag_pipeline"] = pipeline
|
| 244 |
+
st.session_state["rag_state"] = rag_state
|
| 245 |
+
st.session_state["chat_history"] = []
|
| 246 |
+
|
| 247 |
+
st.success("β
Graph RAG index ready!")
|
| 248 |
+
|
| 249 |
+
s = rag_state.stats
|
| 250 |
+
c1, c2, c3, c4 = st.columns(4)
|
| 251 |
+
c1.metric("Doc 1 Chunks", s.get("doc1_chunks", 0))
|
| 252 |
+
c2.metric("Doc 2 Chunks", s.get("doc2_chunks", 0))
|
| 253 |
+
c3.metric("Graph Nodes", s.get("nodes", 0))
|
| 254 |
+
c4.metric("Graph Edges", s.get("edges", 0))
|
| 255 |
+
|
| 256 |
+
with st.expander("Edge type breakdown"):
|
| 257 |
+
for etype, cnt in s.get("edge_types", {}).items():
|
| 258 |
+
st.write(f"**{etype}**: {cnt}")
|
| 259 |
+
|
| 260 |
+
# ββ Chat UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 261 |
+
rag_ready = st.session_state["rag_state"] is not None
|
| 262 |
+
|
| 263 |
+
if rag_ready:
|
| 264 |
+
for msg in st.session_state["chat_history"]:
|
| 265 |
+
with st.chat_message(msg["role"]):
|
| 266 |
+
st.markdown(msg["content"])
|
| 267 |
+
|
| 268 |
+
if user_input := st.chat_input("Ask anything about the two documentsβ¦"):
|
| 269 |
+
st.session_state["chat_history"].append(
|
| 270 |
+
{"role": "user", "content": user_input}
|
| 271 |
+
)
|
| 272 |
+
with st.chat_message("user"):
|
| 273 |
+
st.markdown(user_input)
|
| 274 |
+
|
| 275 |
+
with st.chat_message("assistant"):
|
| 276 |
+
pipeline: GraphRAGPipeline = st.session_state["rag_pipeline"]
|
| 277 |
+
rag_state_obj: PipelineState = st.session_state["rag_state"]
|
| 278 |
+
|
| 279 |
+
response_gen = pipeline.query(user_input, rag_state_obj, stream=True)
|
| 280 |
+
full_response = st.write_stream(response_gen)
|
| 281 |
+
|
| 282 |
+
st.session_state["chat_history"].append(
|
| 283 |
+
{"role": "assistant", "content": full_response}
|
| 284 |
+
)
|
| 285 |
+
else:
|
| 286 |
+
st.info("π Click **Build Graph RAG Index** to start chatting. (Ensure GROQ_API_KEY is set in HF Spaces secrets)")
|
| 287 |
+
|
| 288 |
|
| 289 |
+
# ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 290 |
+
|
| 291 |
+
async def process_and_compare(file1_path, file2_path, weights, enable_phase2=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
ingestion_agent = IngestionAgent()
|
| 293 |
+
text_agent = TextAgent()
|
| 294 |
+
table_agent = TableAgent()
|
| 295 |
+
orchestrator = SimilarityOrchestrator(weights=weights)
|
| 296 |
|
| 297 |
+
image_agent = ImageAgent() if enable_phase2 and IMAGE_AGENT_AVAILABLE else None
|
|
|
|
| 298 |
layout_agent = LayoutAgent() if enable_phase2 and LAYOUT_AGENT_AVAILABLE else None
|
| 299 |
+
meta_agent = MetaAgent() if enable_phase2 and META_AGENT_AVAILABLE else None
|
| 300 |
|
|
|
|
| 301 |
progress_bar = st.progress(0)
|
| 302 |
+
status_text = st.empty()
|
| 303 |
|
|
|
|
| 304 |
status_text.text("β³ Ingesting documents...")
|
| 305 |
progress_bar.progress(10)
|
|
|
|
| 306 |
raw_doc1 = await ingestion_agent.process(file1_path)
|
| 307 |
raw_doc2 = await ingestion_agent.process(file2_path)
|
|
|
|
| 308 |
progress_bar.progress(15)
|
| 309 |
|
| 310 |
+
status_text.text("β³ Extracting textβ¦")
|
|
|
|
|
|
|
| 311 |
text_chunks1, text_embeddings1 = await text_agent.process(raw_doc1)
|
| 312 |
text_chunks2, text_embeddings2 = await text_agent.process(raw_doc2)
|
|
|
|
| 313 |
progress_bar.progress(30)
|
| 314 |
|
| 315 |
+
status_text.text("β³ Extracting tablesβ¦")
|
|
|
|
|
|
|
| 316 |
tables1, table_embeddings1 = await table_agent.process(raw_doc1)
|
| 317 |
tables2, table_embeddings2 = await table_agent.process(raw_doc2)
|
|
|
|
| 318 |
progress_bar.progress(45)
|
| 319 |
|
| 320 |
+
images1 = images2 = image_embeddings1 = image_embeddings2 = []
|
|
|
|
|
|
|
| 321 |
if image_agent:
|
| 322 |
+
status_text.text("β³ Extracting imagesβ¦")
|
| 323 |
try:
|
| 324 |
images1, image_embeddings1 = await image_agent.process(raw_doc1)
|
| 325 |
images2, image_embeddings2 = await image_agent.process(raw_doc2)
|
| 326 |
except Exception as e:
|
| 327 |
st.warning(f"Image extraction failed: {e}")
|
|
|
|
| 328 |
progress_bar.progress(60)
|
| 329 |
|
| 330 |
+
layout1 = layout2 = None
|
|
|
|
| 331 |
if layout_agent:
|
| 332 |
+
status_text.text("β³ Analysing layoutβ¦")
|
| 333 |
try:
|
| 334 |
layout1 = await layout_agent.process(raw_doc1)
|
| 335 |
layout2 = await layout_agent.process(raw_doc2)
|
| 336 |
except Exception as e:
|
| 337 |
st.warning(f"Layout analysis failed: {e}")
|
|
|
|
| 338 |
progress_bar.progress(70)
|
| 339 |
|
| 340 |
+
metadata1 = metadata2 = None
|
|
|
|
| 341 |
if meta_agent:
|
| 342 |
+
status_text.text("β³ Extracting metadataβ¦")
|
| 343 |
try:
|
| 344 |
metadata1 = await meta_agent.process(raw_doc1)
|
| 345 |
metadata2 = await meta_agent.process(raw_doc2)
|
| 346 |
except Exception as e:
|
| 347 |
st.warning(f"Metadata extraction failed: {e}")
|
|
|
|
| 348 |
progress_bar.progress(80)
|
| 349 |
|
|
|
|
| 350 |
processed_doc1 = ProcessedDocument(
|
| 351 |
+
filename=raw_doc1.filename, text_chunks=text_chunks1, tables=tables1,
|
| 352 |
+
total_pages=raw_doc1.total_pages, file_type=raw_doc1.file_type,
|
| 353 |
+
images=images1, layout=layout1, metadata=metadata1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
)
|
|
|
|
| 355 |
processed_doc2 = ProcessedDocument(
|
| 356 |
+
filename=raw_doc2.filename, text_chunks=text_chunks2, tables=tables2,
|
| 357 |
+
total_pages=raw_doc2.total_pages, file_type=raw_doc2.file_type,
|
| 358 |
+
images=images2, layout=layout2, metadata=metadata2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
)
|
| 360 |
|
| 361 |
+
status_text.text("β³ Comparing documentsβ¦")
|
|
|
|
|
|
|
| 362 |
report = await orchestrator.compare_documents(
|
| 363 |
+
processed_doc1, text_embeddings1, table_embeddings1,
|
| 364 |
+
processed_doc2, text_embeddings2, table_embeddings2,
|
| 365 |
+
image_embeddings1, image_embeddings2,
|
| 366 |
+
layout1, layout2, metadata1, metadata2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
)
|
| 368 |
|
| 369 |
progress_bar.progress(100)
|
| 370 |
status_text.text("β
Comparison complete!")
|
| 371 |
|
| 372 |
+
# Return report + raw docs (needed for Graph RAG)
|
| 373 |
+
return report, raw_doc1, raw_doc2
|
| 374 |
|
| 375 |
|
| 376 |
def display_results(report):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
st.markdown("---")
|
| 378 |
st.header("π Comparison Results")
|
| 379 |
|
|
|
|
| 380 |
col1, col2 = st.columns([1, 1])
|
|
|
|
| 381 |
with col1:
|
| 382 |
gauge_fig = create_similarity_gauge(report.overall_score)
|
| 383 |
st.plotly_chart(gauge_fig, use_container_width=True)
|
|
|
|
| 384 |
with col2:
|
| 385 |
st.markdown(create_score_legend())
|
| 386 |
|
|
|
|
| 387 |
st.markdown("---")
|
| 388 |
st.subheader("π Per-Modality Breakdown")
|
|
|
|
| 389 |
breakdown_fig = create_modality_breakdown_chart(report)
|
| 390 |
st.plotly_chart(breakdown_fig, use_container_width=True)
|
| 391 |
|
|
|
|
| 392 |
cols = st.columns(5)
|
| 393 |
+
scores = [
|
| 394 |
+
("Text Similarity", report.text_score, "num_matches"),
|
| 395 |
+
("Table Similarity", report.table_score, "num_matches"),
|
| 396 |
+
("Image Similarity", report.image_score, "num_matches"),
|
| 397 |
+
("Layout Similarity", report.layout_score, "num_metrics"),
|
| 398 |
+
("Metadata Similarity", report.metadata_score, "num_fields_compared"),
|
| 399 |
+
]
|
| 400 |
+
for col, (label, score_obj, detail_key) in zip(cols, scores):
|
| 401 |
+
if score_obj:
|
| 402 |
+
col.metric(label, f"{score_obj.score:.1%}",
|
| 403 |
+
f"{score_obj.details.get(detail_key, 0)} items")
|
| 404 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
st.markdown("---")
|
| 406 |
st.subheader("π Top Matched Sections")
|
|
|
|
| 407 |
if report.matched_sections:
|
| 408 |
+
st.markdown(format_matched_sections(report.matched_sections[:10]))
|
|
|
|
| 409 |
else:
|
| 410 |
+
st.info("No significant matches found.")
|
| 411 |
|
|
|
|
| 412 |
if report.image_score or report.layout_score or report.metadata_score:
|
| 413 |
st.markdown("---")
|
| 414 |
st.subheader("π¨ Phase 2 Modality Details")
|
|
|
|
|
|
|
| 415 |
if report.image_score and report.image_score.matched_items:
|
| 416 |
+
with st.expander(f"πΌοΈ Image Matches ({len(report.image_score.matched_items)})"):
|
| 417 |
+
for idx, m in enumerate(report.image_score.matched_items[:5], 1):
|
| 418 |
+
st.markdown(f"**Match {idx}** β {m['similarity']:.2%}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
if report.layout_score:
|
| 420 |
+
with st.expander(f"π Layout (Score: {report.layout_score.score:.1%})"):
|
| 421 |
+
for k, v in report.layout_score.details.items():
|
| 422 |
+
if k != "num_metrics":
|
| 423 |
+
st.metric(k.replace("_", " ").title(), f"{v:.2%}")
|
|
|
|
|
|
|
| 424 |
if report.metadata_score and report.metadata_score.matched_items:
|
| 425 |
+
with st.expander(f"π Metadata ({len(report.metadata_score.matched_items)} fields)"):
|
| 426 |
+
for m in report.metadata_score.matched_items:
|
| 427 |
+
st.markdown(f"**{m['field'].title()}** β {m['similarity']:.2%}")
|
| 428 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
st.markdown("---")
|
| 430 |
report_json = json.dumps(report.model_dump(), indent=2, default=str)
|
| 431 |
+
st.download_button(
|
| 432 |
+
"π₯ Download Report (JSON)", data=report_json,
|
| 433 |
+
file_name=f"similarity_report_{report.timestamp.strftime('%Y%m%d_%H%M%S')}.json",
|
| 434 |
+
mime="application/json"
|
| 435 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
|
| 437 |
|
| 438 |
if __name__ == "__main__":
|
|
Binary files a/src/utils/__pycache__/visualization.cpython-313.pyc and b/src/utils/__pycache__/visualization.cpython-313.pyc differ
|
|
|