syedmohaiminulhoque commited on
Commit
fa9ea37
Β·
1 Parent(s): c3c7335

feat: Implement Graph RAG pipeline with chunking, vector storage, and graph building

Browse files

- Added `rag` module with core components:
- `chunker.py`: Implements semantic chunking of documents.
- `vector_store.py`: Integrates ChromaDB for storing and retrieving document chunks.
- `graph_builder.py`: Constructs a knowledge graph from document chunks, establishing relationships based on similarity and section headings.
- `groq_chat.py`: Facilitates chat interactions using Groq API with context from the knowledge graph.
- `rag_pipeline.py`: Orchestrates the entire RAG process, from ingestion to querying.
- Introduced `PipelineState` to manage the state of the RAG pipeline.
- Enhanced document processing with robust text extraction and chunking strategies.
- Added support for entity linking and cross-document similarity in the graph.
- Integrated debug utilities for inspecting raw document attributes.

requirements.txt CHANGED
@@ -25,6 +25,15 @@ numpy>=1.26.0
25
  pandas>=2.2.0
26
  Pillow>=10.2.0
27
 
 
 
 
 
 
 
 
 
 
28
  # Utilities
29
  python-dotenv>=1.0.0
30
 
 
25
  pandas>=2.2.0
26
  Pillow>=10.2.0
27
 
28
+ # Vector DB (Graph RAG)
29
+ chromadb>=0.5.0
30
+
31
+ # Knowledge Graph (Graph RAG)
32
+ networkx>=3.2.0
33
+
34
+ # Groq API (Graph RAG Chat)
35
+ groq>=0.9.0
36
+
37
  # Utilities
38
  python-dotenv>=1.0.0
39
 
src/rag/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .chunker import Chunk, chunk_document, chunk_text
2
+ from .vector_store import VectorStore
3
+ from .graph_builder import GraphBuilder
4
+ from .groq_chat import GroqGraphChat
5
+
6
+ __all__ = [
7
+ "Chunk", "chunk_document", "chunk_text",
8
+ "VectorStore",
9
+ "GraphBuilder",
10
+ "GroqGraphChat",
11
+ ]
12
+
src/rag/chunker.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Smart Semantic Chunker
3
+ Chunks documents efficiently using sentence boundaries + structural signals.
4
+ """
5
+ import re
6
+ from typing import List, Dict, Any
7
+ from dataclasses import dataclass, field
8
+
9
+
10
+ @dataclass
11
+ class Chunk:
12
+ chunk_id: str
13
+ doc_id: str # "doc1" or "doc2"
14
+ text: str
15
+ chunk_index: int
16
+ section: str = "" # heading/section title if detected
17
+ page: int = 0
18
+ metadata: Dict[str, Any] = field(default_factory=dict)
19
+
20
+
21
+ def _split_sentences(text: str) -> List[str]:
22
+ """Split text into sentences using regex."""
23
+ sentences = re.split(r'(?<=[.!?])\s+(?=[A-Z])', text.strip())
24
+ return [s.strip() for s in sentences if s.strip()]
25
+
26
+
27
+ def _detect_heading(line: str) -> bool:
28
+ """Detect if a line looks like a section heading."""
29
+ line = line.strip()
30
+ if not line:
31
+ return False
32
+ if re.match(r'^(\d+[\.\)]\s+|[A-Z][A-Z\s]{3,50}$)', line):
33
+ return True
34
+ if len(line) < 80 and not line.endswith('.') and line[0].isupper():
35
+ if re.match(r'^(Abstract|Introduction|Conclusion|Method|Result|Discussion|Background|Overview|Summary)', line, re.I):
36
+ return True
37
+ return False
38
+
39
+
40
+ def chunk_text(
41
+ text: str,
42
+ doc_id: str,
43
+ chunk_size: int = 300,
44
+ overlap: int = 50,
45
+ ) -> List[Chunk]:
46
+ """
47
+ Semantic chunking with section awareness, sentence boundary respect,
48
+ and sliding window overlap.
49
+ """
50
+ chunks = []
51
+ lines = text.split('\n')
52
+
53
+ current_section = "General"
54
+ buffer_sentences = []
55
+ buffer_words = 0
56
+ chunk_index = 0
57
+
58
+ def flush_buffer(section: str) -> None:
59
+ nonlocal chunk_index, buffer_sentences, buffer_words
60
+ if not buffer_sentences:
61
+ return
62
+ chunk_text_val = ' '.join(buffer_sentences)
63
+ chunks.append(Chunk(
64
+ chunk_id=f"{doc_id}_chunk_{chunk_index}",
65
+ doc_id=doc_id,
66
+ text=chunk_text_val,
67
+ chunk_index=chunk_index,
68
+ section=section,
69
+ metadata={"word_count": buffer_words}
70
+ ))
71
+ chunk_index += 1
72
+ overlap_sentences = []
73
+ overlap_words = 0
74
+ for sent in reversed(buffer_sentences):
75
+ w = len(sent.split())
76
+ if overlap_words + w <= overlap:
77
+ overlap_sentences.insert(0, sent)
78
+ overlap_words += w
79
+ else:
80
+ break
81
+ buffer_sentences = overlap_sentences
82
+ buffer_words = overlap_words
83
+
84
+ paragraph_buffer = []
85
+
86
+ for line in lines:
87
+ stripped = line.strip()
88
+
89
+ if _detect_heading(stripped):
90
+ if paragraph_buffer:
91
+ full_text = ' '.join(paragraph_buffer)
92
+ sentences = _split_sentences(full_text)
93
+ for sent in sentences:
94
+ buffer_sentences.append(sent)
95
+ buffer_words += len(sent.split())
96
+ if buffer_words >= chunk_size:
97
+ flush_buffer(current_section)
98
+ paragraph_buffer = []
99
+ flush_buffer(current_section)
100
+ current_section = stripped
101
+ continue
102
+
103
+ if stripped:
104
+ paragraph_buffer.append(stripped)
105
+ else:
106
+ if paragraph_buffer:
107
+ full_text = ' '.join(paragraph_buffer)
108
+ sentences = _split_sentences(full_text)
109
+ for sent in sentences:
110
+ buffer_sentences.append(sent)
111
+ buffer_words += len(sent.split())
112
+ if buffer_words >= chunk_size:
113
+ flush_buffer(current_section)
114
+ paragraph_buffer = []
115
+
116
+ if paragraph_buffer:
117
+ full_text = ' '.join(paragraph_buffer)
118
+ sentences = _split_sentences(full_text)
119
+ for sent in sentences:
120
+ buffer_sentences.append(sent)
121
+ buffer_words += len(sent.split())
122
+ flush_buffer(current_section)
123
+
124
+ return chunks
125
+
126
+
127
+ # ── Debug helper ──────────────────────────────────────────────────────────────
128
+
129
+ def debug_raw_doc(raw_doc) -> str:
130
+ """Return a string summarising all attributes of a raw_doc for debugging."""
131
+ lines = [f"Type: {type(raw_doc).__name__}"]
132
+ try:
133
+ d = raw_doc.model_dump() if hasattr(raw_doc, 'model_dump') else vars(raw_doc)
134
+ for k, v in d.items():
135
+ if isinstance(v, str):
136
+ lines.append(f" str attr '{k}': len={len(v)} preview={repr(v[:80])}")
137
+ elif isinstance(v, list):
138
+ lines.append(f" list attr '{k}': len={len(v)}")
139
+ else:
140
+ lines.append(f" attr '{k}': {type(v).__name__} = {repr(str(v)[:60])}")
141
+ except Exception as e:
142
+ lines.append(f" (could not introspect: {e})")
143
+ return '\n'.join(lines)
144
+
145
+
146
+ # ── Robust text extraction ────────────────────────────────────────────────────
147
+
148
+ def extract_text_from_raw_doc(raw_doc) -> str:
149
+ """
150
+ Robustly extract text from whatever RawDocument the ingestion agent returns.
151
+ Tries all known attribute names and fallback strategies.
152
+ """
153
+ # Strategy 1: Common direct string attributes
154
+ for attr in ['text_content', 'content', 'text', 'raw_text', 'full_text', 'body',
155
+ 'extracted_text', 'plain_text', 'document_text']:
156
+ val = getattr(raw_doc, attr, None)
157
+ if val and isinstance(val, str) and len(val.strip()) > 10:
158
+ return val.strip()
159
+
160
+ # Strategy 2: List of pages / sections
161
+ for attr in ['pages', 'sections', 'chunks', 'paragraphs', 'text_chunks']:
162
+ val = getattr(raw_doc, attr, None)
163
+ if val and isinstance(val, list):
164
+ parts = []
165
+ for item in val:
166
+ if isinstance(item, str):
167
+ parts.append(item)
168
+ elif hasattr(item, 'text') and isinstance(item.text, str):
169
+ parts.append(item.text)
170
+ elif hasattr(item, 'content') and isinstance(item.content, str):
171
+ parts.append(item.content)
172
+ elif isinstance(item, dict):
173
+ parts.append(str(item.get('text') or item.get('content') or ''))
174
+ combined = '\n'.join(p for p in parts if p.strip())
175
+ if len(combined.strip()) > 10:
176
+ return combined.strip()
177
+
178
+ # Strategy 3: Pydantic model_dump / __dict__ β€” grab longest string field
179
+ try:
180
+ d = raw_doc.model_dump() if hasattr(raw_doc, 'model_dump') else vars(raw_doc)
181
+ # Preferred keys first
182
+ for key in ['text_content', 'content', 'text', 'raw_text', 'full_text', 'body']:
183
+ if key in d and isinstance(d[key], str) and len(d[key].strip()) > 10:
184
+ return d[key].strip()
185
+ # Any long string
186
+ best = max(
187
+ ((k, v) for k, v in d.items() if isinstance(v, str)),
188
+ key=lambda kv: len(kv[1]),
189
+ default=(None, ''),
190
+ )
191
+ if len(best[1]) > 100:
192
+ return best[1].strip()
193
+ except Exception:
194
+ pass
195
+
196
+ # Strategy 4: str() fallback
197
+ fallback = str(raw_doc)
198
+ if len(fallback) > 50 and not fallback.startswith('<'):
199
+ return fallback
200
+
201
+ return ""
202
+
203
+
204
+ def chunk_document(raw_doc, doc_id: str, chunk_size: int = 300, overlap: int = 50) -> List[Chunk]:
205
+ """
206
+ Chunk a RawDocument object from the ingestion agent.
207
+ Robustly handles any attribute structure.
208
+ """
209
+ text = extract_text_from_raw_doc(raw_doc)
210
+
211
+ if not text:
212
+ return [Chunk(
213
+ chunk_id=f"{doc_id}_chunk_0",
214
+ doc_id=doc_id,
215
+ text=f"[Could not extract text from {doc_id}. Attributes: {debug_raw_doc(raw_doc)[:200]}]",
216
+ chunk_index=0,
217
+ section="Error",
218
+ )]
219
+
220
+ chunks = chunk_text(text, doc_id, chunk_size, overlap)
221
+
222
+ if not chunks:
223
+ return [Chunk(
224
+ chunk_id=f"{doc_id}_chunk_0",
225
+ doc_id=doc_id,
226
+ text=text[:500],
227
+ chunk_index=0,
228
+ section="General",
229
+ )]
230
+
231
+ return chunks
src/rag/graph_builder.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Graph RAG β€” Knowledge Graph Builder
3
+ Builds a NetworkX graph where:
4
+ - Nodes = chunks (from doc1 & doc2)
5
+ - Edges = relationships between chunks:
6
+ * sequential : consecutive chunks in same document
7
+ * same_section : chunks sharing the same heading/section
8
+ * cross_similar: high cosine similarity between doc1 chunk & doc2 chunk
9
+ * entity_link : chunks sharing important noun phrases (entities)
10
+ """
11
+ import re
12
+ import networkx as nx
13
+ from typing import List, Dict, Any, Tuple
14
+ from sentence_transformers import SentenceTransformer
15
+ import numpy as np
16
+ from sklearn.metrics.pairwise import cosine_similarity
17
+
18
+ from .chunker import Chunk
19
+
20
+
21
+ _EMBED_MODEL_NAME = "all-MiniLM-L6-v2"
22
+ _CROSS_SIM_THRESHOLD = 0.55 # min similarity to create a cross-doc edge
23
+ _ENTITY_MIN_LEN = 4 # min characters for an entity term
24
+
25
+
26
+ def _extract_noun_phrases(text: str) -> set:
27
+ """
28
+ Lightweight noun phrase extraction via regex patterns.
29
+ No spacy dependency β€” works in constrained environments.
30
+ """
31
+ # Capitalised multi-word phrases and key technical terms
32
+ patterns = [
33
+ r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)+\b', # "Neural Network", "New York"
34
+ r'\b[A-Z]{2,}\b', # acronyms: "RAG", "LLM"
35
+ r'\b\w{5,}\b', # any long word (catch technical terms)
36
+ ]
37
+ entities = set()
38
+ for pat in patterns:
39
+ found = re.findall(pat, text)
40
+ entities.update(f.strip().lower() for f in found if len(f) >= _ENTITY_MIN_LEN)
41
+ # Remove very common stopwords
42
+ stopwords = {'which', 'these', 'those', 'their', 'there', 'where', 'about',
43
+ 'would', 'could', 'should', 'other', 'being', 'using', 'having'}
44
+ return entities - stopwords
45
+
46
+
47
+ class GraphBuilder:
48
+ """
49
+ Builds and queries a knowledge graph from doc chunks.
50
+ """
51
+
52
+ def __init__(self):
53
+ self._model = SentenceTransformer(_EMBED_MODEL_NAME)
54
+ self.graph: nx.Graph = nx.Graph()
55
+ self._chunk_map: Dict[str, Chunk] = {} # chunk_id -> Chunk
56
+
57
+ # ------------------------------------------------------------------
58
+ # Build
59
+ # ------------------------------------------------------------------
60
+
61
+ def build(self, doc1_chunks: List[Chunk], doc2_chunks: List[Chunk]) -> nx.Graph:
62
+ """
63
+ Full graph construction pipeline.
64
+ Returns the built NetworkX graph.
65
+ """
66
+ self.graph = nx.Graph()
67
+ self._chunk_map = {}
68
+
69
+ all_chunks = doc1_chunks + doc2_chunks
70
+
71
+ # 1. Add nodes
72
+ for chunk in all_chunks:
73
+ self._chunk_map[chunk.chunk_id] = chunk
74
+ self.graph.add_node(
75
+ chunk.chunk_id,
76
+ text=chunk.text[:200], # store snippet
77
+ doc_id=chunk.doc_id,
78
+ section=chunk.section,
79
+ chunk_index=chunk.chunk_index,
80
+ entities=list(_extract_noun_phrases(chunk.text)),
81
+ )
82
+
83
+ # 2. Sequential edges (within same doc)
84
+ self._add_sequential_edges(doc1_chunks)
85
+ self._add_sequential_edges(doc2_chunks)
86
+
87
+ # 3. Same-section edges
88
+ self._add_section_edges(all_chunks)
89
+
90
+ # 4. Cross-document similarity edges
91
+ self._add_cross_similarity_edges(doc1_chunks, doc2_chunks)
92
+
93
+ # 5. Entity co-occurrence edges
94
+ self._add_entity_edges(all_chunks)
95
+
96
+ return self.graph
97
+
98
+ def _add_sequential_edges(self, chunks: List[Chunk]) -> None:
99
+ sorted_chunks = sorted(chunks, key=lambda c: c.chunk_index)
100
+ for i in range(len(sorted_chunks) - 1):
101
+ a, b = sorted_chunks[i], sorted_chunks[i + 1]
102
+ self.graph.add_edge(
103
+ a.chunk_id, b.chunk_id,
104
+ relation="sequential",
105
+ weight=0.9,
106
+ )
107
+
108
+ def _add_section_edges(self, chunks: List[Chunk]) -> None:
109
+ section_map: Dict[str, List[str]] = {}
110
+ for chunk in chunks:
111
+ key = f"{chunk.doc_id}::{chunk.section}"
112
+ section_map.setdefault(key, []).append(chunk.chunk_id)
113
+
114
+ for ids in section_map.values():
115
+ for i in range(len(ids)):
116
+ for j in range(i + 1, len(ids)):
117
+ if not self.graph.has_edge(ids[i], ids[j]):
118
+ self.graph.add_edge(
119
+ ids[i], ids[j],
120
+ relation="same_section",
121
+ weight=0.6,
122
+ )
123
+
124
+ def _add_cross_similarity_edges(
125
+ self, doc1_chunks: List[Chunk], doc2_chunks: List[Chunk]
126
+ ) -> None:
127
+ if not doc1_chunks or not doc2_chunks:
128
+ return
129
+
130
+ texts1 = [c.text for c in doc1_chunks]
131
+ texts2 = [c.text for c in doc2_chunks]
132
+
133
+ emb1 = self._model.encode(texts1, batch_size=32, show_progress_bar=False)
134
+ emb2 = self._model.encode(texts2, batch_size=32, show_progress_bar=False)
135
+
136
+ sim_matrix = cosine_similarity(emb1, emb2)
137
+
138
+ for i, c1 in enumerate(doc1_chunks):
139
+ for j, c2 in enumerate(doc2_chunks):
140
+ sim = float(sim_matrix[i, j])
141
+ if sim >= _CROSS_SIM_THRESHOLD:
142
+ self.graph.add_edge(
143
+ c1.chunk_id, c2.chunk_id,
144
+ relation="cross_similar",
145
+ weight=round(sim, 4),
146
+ similarity=round(sim, 4),
147
+ )
148
+
149
+ def _add_entity_edges(self, chunks: List[Chunk]) -> None:
150
+ entity_to_chunks: Dict[str, List[str]] = {}
151
+ for chunk in chunks:
152
+ entities = _extract_noun_phrases(chunk.text)
153
+ for ent in entities:
154
+ entity_to_chunks.setdefault(ent, []).append(chunk.chunk_id)
155
+
156
+ for ent, ids in entity_to_chunks.items():
157
+ if len(ids) < 2:
158
+ continue
159
+ # Only connect cross-doc pairs to avoid too many same-doc entity edges
160
+ doc_ids = {self._chunk_map[cid].doc_id: cid for cid in ids}
161
+ if len(doc_ids) >= 2:
162
+ cids = list(doc_ids.values())
163
+ for i in range(len(cids)):
164
+ for j in range(i + 1, len(cids)):
165
+ if not self.graph.has_edge(cids[i], cids[j]):
166
+ self.graph.add_edge(
167
+ cids[i], cids[j],
168
+ relation="entity_link",
169
+ entity=ent,
170
+ weight=0.5,
171
+ )
172
+
173
+ # ------------------------------------------------------------------
174
+ # Query
175
+ # ------------------------------------------------------------------
176
+
177
+ def retrieve(
178
+ self,
179
+ query: str,
180
+ seed_chunks: List[Dict[str, Any]], # from VectorStore.search()
181
+ hops: int = 2,
182
+ max_nodes: int = 10,
183
+ ) -> List[Dict[str, Any]]:
184
+ """
185
+ Graph-aware retrieval:
186
+ 1. Start from seed chunk nodes (vector search results)
187
+ 2. Expand via BFS up to `hops` hops, prioritising high-weight edges
188
+ 3. Return unique chunks from both docs, ranked by relevance
189
+ """
190
+ visited = set()
191
+ result_nodes = []
192
+
193
+ seed_ids = [
194
+ f"{s['doc_id']}_chunk_{s['chunk_index']}"
195
+ for s in seed_chunks
196
+ if s.get('chunk_index') is not None
197
+ ]
198
+
199
+ # BFS queue: (node_id, remaining_hops, accumulated_weight)
200
+ queue = [(nid, hops, 1.0) for nid in seed_ids if nid in self.graph]
201
+
202
+ while queue and len(result_nodes) < max_nodes:
203
+ node_id, remaining, acc_weight = queue.pop(0)
204
+ if node_id in visited:
205
+ continue
206
+ visited.add(node_id)
207
+
208
+ chunk = self._chunk_map.get(node_id)
209
+ if chunk:
210
+ result_nodes.append({
211
+ "chunk_id": node_id,
212
+ "text": chunk.text,
213
+ "doc_id": chunk.doc_id,
214
+ "section": chunk.section,
215
+ "relevance": round(acc_weight, 4),
216
+ })
217
+
218
+ if remaining > 0:
219
+ neighbors = sorted(
220
+ self.graph[node_id].items(),
221
+ key=lambda x: x[1].get("weight", 0),
222
+ reverse=True,
223
+ )
224
+ for neighbor_id, edge_data in neighbors[:4]: # top-4 neighbours
225
+ if neighbor_id not in visited:
226
+ queue.append((
227
+ neighbor_id,
228
+ remaining - 1,
229
+ acc_weight * edge_data.get("weight", 0.5),
230
+ ))
231
+
232
+ # Sort by relevance
233
+ result_nodes.sort(key=lambda x: x["relevance"], reverse=True)
234
+ return result_nodes[:max_nodes]
235
+
236
+ def get_stats(self) -> Dict[str, Any]:
237
+ edge_types = {}
238
+ for _, _, data in self.graph.edges(data=True):
239
+ rel = data.get("relation", "unknown")
240
+ edge_types[rel] = edge_types.get(rel, 0) + 1
241
+ return {
242
+ "nodes": self.graph.number_of_nodes(),
243
+ "edges": self.graph.number_of_edges(),
244
+ "edge_types": edge_types,
245
+ }
src/rag/groq_chat.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Groq Chat with Graph RAG context injection.
3
+ Uses llama-3.3-70b-versatile (fast + smart) via Groq API.
4
+ """
5
+ import os
6
+ from typing import List, Dict, Any, Generator
7
+ from groq import Groq
8
+
9
+
10
+ _DEFAULT_MODEL = "llama-3.3-70b-versatile"
11
+ _MAX_CONTEXT_CHARS = 6000 # stay within context window safely
12
+
13
+
14
+ def _build_context(retrieved_nodes: List[Dict[str, Any]]) -> str:
15
+ """
16
+ Format retrieved graph nodes into a clean context block for the LLM.
17
+ Groups by document for clarity.
18
+ """
19
+ doc1_nodes = [n for n in retrieved_nodes if n.get("doc_id") == "doc1"]
20
+ doc2_nodes = [n for n in retrieved_nodes if n.get("doc_id") == "doc2"]
21
+
22
+ parts = []
23
+
24
+ if doc1_nodes:
25
+ parts.append("### Relevant passages from Document 1:")
26
+ for node in doc1_nodes:
27
+ sec = f" [{node['section']}]" if node.get("section") else ""
28
+ parts.append(f"- {node['text'][:500]}{sec}")
29
+
30
+ if doc2_nodes:
31
+ parts.append("\n### Relevant passages from Document 2:")
32
+ for node in doc2_nodes:
33
+ sec = f" [{node['section']}]" if node.get("section") else ""
34
+ parts.append(f"- {node['text'][:500]}{sec}")
35
+
36
+ context = "\n".join(parts)
37
+ return context[:_MAX_CONTEXT_CHARS]
38
+
39
+
40
+ _SYSTEM_PROMPT = """You are an expert document analyst assistant with access to two documents that have been processed, chunked, and indexed using a Knowledge Graph RAG system.
41
+
42
+ You will be given:
43
+ 1. CONTEXT: Relevant passages retrieved from both documents via graph-enhanced semantic search
44
+ 2. USER QUESTION: What the user wants to know
45
+
46
+ Your job:
47
+ - Answer using ONLY the provided context
48
+ - Clearly indicate which document (Document 1 or Document 2) information comes from
49
+ - If comparing both documents, highlight similarities and differences
50
+ - If the context doesn't contain the answer, say so honestly
51
+ - Be concise, accurate, and helpful
52
+ """
53
+
54
+
55
+ class GroqGraphChat:
56
+ """
57
+ Stateful chat session backed by Groq API + GraphRAG context injection.
58
+ """
59
+
60
+ def __init__(self, api_key: str, model: str = _DEFAULT_MODEL):
61
+ self._client = Groq(api_key=api_key)
62
+ self._model = model
63
+ self._history: List[Dict[str, str]] = []
64
+
65
+ def reset(self) -> None:
66
+ self._history = []
67
+
68
+ def chat(
69
+ self,
70
+ user_query: str,
71
+ retrieved_nodes: List[Dict[str, Any]],
72
+ stream: bool = True,
73
+ ) -> str | Generator:
74
+ """
75
+ Send a message with GraphRAG context and get a response.
76
+
77
+ Args:
78
+ user_query: The user's question
79
+ retrieved_nodes: Chunks from GraphBuilder.retrieve()
80
+ stream: If True, returns a generator for streaming UI
81
+
82
+ Returns:
83
+ Full response string (if stream=False) or generator (if stream=True)
84
+ """
85
+ context = _build_context(retrieved_nodes)
86
+
87
+ # Build the user turn with injected context
88
+ augmented_user_message = f"""<context>
89
+ {context}
90
+ </context>
91
+
92
+ <question>
93
+ {user_query}
94
+ </question>"""
95
+
96
+ # Append to history
97
+ self._history.append({"role": "user", "content": augmented_user_message})
98
+
99
+ messages = [{"role": "system", "content": _SYSTEM_PROMPT}] + self._history
100
+
101
+ if stream:
102
+ return self._stream_response(messages)
103
+ else:
104
+ return self._full_response(messages)
105
+
106
+ def _full_response(self, messages: List[Dict]) -> str:
107
+ response = self._client.chat.completions.create(
108
+ model=self._model,
109
+ messages=messages,
110
+ max_tokens=1024,
111
+ temperature=0.3,
112
+ )
113
+ answer = response.choices[0].message.content
114
+ self._history.append({"role": "assistant", "content": answer})
115
+ return answer
116
+
117
+ def _stream_response(self, messages: List[Dict]) -> Generator:
118
+ stream = self._client.chat.completions.create(
119
+ model=self._model,
120
+ messages=messages,
121
+ max_tokens=1024,
122
+ temperature=0.3,
123
+ stream=True,
124
+ )
125
+ full_response = ""
126
+ for chunk in stream:
127
+ delta = chunk.choices[0].delta.content or ""
128
+ full_response += delta
129
+ yield delta
130
+ self._history.append({"role": "assistant", "content": full_response})
src/rag/rag_pipeline.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAG Pipeline β€” wires everything together.
3
+ Used by the Streamlit chat tab.
4
+ """
5
+ from typing import List, Dict, Any, Optional
6
+ from dataclasses import dataclass, field
7
+
8
+ from .chunker import Chunk, chunk_document
9
+ from .vector_store import VectorStore
10
+ from .graph_builder import GraphBuilder
11
+ from .groq_chat import GroqGraphChat
12
+
13
+
14
+ @dataclass
15
+ class PipelineState:
16
+ """Holds the built RAG state after ingestion."""
17
+ doc1_chunks: List[Chunk] = field(default_factory=list)
18
+ doc2_chunks: List[Chunk] = field(default_factory=list)
19
+ vector_store: Optional[VectorStore] = None
20
+ graph_builder: Optional[GraphBuilder] = None
21
+ is_ready: bool = False
22
+ stats: Dict[str, Any] = field(default_factory=dict)
23
+
24
+
25
+ class GraphRAGPipeline:
26
+ """
27
+ End-to-end Graph RAG pipeline.
28
+
29
+ Usage:
30
+ pipeline = GraphRAGPipeline(groq_api_key="...")
31
+ state = pipeline.ingest(raw_doc1, raw_doc2)
32
+ answer = pipeline.query("What does doc1 say about climate?", state)
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ groq_api_key: str,
38
+ chunk_size: int = 300,
39
+ chunk_overlap: int = 50,
40
+ top_k_vector: int = 5,
41
+ graph_hops: int = 2,
42
+ graph_max_nodes: int = 10,
43
+ ):
44
+ self.groq_api_key = groq_api_key
45
+ self.chunk_size = chunk_size
46
+ self.chunk_overlap = chunk_overlap
47
+ self.top_k_vector = top_k_vector
48
+ self.graph_hops = graph_hops
49
+ self.graph_max_nodes = graph_max_nodes
50
+
51
+ self._chat: Optional[GroqGraphChat] = None
52
+
53
+ # ------------------------------------------------------------------
54
+ # Ingestion
55
+ # ------------------------------------------------------------------
56
+
57
+ def ingest(self, raw_doc1, raw_doc2) -> PipelineState:
58
+ """
59
+ Process both documents: chunk β†’ embed β†’ store β†’ build graph.
60
+ Returns a PipelineState that should be stored in st.session_state.
61
+ """
62
+ state = PipelineState()
63
+
64
+ # 1. Chunk
65
+ state.doc1_chunks = chunk_document(
66
+ raw_doc1, "doc1", self.chunk_size, self.chunk_overlap
67
+ )
68
+ state.doc2_chunks = chunk_document(
69
+ raw_doc2, "doc2", self.chunk_size, self.chunk_overlap
70
+ )
71
+
72
+ # 2. Vector store
73
+ state.vector_store = VectorStore()
74
+ state.vector_store.add_chunks(state.doc1_chunks)
75
+ state.vector_store.add_chunks(state.doc2_chunks)
76
+
77
+ # 3. Knowledge graph
78
+ state.graph_builder = GraphBuilder()
79
+ state.graph_builder.build(state.doc1_chunks, state.doc2_chunks)
80
+
81
+ # 4. Stats
82
+ graph_stats = state.graph_builder.get_stats()
83
+ state.stats = {
84
+ "doc1_chunks": len(state.doc1_chunks),
85
+ "doc2_chunks": len(state.doc2_chunks),
86
+ "total_vectors": state.vector_store.count(),
87
+ **graph_stats,
88
+ }
89
+ state.is_ready = True
90
+
91
+ # 5. Fresh chat session
92
+ self._chat = GroqGraphChat(api_key=self.groq_api_key)
93
+
94
+ return state
95
+
96
+ # ------------------------------------------------------------------
97
+ # Query
98
+ # ------------------------------------------------------------------
99
+
100
+ def query(
101
+ self,
102
+ user_query: str,
103
+ state: PipelineState,
104
+ stream: bool = True,
105
+ ):
106
+ """
107
+ Retrieve relevant context via vector + graph search,
108
+ then pass to Groq for generation.
109
+ """
110
+ if not state.is_ready:
111
+ raise RuntimeError("Pipeline not ready. Call ingest() first.")
112
+
113
+ # Step 1: Vector search (both docs)
114
+ seed_chunks = state.vector_store.search(
115
+ user_query, n_results=self.top_k_vector
116
+ )
117
+
118
+ # Step 2: Graph expansion
119
+ retrieved_nodes = state.graph_builder.retrieve(
120
+ query=user_query,
121
+ seed_chunks=seed_chunks,
122
+ hops=self.graph_hops,
123
+ max_nodes=self.graph_max_nodes,
124
+ )
125
+
126
+ # Fallback: if graph expansion returned nothing, use raw vector results
127
+ if not retrieved_nodes:
128
+ retrieved_nodes = [
129
+ {
130
+ "chunk_id": f"{s['doc_id']}_chunk_{s['chunk_index']}",
131
+ "text": s["text"],
132
+ "doc_id": s["doc_id"],
133
+ "section": s.get("section", ""),
134
+ "relevance": s["score"],
135
+ }
136
+ for s in seed_chunks
137
+ ]
138
+
139
+ # Step 3: Generate answer via Groq
140
+ return self._chat.chat(
141
+ user_query=user_query,
142
+ retrieved_nodes=retrieved_nodes,
143
+ stream=stream,
144
+ )
145
+
146
+ def reset_chat(self) -> None:
147
+ """Clear conversation history (keep the indexed data)."""
148
+ if self._chat:
149
+ self._chat.reset()
src/rag/vector_store.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vector Store using ChromaDB (in-memory, HF Spaces compatible)
3
+ Stores and retrieves chunks from both documents via semantic search.
4
+ """
5
+ import chromadb
6
+ from chromadb.config import Settings
7
+ from sentence_transformers import SentenceTransformer
8
+ from typing import List, Dict, Any, Optional
9
+ import hashlib
10
+
11
+ from .chunker import Chunk
12
+
13
+
14
+ _EMBED_MODEL_NAME = "all-MiniLM-L6-v2" # fast, small, works great
15
+
16
+
17
+ class VectorStore:
18
+ """
19
+ Wraps ChromaDB with a SentenceTransformer embedding function.
20
+ Collection name: 'doc_chunks' β€” shared for both documents.
21
+ """
22
+
23
+ def __init__(self, persist_dir: Optional[str] = None):
24
+ self._model = SentenceTransformer(_EMBED_MODEL_NAME)
25
+
26
+ if persist_dir:
27
+ self._client = chromadb.PersistentClient(path=persist_dir)
28
+ else:
29
+ self._client = chromadb.EphemeralClient()
30
+
31
+ self._collection = self._client.get_or_create_collection(
32
+ name="doc_chunks",
33
+ metadata={"hnsw:space": "cosine"},
34
+ )
35
+
36
+ # ------------------------------------------------------------------
37
+ # Write
38
+ # ------------------------------------------------------------------
39
+
40
+ def add_chunks(self, chunks: List[Chunk]) -> None:
41
+ """Embed and upsert chunks into the collection."""
42
+ if not chunks:
43
+ return
44
+
45
+ texts = [c.text for c in chunks]
46
+ embeddings = self._model.encode(texts, batch_size=32, show_progress_bar=False).tolist()
47
+
48
+ ids = [c.chunk_id for c in chunks]
49
+ metadatas = [
50
+ {
51
+ "doc_id": c.doc_id,
52
+ "chunk_index": c.chunk_index,
53
+ "section": c.section,
54
+ "page": c.page,
55
+ **{k: str(v) for k, v in c.metadata.items()},
56
+ }
57
+ for c in chunks
58
+ ]
59
+
60
+ self._collection.upsert(
61
+ ids=ids,
62
+ embeddings=embeddings,
63
+ documents=texts,
64
+ metadatas=metadatas,
65
+ )
66
+
67
+ def clear(self) -> None:
68
+ """Remove all chunks (useful for re-ingestion)."""
69
+ self._client.delete_collection("doc_chunks")
70
+ self._collection = self._client.get_or_create_collection(
71
+ name="doc_chunks",
72
+ metadata={"hnsw:space": "cosine"},
73
+ )
74
+
75
+ # ------------------------------------------------------------------
76
+ # Read
77
+ # ------------------------------------------------------------------
78
+
79
+ def search(
80
+ self,
81
+ query: str,
82
+ n_results: int = 5,
83
+ doc_filter: Optional[str] = None, # "doc1" | "doc2" | None
84
+ ) -> List[Dict[str, Any]]:
85
+ """
86
+ Semantic search over stored chunks.
87
+ Returns list of dicts with keys: text, doc_id, section, score.
88
+ """
89
+ query_embedding = self._model.encode([query]).tolist()
90
+
91
+ where = {"doc_id": doc_filter} if doc_filter else None
92
+
93
+ results = self._collection.query(
94
+ query_embeddings=query_embedding,
95
+ n_results=min(n_results, self._collection.count() or 1),
96
+ where=where,
97
+ include=["documents", "metadatas", "distances"],
98
+ )
99
+
100
+ hits = []
101
+ for text, meta, dist in zip(
102
+ results["documents"][0],
103
+ results["metadatas"][0],
104
+ results["distances"][0],
105
+ ):
106
+ hits.append({
107
+ "text": text,
108
+ "doc_id": meta.get("doc_id"),
109
+ "section": meta.get("section", ""),
110
+ "chunk_index": meta.get("chunk_index", -1),
111
+ "score": round(1 - dist, 4), # cosine similarity
112
+ })
113
+
114
+ return hits
115
+
116
+ def count(self) -> int:
117
+ return self._collection.count()
118
+
119
+ def get_all_chunks_for_doc(self, doc_id: str) -> List[Dict[str, Any]]:
120
+ """Retrieve all stored chunks for a given document."""
121
+ results = self._collection.get(
122
+ where={"doc_id": doc_id},
123
+ include=["documents", "metadatas"],
124
+ )
125
+ items = []
126
+ for text, meta in zip(results["documents"], results["metadatas"]):
127
+ items.append({"text": text, **meta})
128
+ # Sort by chunk_index
129
+ items.sort(key=lambda x: int(x.get("chunk_index", 0)))
130
+ return items
src/streamlit_app.py CHANGED
@@ -1,10 +1,11 @@
1
  """
2
  Multi-Agent Document Comparison Streamlit App
 
3
  """
4
  import sys
 
5
  from pathlib import Path
6
 
7
- # Add project root to Python path for imports
8
  project_root = Path(__file__).parent
9
  if str(project_root) not in sys.path:
10
  sys.path.insert(0, str(project_root))
@@ -13,7 +14,6 @@ import streamlit as st
13
  import asyncio
14
  import json
15
 
16
- # Import agents and utilities
17
  from agents.ingestion_agent import IngestionAgent
18
  from agents.text_agent import TextAgent
19
  from agents.table_agent import TableAgent
@@ -28,7 +28,10 @@ from utils.visualization import (
28
  from models.document import ProcessedDocument
29
  import config
30
 
31
- # Phase 2 imports (conditional based on availability)
 
 
 
32
  try:
33
  from agents.image_agent import ImageAgent
34
  IMAGE_AGENT_AVAILABLE = True
@@ -48,7 +51,6 @@ except ImportError:
48
  META_AGENT_AVAILABLE = False
49
 
50
 
51
- # Page configuration
52
  st.set_page_config(
53
  page_title="Multi-Agent Document Comparator",
54
  page_icon="πŸ“„",
@@ -58,448 +60,379 @@ st.set_page_config(
58
 
59
 
60
  def main():
61
- """Main application function."""
62
-
63
- # Header
64
- st.title("πŸ“„ Multi-Agent Document Comparator")
65
- st.markdown("**An agentic system to accurately match document similarity**")
66
 
67
- # Show architecture diagram
68
  with st.expander("πŸ—οΈ View System Architecture", expanded=False):
69
  arch_path = Path("src/img/multi_agent_doc_similarity_architecture.svg")
70
  if arch_path.exists():
71
  st.image(str(arch_path), use_container_width=True)
72
- else:
73
- st.info("Architecture diagram not found")
74
 
75
  st.markdown("---")
76
 
77
- # Sidebar configuration
78
  with st.sidebar:
79
  st.header("βš™οΈ Configuration")
80
 
81
- # Phase 2 feature toggles
82
  st.subheader("Phase 2 Features")
83
  enable_phase2 = st.checkbox(
84
  "Enable Phase 2 Modalities",
85
  value=config.ENABLE_IMAGE_COMPARISON,
86
  help="Enable image, layout, and metadata comparison"
87
  )
88
-
89
- # Modality weights
90
  st.markdown("---")
91
  st.subheader("Modality Weights")
92
 
93
  if enable_phase2:
94
- # Phase 2: All 5 modalities
95
- text_weight = st.slider(
96
- "Text Weight",
97
- min_value=0.0,
98
- max_value=1.0,
99
- value=config.MODALITY_WEIGHTS["text"],
100
- step=0.05
101
- )
102
- table_weight = st.slider(
103
- "Table Weight",
104
- min_value=0.0,
105
- max_value=1.0,
106
- value=config.MODALITY_WEIGHTS["table"],
107
- step=0.05
108
- )
109
- image_weight = st.slider(
110
- "Image Weight",
111
- min_value=0.0,
112
- max_value=1.0,
113
- value=config.MODALITY_WEIGHTS["image"],
114
- step=0.05
115
- )
116
- layout_weight = st.slider(
117
- "Layout Weight",
118
- min_value=0.0,
119
- max_value=1.0,
120
- value=config.MODALITY_WEIGHTS["layout"],
121
- step=0.05
122
- )
123
- metadata_weight = st.slider(
124
- "Metadata Weight",
125
- min_value=0.0,
126
- max_value=1.0,
127
- value=config.MODALITY_WEIGHTS["metadata"],
128
- step=0.05
129
- )
130
-
131
- # Normalize weights to sum to 1.0
132
- total_weight = text_weight + table_weight + image_weight + layout_weight + metadata_weight
133
- if total_weight > 0:
134
  weights = {
135
- "text": text_weight / total_weight,
136
- "table": table_weight / total_weight,
137
- "image": image_weight / total_weight,
138
- "layout": layout_weight / total_weight,
139
- "metadata": metadata_weight / total_weight
140
  }
141
  else:
142
  weights = config.MODALITY_WEIGHTS
143
-
144
- st.info(f"Weights normalized to sum to 1.0")
145
-
146
  else:
147
- # Phase 1: Only text and tables
148
- text_weight = st.slider(
149
- "Text Weight",
150
- min_value=0.0,
151
- max_value=1.0,
152
- value=config.MODALITY_WEIGHTS_PHASE1["text"],
153
- step=0.05
154
- )
155
  table_weight = 1.0 - text_weight
156
  st.write(f"Table Weight: {table_weight:.2f}")
157
-
158
  weights = {"text": text_weight, "table": table_weight}
159
 
160
- # Phase status
161
  st.markdown("---")
162
- st.subheader("πŸ“‹ Implementation Status")
163
  st.write("βœ… Text comparison")
164
  st.write("βœ… Table comparison")
165
-
166
  if enable_phase2:
167
- st.write(f"{'βœ…' if IMAGE_AGENT_AVAILABLE else '⚠️'} Image comparison")
168
  st.write(f"{'βœ…' if LAYOUT_AGENT_AVAILABLE else '⚠️'} Layout comparison")
169
- st.write(f"{'βœ…' if META_AGENT_AVAILABLE else '⚠️'} Metadata comparison")
170
  else:
171
- st.write("⏸️ Image comparison (disabled)")
172
- st.write("⏸️ Layout comparison (disabled)")
173
- st.write("⏸️ Metadata comparison (disabled)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
- # Main content area
176
- col1, col2 = st.columns(2)
 
 
 
 
177
 
178
- with col1:
179
- st.subheader("πŸ“€ Document 1 (Main)")
180
- uploaded_file1 = st.file_uploader(
181
- "Upload PDF or DOCX",
182
- type=["pdf", "docx"],
183
- key="file1",
184
- help="Maximum file size: 50MB"
185
- )
186
 
187
- with col2:
188
- st.subheader("πŸ“€ Document 2 (Comparison)")
189
- uploaded_file2 = st.file_uploader(
190
- "Upload PDF or DOCX",
191
- type=["pdf", "docx"],
192
- key="file2",
193
- help="Maximum file size: 50MB"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  )
195
 
196
- # Compare button
197
- st.markdown("---")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
- if st.button("πŸ” Compare Documents", type="primary", use_container_width=True):
200
- if not uploaded_file1 or not uploaded_file2:
201
- st.error("Please upload both documents before comparing.")
202
- return
203
-
204
- # Process documents and compare
205
- with st.spinner("Processing documents..."):
206
- try:
207
- # Save uploaded files
208
- file1_path = save_uploaded_file(uploaded_file1)
209
- file2_path = save_uploaded_file(uploaded_file2)
210
-
211
- # Validate files
212
- valid1, error1 = validate_file(file1_path)
213
- valid2, error2 = validate_file(file2_path)
214
-
215
- if not valid1:
216
- st.error(f"Document 1 error: {error1}")
217
- return
218
- if not valid2:
219
- st.error(f"Document 2 error: {error2}")
220
- return
221
-
222
- # Process documents
223
- report = asyncio.run(process_and_compare(
224
- file1_path,
225
- file2_path,
226
- weights,
227
- enable_phase2
228
- ))
229
-
230
- # Display results
231
- display_results(report)
232
-
233
- except Exception as e:
234
- st.error(f"An error occurred: {str(e)}")
235
- import traceback
236
- st.code(traceback.format_exc())
237
-
238
-
239
- async def process_and_compare(file1_path: str, file2_path: str, weights: dict, enable_phase2: bool = False):
240
- """
241
- Process two documents and compare them.
242
-
243
- Args:
244
- file1_path: Path to first document
245
- file2_path: Path to second document
246
- weights: Modality weights
247
- enable_phase2: Enable Phase 2 modalities (image, layout, metadata)
248
-
249
- Returns:
250
- SimilarityReport
251
- """
252
- # Initialize agents
253
  ingestion_agent = IngestionAgent()
254
- text_agent = TextAgent()
255
- table_agent = TableAgent()
256
- orchestrator = SimilarityOrchestrator(weights=weights)
257
 
258
- # Phase 2 agents (conditional)
259
- image_agent = ImageAgent() if enable_phase2 and IMAGE_AGENT_AVAILABLE else None
260
  layout_agent = LayoutAgent() if enable_phase2 and LAYOUT_AGENT_AVAILABLE else None
261
- meta_agent = MetaAgent() if enable_phase2 and META_AGENT_AVAILABLE else None
262
 
263
- # Progress tracking
264
  progress_bar = st.progress(0)
265
- status_text = st.empty()
266
 
267
- # Step 1: Ingest documents
268
  status_text.text("⏳ Ingesting documents...")
269
  progress_bar.progress(10)
270
-
271
  raw_doc1 = await ingestion_agent.process(file1_path)
272
  raw_doc2 = await ingestion_agent.process(file2_path)
273
-
274
  progress_bar.progress(15)
275
 
276
- # Step 2: Extract text
277
- status_text.text("⏳ Extracting and embedding text...")
278
-
279
  text_chunks1, text_embeddings1 = await text_agent.process(raw_doc1)
280
  text_chunks2, text_embeddings2 = await text_agent.process(raw_doc2)
281
-
282
  progress_bar.progress(30)
283
 
284
- # Step 3: Extract tables
285
- status_text.text("⏳ Extracting and embedding tables...")
286
-
287
  tables1, table_embeddings1 = await table_agent.process(raw_doc1)
288
  tables2, table_embeddings2 = await table_agent.process(raw_doc2)
289
-
290
  progress_bar.progress(45)
291
 
292
- # Phase 2: Extract images
293
- images1, image_embeddings1 = [], None
294
- images2, image_embeddings2 = [], None
295
  if image_agent:
296
- status_text.text("⏳ Extracting and embedding images...")
297
  try:
298
  images1, image_embeddings1 = await image_agent.process(raw_doc1)
299
  images2, image_embeddings2 = await image_agent.process(raw_doc2)
300
  except Exception as e:
301
  st.warning(f"Image extraction failed: {e}")
302
-
303
  progress_bar.progress(60)
304
 
305
- # Phase 2: Extract layout
306
- layout1, layout2 = None, None
307
  if layout_agent:
308
- status_text.text("⏳ Analyzing document structure...")
309
  try:
310
  layout1 = await layout_agent.process(raw_doc1)
311
  layout2 = await layout_agent.process(raw_doc2)
312
  except Exception as e:
313
  st.warning(f"Layout analysis failed: {e}")
314
-
315
  progress_bar.progress(70)
316
 
317
- # Phase 2: Extract metadata
318
- metadata1, metadata2 = None, None
319
  if meta_agent:
320
- status_text.text("⏳ Extracting metadata...")
321
  try:
322
  metadata1 = await meta_agent.process(raw_doc1)
323
  metadata2 = await meta_agent.process(raw_doc2)
324
  except Exception as e:
325
  st.warning(f"Metadata extraction failed: {e}")
326
-
327
  progress_bar.progress(80)
328
 
329
- # Create processed documents
330
  processed_doc1 = ProcessedDocument(
331
- filename=raw_doc1.filename,
332
- text_chunks=text_chunks1,
333
- tables=tables1,
334
- total_pages=raw_doc1.total_pages,
335
- file_type=raw_doc1.file_type,
336
- images=images1,
337
- layout=layout1,
338
- metadata=metadata1
339
  )
340
-
341
  processed_doc2 = ProcessedDocument(
342
- filename=raw_doc2.filename,
343
- text_chunks=text_chunks2,
344
- tables=tables2,
345
- total_pages=raw_doc2.total_pages,
346
- file_type=raw_doc2.file_type,
347
- images=images2,
348
- layout=layout2,
349
- metadata=metadata2
350
  )
351
 
352
- # Compare documents
353
- status_text.text("⏳ Comparing documents...")
354
-
355
  report = await orchestrator.compare_documents(
356
- processed_doc1,
357
- text_embeddings1,
358
- table_embeddings1,
359
- processed_doc2,
360
- text_embeddings2,
361
- table_embeddings2,
362
- # Phase 2 parameters
363
- image_embeddings1,
364
- image_embeddings2,
365
- layout1,
366
- layout2,
367
- metadata1,
368
- metadata2
369
  )
370
 
371
  progress_bar.progress(100)
372
  status_text.text("βœ… Comparison complete!")
373
 
374
- return report
 
375
 
376
 
377
  def display_results(report):
378
- """
379
- Display comparison results.
380
-
381
- Args:
382
- report: SimilarityReport object
383
- """
384
  st.markdown("---")
385
  st.header("πŸ“Š Comparison Results")
386
 
387
- # Overall similarity gauge
388
  col1, col2 = st.columns([1, 1])
389
-
390
  with col1:
391
  gauge_fig = create_similarity_gauge(report.overall_score)
392
  st.plotly_chart(gauge_fig, use_container_width=True)
393
-
394
  with col2:
395
  st.markdown(create_score_legend())
396
 
397
- # Modality breakdown
398
  st.markdown("---")
399
  st.subheader("πŸ“ˆ Per-Modality Breakdown")
400
-
401
  breakdown_fig = create_modality_breakdown_chart(report)
402
  st.plotly_chart(breakdown_fig, use_container_width=True)
403
 
404
- # Detailed scores
405
  cols = st.columns(5)
 
 
 
 
 
 
 
 
 
 
 
406
 
407
- with cols[0]:
408
- if report.text_score:
409
- st.metric(
410
- "Text Similarity",
411
- f"{report.text_score.score:.1%}",
412
- f"{report.text_score.details.get('num_matches', 0)} matches"
413
- )
414
-
415
- with cols[1]:
416
- if report.table_score:
417
- st.metric(
418
- "Table Similarity",
419
- f"{report.table_score.score:.1%}",
420
- f"{report.table_score.details.get('num_matches', 0)} matches"
421
- )
422
-
423
- with cols[2]:
424
- if report.image_score:
425
- st.metric(
426
- "Image Similarity",
427
- f"{report.image_score.score:.1%}",
428
- f"{report.image_score.details.get('num_matches', 0)} matches"
429
- )
430
-
431
- with cols[3]:
432
- if report.layout_score:
433
- st.metric(
434
- "Layout Similarity",
435
- f"{report.layout_score.score:.1%}",
436
- f"{report.layout_score.details.get('num_metrics', 0)} metrics"
437
- )
438
-
439
- with cols[4]:
440
- if report.metadata_score:
441
- st.metric(
442
- "Metadata Similarity",
443
- f"{report.metadata_score.score:.1%}",
444
- f"{report.metadata_score.details.get('num_fields_compared', 0)} fields"
445
- )
446
-
447
- # Matched sections
448
  st.markdown("---")
449
  st.subheader("πŸ”— Top Matched Sections")
450
-
451
  if report.matched_sections:
452
- formatted_sections = format_matched_sections(report.matched_sections[:10])
453
- st.markdown(formatted_sections)
454
  else:
455
- st.info("No significant matches found between documents.")
456
 
457
- # Phase 2: Additional modality details
458
  if report.image_score or report.layout_score or report.metadata_score:
459
  st.markdown("---")
460
  st.subheader("🎨 Phase 2 Modality Details")
461
-
462
- # Image matches
463
  if report.image_score and report.image_score.matched_items:
464
- with st.expander(f"πŸ–ΌοΈ Image Matches ({len(report.image_score.matched_items)} found)", expanded=False):
465
- for idx, match in enumerate(report.image_score.matched_items[:5], 1):
466
- st.markdown(f"**Match {idx}** - Similarity: {match['similarity']:.2%}")
467
- st.write(f"Doc1: Page {match['doc1_page']}, Size: {match['doc1_size']}")
468
- st.write(f"Doc2: Page {match['doc2_page']}, Size: {match['doc2_size']}")
469
- st.markdown("---")
470
-
471
- # Layout details
472
  if report.layout_score:
473
- with st.expander(f"πŸ“ Layout Analysis (Score: {report.layout_score.score:.1%})", expanded=False):
474
- for metric, value in report.layout_score.details.items():
475
- if metric != "num_metrics":
476
- st.metric(metric.replace("_", " ").title(), f"{value:.2%}")
477
-
478
- # Metadata matches
479
  if report.metadata_score and report.metadata_score.matched_items:
480
- with st.expander(f"πŸ“‹ Metadata Comparison ({len(report.metadata_score.matched_items)} fields)", expanded=False):
481
- for match in report.metadata_score.matched_items:
482
- st.markdown(f"**{match['field'].title()}** - Similarity: {match['similarity']:.2%}")
483
- col1, col2 = st.columns(2)
484
- with col1:
485
- st.write(f"Doc1: {match['doc1_value']}")
486
- with col2:
487
- st.write(f"Doc2: {match['doc2_value']}")
488
- st.markdown("---")
489
-
490
- # Download report
491
  st.markdown("---")
492
  report_json = json.dumps(report.model_dump(), indent=2, default=str)
493
-
494
- col1, col2, col3 = st.columns([1, 1, 2])
495
-
496
- with col1:
497
- st.download_button(
498
- label="πŸ“₯ Download Report (JSON)",
499
- data=report_json,
500
- file_name=f"similarity_report_{report.timestamp.strftime('%Y%m%d_%H%M%S')}.json",
501
- mime="application/json"
502
- )
503
 
504
 
505
  if __name__ == "__main__":
 
1
  """
2
  Multi-Agent Document Comparison Streamlit App
3
+ + Graph RAG Chat Tab (new)
4
  """
5
  import sys
6
+ import os
7
  from pathlib import Path
8
 
 
9
  project_root = Path(__file__).parent
10
  if str(project_root) not in sys.path:
11
  sys.path.insert(0, str(project_root))
 
14
  import asyncio
15
  import json
16
 
 
17
  from agents.ingestion_agent import IngestionAgent
18
  from agents.text_agent import TextAgent
19
  from agents.table_agent import TableAgent
 
28
  from models.document import ProcessedDocument
29
  import config
30
 
31
+ # Graph RAG imports
32
+ from rag.rag_pipeline import GraphRAGPipeline, PipelineState
33
+
34
+ # Phase 2 imports (conditional)
35
  try:
36
  from agents.image_agent import ImageAgent
37
  IMAGE_AGENT_AVAILABLE = True
 
51
  META_AGENT_AVAILABLE = False
52
 
53
 
 
54
  st.set_page_config(
55
  page_title="Multi-Agent Document Comparator",
56
  page_icon="πŸ“„",
 
60
 
61
 
62
  def main():
63
+ st.title("πŸ“„ Multi-Agent Document Comparator + Graph RAG Chat")
64
+ st.markdown("**Agentic document similarity Β· Knowledge Graph RAG Β· Groq-powered chat**")
 
 
 
65
 
 
66
  with st.expander("πŸ—οΈ View System Architecture", expanded=False):
67
  arch_path = Path("src/img/multi_agent_doc_similarity_architecture.svg")
68
  if arch_path.exists():
69
  st.image(str(arch_path), use_container_width=True)
 
 
70
 
71
  st.markdown("---")
72
 
73
+ # ── Sidebar ───────────────────────────────────────────────────────────────
74
  with st.sidebar:
75
  st.header("βš™οΈ Configuration")
76
 
 
77
  st.subheader("Phase 2 Features")
78
  enable_phase2 = st.checkbox(
79
  "Enable Phase 2 Modalities",
80
  value=config.ENABLE_IMAGE_COMPARISON,
81
  help="Enable image, layout, and metadata comparison"
82
  )
 
 
83
  st.markdown("---")
84
  st.subheader("Modality Weights")
85
 
86
  if enable_phase2:
87
+ text_weight = st.slider("Text Weight", 0.0, 1.0, config.MODALITY_WEIGHTS["text"], 0.05)
88
+ table_weight = st.slider("Table Weight", 0.0, 1.0, config.MODALITY_WEIGHTS["table"], 0.05)
89
+ image_weight = st.slider("Image Weight", 0.0, 1.0, config.MODALITY_WEIGHTS["image"], 0.05)
90
+ layout_weight = st.slider("Layout Weight", 0.0, 1.0, config.MODALITY_WEIGHTS["layout"], 0.05)
91
+ meta_weight = st.slider("Metadata Weight", 0.0, 1.0, config.MODALITY_WEIGHTS["metadata"], 0.05)
92
+
93
+ total = text_weight + table_weight + image_weight + layout_weight + meta_weight
94
+ if total > 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  weights = {
96
+ "text": text_weight / total,
97
+ "table": table_weight / total,
98
+ "image": image_weight / total,
99
+ "layout": layout_weight / total,
100
+ "metadata": meta_weight / total,
101
  }
102
  else:
103
  weights = config.MODALITY_WEIGHTS
104
+ st.info("Weights normalised to 1.0")
 
 
105
  else:
106
+ text_weight = st.slider("Text Weight", 0.0, 1.0, config.MODALITY_WEIGHTS_PHASE1["text"], 0.05)
 
 
 
 
 
 
 
107
  table_weight = 1.0 - text_weight
108
  st.write(f"Table Weight: {table_weight:.2f}")
 
109
  weights = {"text": text_weight, "table": table_weight}
110
 
 
111
  st.markdown("---")
112
+ st.subheader("πŸ“‹ Status")
113
  st.write("βœ… Text comparison")
114
  st.write("βœ… Table comparison")
 
115
  if enable_phase2:
116
+ st.write(f"{'βœ…' if IMAGE_AGENT_AVAILABLE else '⚠️'} Image comparison")
117
  st.write(f"{'βœ…' if LAYOUT_AGENT_AVAILABLE else '⚠️'} Layout comparison")
118
+ st.write(f"{'βœ…' if META_AGENT_AVAILABLE else '⚠️'} Metadata comparison")
119
  else:
120
+ st.write("⏸️ Image / Layout / Metadata (disabled)")
121
+
122
+ st.markdown("---")
123
+ st.subheader("πŸ”— Graph RAG Settings")
124
+ chunk_size = st.slider("Chunk size (words)", 100, 600, 300, 50)
125
+ chunk_overlap = st.slider("Overlap (words)", 20, 150, 50, 10)
126
+ top_k = st.slider("Vector top-k", 3, 15, 5, 1)
127
+ graph_hops = st.slider("Graph hops", 1, 4, 2, 1)
128
+
129
+ # ── Main tabs ─────────────────────────────────────────────────────────────
130
+ tab1, tab2 = st.tabs(["πŸ“Š Document Comparison", "πŸ’¬ Graph RAG Chat"])
131
+
132
+ # ── Session state init ────────────────────────────────────────────────────
133
+ for key in ["raw_doc1", "raw_doc2", "rag_state", "rag_pipeline", "chat_history"]:
134
+ if key not in st.session_state:
135
+ st.session_state[key] = None if key != "chat_history" else []
136
+
137
+ # ════════════════════════════════════════════════════════════════════════
138
+ # TAB 1 β€” Comparison
139
+ # ════════════════════════════════════════════════════════════════════════
140
+ with tab1:
141
+ col1, col2 = st.columns(2)
142
+
143
+ with col1:
144
+ st.subheader("πŸ“€ Document 1 (Main)")
145
+ uploaded_file1 = st.file_uploader(
146
+ "Upload PDF or DOCX", type=["pdf", "docx"], key="file1",
147
+ help="Maximum file size: 50MB"
148
+ )
149
 
150
+ with col2:
151
+ st.subheader("πŸ“€ Document 2 (Comparison)")
152
+ uploaded_file2 = st.file_uploader(
153
+ "Upload PDF or DOCX", type=["pdf", "docx"], key="file2",
154
+ help="Maximum file size: 50MB"
155
+ )
156
 
157
+ st.markdown("---")
 
 
 
 
 
 
 
158
 
159
+ if st.button("πŸ” Compare Documents", type="primary", use_container_width=True):
160
+ if not uploaded_file1 or not uploaded_file2:
161
+ st.error("Please upload both documents before comparing.")
162
+ else:
163
+ with st.spinner("Processing documents..."):
164
+ try:
165
+ file1_path = save_uploaded_file(uploaded_file1)
166
+ file2_path = save_uploaded_file(uploaded_file2)
167
+
168
+ valid1, error1 = validate_file(file1_path)
169
+ valid2, error2 = validate_file(file2_path)
170
+
171
+ if not valid1:
172
+ st.error(f"Document 1 error: {error1}"); st.stop()
173
+ if not valid2:
174
+ st.error(f"Document 2 error: {error2}"); st.stop()
175
+
176
+ report, raw_doc1, raw_doc2 = asyncio.run(
177
+ process_and_compare(file1_path, file2_path, weights, enable_phase2)
178
+ )
179
+
180
+ # Store raw docs for Graph RAG tab
181
+ st.session_state["raw_doc1"] = raw_doc1
182
+ st.session_state["raw_doc2"] = raw_doc2
183
+ # Reset any previous RAG state
184
+ st.session_state["rag_state"] = None
185
+ st.session_state["chat_history"] = []
186
+
187
+ display_results(report)
188
+
189
+ except Exception as e:
190
+ st.error(f"An error occurred: {str(e)}")
191
+ import traceback
192
+ st.code(traceback.format_exc())
193
+
194
+ # ════════════════════════════════════════════════════════════════════════
195
+ # TAB 2 β€” Graph RAG Chat
196
+ # ════════════════════════════════════════════════════════════════════════
197
+ with tab2:
198
+ st.subheader("πŸ’¬ Chat with your Documents (Graph RAG + Groq)")
199
+
200
+ docs_ready = (
201
+ st.session_state["raw_doc1"] is not None
202
+ and st.session_state["raw_doc2"] is not None
203
  )
204
 
205
+ if not docs_ready:
206
+ st.info("πŸ“‚ Please upload and compare documents in the **Document Comparison** tab first.")
207
+ else:
208
+ # Load Groq API key from environment (Hugging Face Spaces secrets)
209
+ groq_key = os.environ.get("GROQ_API_KEY", "")
210
+
211
+ if not groq_key:
212
+ st.warning("⚠️ GROQ_API_KEY not found in environment. Please set it in Hugging Face Spaces secrets.")
213
+
214
+ col_build, col_reset = st.columns([2, 1])
215
+
216
+ with col_build:
217
+ build_btn = st.button(
218
+ "πŸ”¨ Build Graph RAG Index",
219
+ disabled=not groq_key,
220
+ help="Chunks docs β†’ embeds β†’ builds vector DB + knowledge graph",
221
+ )
222
+
223
+ with col_reset:
224
+ if st.button("πŸ”„ Reset Chat"):
225
+ st.session_state["chat_history"] = []
226
+ if st.session_state["rag_pipeline"]:
227
+ st.session_state["rag_pipeline"].reset_chat()
228
+ st.rerun()
229
+
230
+ if build_btn:
231
+ with st.spinner("Chunking, embedding, building knowledge graph β€” this takes ~30s…"):
232
+ pipeline = GraphRAGPipeline(
233
+ groq_api_key=groq_key,
234
+ chunk_size=chunk_size,
235
+ chunk_overlap=chunk_overlap,
236
+ top_k_vector=top_k,
237
+ graph_hops=graph_hops,
238
+ )
239
+ rag_state = pipeline.ingest(
240
+ st.session_state["raw_doc1"],
241
+ st.session_state["raw_doc2"],
242
+ )
243
+ st.session_state["rag_pipeline"] = pipeline
244
+ st.session_state["rag_state"] = rag_state
245
+ st.session_state["chat_history"] = []
246
+
247
+ st.success("βœ… Graph RAG index ready!")
248
+
249
+ s = rag_state.stats
250
+ c1, c2, c3, c4 = st.columns(4)
251
+ c1.metric("Doc 1 Chunks", s.get("doc1_chunks", 0))
252
+ c2.metric("Doc 2 Chunks", s.get("doc2_chunks", 0))
253
+ c3.metric("Graph Nodes", s.get("nodes", 0))
254
+ c4.metric("Graph Edges", s.get("edges", 0))
255
+
256
+ with st.expander("Edge type breakdown"):
257
+ for etype, cnt in s.get("edge_types", {}).items():
258
+ st.write(f"**{etype}**: {cnt}")
259
+
260
+ # ── Chat UI ───────────────────────────────────────────────────────
261
+ rag_ready = st.session_state["rag_state"] is not None
262
+
263
+ if rag_ready:
264
+ for msg in st.session_state["chat_history"]:
265
+ with st.chat_message(msg["role"]):
266
+ st.markdown(msg["content"])
267
+
268
+ if user_input := st.chat_input("Ask anything about the two documents…"):
269
+ st.session_state["chat_history"].append(
270
+ {"role": "user", "content": user_input}
271
+ )
272
+ with st.chat_message("user"):
273
+ st.markdown(user_input)
274
+
275
+ with st.chat_message("assistant"):
276
+ pipeline: GraphRAGPipeline = st.session_state["rag_pipeline"]
277
+ rag_state_obj: PipelineState = st.session_state["rag_state"]
278
+
279
+ response_gen = pipeline.query(user_input, rag_state_obj, stream=True)
280
+ full_response = st.write_stream(response_gen)
281
+
282
+ st.session_state["chat_history"].append(
283
+ {"role": "assistant", "content": full_response}
284
+ )
285
+ else:
286
+ st.info("πŸ‘† Click **Build Graph RAG Index** to start chatting. (Ensure GROQ_API_KEY is set in HF Spaces secrets)")
287
+
288
 
289
+ # ── Helpers ───────────────────────────────────────────────────────────────────
290
+
291
+ async def process_and_compare(file1_path, file2_path, weights, enable_phase2=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  ingestion_agent = IngestionAgent()
293
+ text_agent = TextAgent()
294
+ table_agent = TableAgent()
295
+ orchestrator = SimilarityOrchestrator(weights=weights)
296
 
297
+ image_agent = ImageAgent() if enable_phase2 and IMAGE_AGENT_AVAILABLE else None
 
298
  layout_agent = LayoutAgent() if enable_phase2 and LAYOUT_AGENT_AVAILABLE else None
299
+ meta_agent = MetaAgent() if enable_phase2 and META_AGENT_AVAILABLE else None
300
 
 
301
  progress_bar = st.progress(0)
302
+ status_text = st.empty()
303
 
 
304
  status_text.text("⏳ Ingesting documents...")
305
  progress_bar.progress(10)
 
306
  raw_doc1 = await ingestion_agent.process(file1_path)
307
  raw_doc2 = await ingestion_agent.process(file2_path)
 
308
  progress_bar.progress(15)
309
 
310
+ status_text.text("⏳ Extracting text…")
 
 
311
  text_chunks1, text_embeddings1 = await text_agent.process(raw_doc1)
312
  text_chunks2, text_embeddings2 = await text_agent.process(raw_doc2)
 
313
  progress_bar.progress(30)
314
 
315
+ status_text.text("⏳ Extracting tables…")
 
 
316
  tables1, table_embeddings1 = await table_agent.process(raw_doc1)
317
  tables2, table_embeddings2 = await table_agent.process(raw_doc2)
 
318
  progress_bar.progress(45)
319
 
320
+ images1 = images2 = image_embeddings1 = image_embeddings2 = []
 
 
321
  if image_agent:
322
+ status_text.text("⏳ Extracting images…")
323
  try:
324
  images1, image_embeddings1 = await image_agent.process(raw_doc1)
325
  images2, image_embeddings2 = await image_agent.process(raw_doc2)
326
  except Exception as e:
327
  st.warning(f"Image extraction failed: {e}")
 
328
  progress_bar.progress(60)
329
 
330
+ layout1 = layout2 = None
 
331
  if layout_agent:
332
+ status_text.text("⏳ Analysing layout…")
333
  try:
334
  layout1 = await layout_agent.process(raw_doc1)
335
  layout2 = await layout_agent.process(raw_doc2)
336
  except Exception as e:
337
  st.warning(f"Layout analysis failed: {e}")
 
338
  progress_bar.progress(70)
339
 
340
+ metadata1 = metadata2 = None
 
341
  if meta_agent:
342
+ status_text.text("⏳ Extracting metadata…")
343
  try:
344
  metadata1 = await meta_agent.process(raw_doc1)
345
  metadata2 = await meta_agent.process(raw_doc2)
346
  except Exception as e:
347
  st.warning(f"Metadata extraction failed: {e}")
 
348
  progress_bar.progress(80)
349
 
 
350
  processed_doc1 = ProcessedDocument(
351
+ filename=raw_doc1.filename, text_chunks=text_chunks1, tables=tables1,
352
+ total_pages=raw_doc1.total_pages, file_type=raw_doc1.file_type,
353
+ images=images1, layout=layout1, metadata=metadata1
 
 
 
 
 
354
  )
 
355
  processed_doc2 = ProcessedDocument(
356
+ filename=raw_doc2.filename, text_chunks=text_chunks2, tables=tables2,
357
+ total_pages=raw_doc2.total_pages, file_type=raw_doc2.file_type,
358
+ images=images2, layout=layout2, metadata=metadata2
 
 
 
 
 
359
  )
360
 
361
+ status_text.text("⏳ Comparing documents…")
 
 
362
  report = await orchestrator.compare_documents(
363
+ processed_doc1, text_embeddings1, table_embeddings1,
364
+ processed_doc2, text_embeddings2, table_embeddings2,
365
+ image_embeddings1, image_embeddings2,
366
+ layout1, layout2, metadata1, metadata2
 
 
 
 
 
 
 
 
 
367
  )
368
 
369
  progress_bar.progress(100)
370
  status_text.text("βœ… Comparison complete!")
371
 
372
+ # Return report + raw docs (needed for Graph RAG)
373
+ return report, raw_doc1, raw_doc2
374
 
375
 
376
  def display_results(report):
 
 
 
 
 
 
377
  st.markdown("---")
378
  st.header("πŸ“Š Comparison Results")
379
 
 
380
  col1, col2 = st.columns([1, 1])
 
381
  with col1:
382
  gauge_fig = create_similarity_gauge(report.overall_score)
383
  st.plotly_chart(gauge_fig, use_container_width=True)
 
384
  with col2:
385
  st.markdown(create_score_legend())
386
 
 
387
  st.markdown("---")
388
  st.subheader("πŸ“ˆ Per-Modality Breakdown")
 
389
  breakdown_fig = create_modality_breakdown_chart(report)
390
  st.plotly_chart(breakdown_fig, use_container_width=True)
391
 
 
392
  cols = st.columns(5)
393
+ scores = [
394
+ ("Text Similarity", report.text_score, "num_matches"),
395
+ ("Table Similarity", report.table_score, "num_matches"),
396
+ ("Image Similarity", report.image_score, "num_matches"),
397
+ ("Layout Similarity", report.layout_score, "num_metrics"),
398
+ ("Metadata Similarity", report.metadata_score, "num_fields_compared"),
399
+ ]
400
+ for col, (label, score_obj, detail_key) in zip(cols, scores):
401
+ if score_obj:
402
+ col.metric(label, f"{score_obj.score:.1%}",
403
+ f"{score_obj.details.get(detail_key, 0)} items")
404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  st.markdown("---")
406
  st.subheader("πŸ”— Top Matched Sections")
 
407
  if report.matched_sections:
408
+ st.markdown(format_matched_sections(report.matched_sections[:10]))
 
409
  else:
410
+ st.info("No significant matches found.")
411
 
 
412
  if report.image_score or report.layout_score or report.metadata_score:
413
  st.markdown("---")
414
  st.subheader("🎨 Phase 2 Modality Details")
 
 
415
  if report.image_score and report.image_score.matched_items:
416
+ with st.expander(f"πŸ–ΌοΈ Image Matches ({len(report.image_score.matched_items)})"):
417
+ for idx, m in enumerate(report.image_score.matched_items[:5], 1):
418
+ st.markdown(f"**Match {idx}** β€” {m['similarity']:.2%}")
 
 
 
 
 
419
  if report.layout_score:
420
+ with st.expander(f"πŸ“ Layout (Score: {report.layout_score.score:.1%})"):
421
+ for k, v in report.layout_score.details.items():
422
+ if k != "num_metrics":
423
+ st.metric(k.replace("_", " ").title(), f"{v:.2%}")
 
 
424
  if report.metadata_score and report.metadata_score.matched_items:
425
+ with st.expander(f"πŸ“‹ Metadata ({len(report.metadata_score.matched_items)} fields)"):
426
+ for m in report.metadata_score.matched_items:
427
+ st.markdown(f"**{m['field'].title()}** β€” {m['similarity']:.2%}")
428
+
 
 
 
 
 
 
 
429
  st.markdown("---")
430
  report_json = json.dumps(report.model_dump(), indent=2, default=str)
431
+ st.download_button(
432
+ "πŸ“₯ Download Report (JSON)", data=report_json,
433
+ file_name=f"similarity_report_{report.timestamp.strftime('%Y%m%d_%H%M%S')}.json",
434
+ mime="application/json"
435
+ )
 
 
 
 
 
436
 
437
 
438
  if __name__ == "__main__":
src/utils/__pycache__/visualization.cpython-313.pyc CHANGED
Binary files a/src/utils/__pycache__/visualization.cpython-313.pyc and b/src/utils/__pycache__/visualization.cpython-313.pyc differ