import logging import os import re import threading from collections import Counter, defaultdict from dataclasses import dataclass from huggingface_hub import InferenceClient from langchain_text_splitters import RecursiveCharacterTextSplitter log = logging.getLogger(__name__) FALLBACK_MODELS = [ "Qwen/Qwen2.5-72B-Instruct", "meta-llama/Llama-3.3-70B-Instruct", "mistralai/Mistral-Small-24B-Instruct-2501", ] SYSTEM_PROMPT = """\ You are a Graph-RAG reasoning engine. Your job is NOT to answer directly from text. Your job is to construct answers by traversing entity relationships. STRICT RULES: 1. ENTITY EXTRACTION - Identify all entities in the question. - Do not answer yet. 2. RELATIONSHIP TRAVERSAL - Use retrieved context to find explicit relationships between entities. - Build a step-by-step path using only these relationships. 3. MULTI-HOP ENFORCEMENT - If the question requires multiple steps, you MUST show intermediate entities. - Do NOT skip steps even if the answer is obvious. 4. NO SHORTCUTS - Do NOT answer from a single chunk if it contains the full answer. - You MUST validate at least one intermediate relationship (bridge). 5. FAITHFULNESS - Use ONLY retrieved context. Do NOT use prior knowledge. - If the path cannot be constructed -> return INSUFFICIENT_CONTEXT. 6. NEGATIVE GUARD - If the question asks for unknown/private/unsupported info -> return INSUFFICIENT_CONTEXT. 7. OUTPUT FORMAT (MANDATORY) - Return JSON ONLY: { "answer": "...", "reasoning_type": "direct | multi-hop | insufficient", "path": ["Entity1 -> Entity2", "Entity2 -> Entity3"], "used_chunks": ["0", "1"], "justification": "Explain briefly how the path leads to the answer using retrieved data only." } 8. FAILURE CONDITIONS - Return: { "answer": "INSUFFICIENT_CONTEXT", "reasoning_type": "insufficient", "path": [], "used_chunks": [], "justification": "No valid relationship path found in retrieved context." } IF: No relationship path exists | Only one-hop shortcut found | Information is missing """ STOPWORDS = { "a", "an", "and", "are", "as", "at", "be", "by", "for", "from", "has", "he", "in", "is", "it", "its", "of", "on", "that", "the", "to", "was", "were", "will", "with", "this", "these", "those", "or", "if", "then", "than", "into", "can", "could", "should", "would", "about", "over", "under", "after", "before", "between", "during", "also", "such", "their", "there", "them", "they", "you", "your", "we", "our", "i", "me", "my", "mine", "his", "her", "hers", "what", "which", "who", "whom", "when", "where", "why", "how", "do", "does", "did", "done", "not", "no", "yes", } @dataclass class RetrievedChunk: index: int score: float text: str class RAGEngine: def __init__( self, embed_provider: str = "graph", chunk_size: int = 180, chunk_overlap: int = 40, top_k: int = 3, llm_model: str = "Qwen/Qwen2.5-72B-Instruct", ): self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap self.top_k = top_k self.llm_model = llm_model self.embed_provider = "graph" self.chunks: list[str] = [] self._ingest_progress: dict = {"state": "idle", "embedded": 0, "total": 0} self._graph_lock = threading.Lock() self._ingest_generation = 0 self._chunk_terms: dict[int, Counter[str]] = {} self._chunk_entities: dict[int, set[str]] = {} self._entity_to_chunks: dict[str, set[int]] = defaultdict(set) self._entity_graph: dict[str, dict[str, float]] = defaultdict(dict) @property def ingest_progress(self) -> dict: return dict(self._ingest_progress) @property def is_ready(self) -> bool: return bool(self.chunks) and self._ingest_progress.get("state") != "running" def vector_store_status(self) -> dict: node_count = len(self._entity_to_chunks) edge_count = sum(len(v) for v in self._entity_graph.values()) // 2 return { "provider": "knowledge_graph", "is_ready": self.is_ready, "loaded_chunks": len(self.chunks), "ingest_state": self._ingest_progress.get("state", "idle"), "graph_nodes": node_count, "graph_edges": edge_count, "detail": "GraphRAG index active. Retrieval is entity and relation based, not vector similarity.", } def ingest(self, text: str) -> int: self._do_split(text) self._do_embed() return len(self.chunks) def start_ingest(self, text: str) -> int: self._do_split(text) return len(self.chunks) def _do_split(self, text: str): splitter = RecursiveCharacterTextSplitter( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, ) self.chunks = splitter.split_text(text) self._ingest_generation += 1 total = len(self.chunks) self._ingest_progress = {"state": "running", "embedded": 0, "total": total} log.info("Chunked into %d pieces (size=%d). Building graph index...", total, self.chunk_size) def _do_embed(self): gen = self._ingest_generation try: with self._graph_lock: if gen != self._ingest_generation: return total = len(self.chunks) if total == 0: self._clear_graph() self._ingest_progress = {"state": "done", "embedded": 0, "total": 0} return self._clear_graph() for idx, chunk in enumerate(self.chunks): if gen != self._ingest_generation: return terms = self._extract_terms(chunk) entities = self._extract_entities(chunk) if not entities: entities = set(terms[:5]) self._chunk_terms[idx] = Counter(terms) self._chunk_entities[idx] = entities for ent in entities: self._entity_to_chunks[ent].add(idx) ent_list = sorted(entities) for i in range(len(ent_list)): left = ent_list[i] for j in range(i + 1, len(ent_list)): right = ent_list[j] w = 1.0 / max(len(ent_list), 1) self._add_undirected_edge(left, right, w) self._ingest_progress["embedded"] = idx + 1 self._ingest_progress["state"] = "done" log.info( "Graph index built: %d chunks, %d nodes, %d edges", len(self.chunks), len(self._entity_to_chunks), sum(len(v) for v in self._entity_graph.values()) // 2, ) except Exception as e: log.exception("Graph index build failed") self._ingest_progress = {"state": "error", "embedded": 0, "total": 0, "error": str(e)} def _clear_graph(self): self._chunk_terms = {} self._chunk_entities = {} self._entity_to_chunks = defaultdict(set) self._entity_graph = defaultdict(dict) def _add_undirected_edge(self, left: str, right: str, weight: float): self._entity_graph[left][right] = self._entity_graph[left].get(right, 0.0) + weight self._entity_graph[right][left] = self._entity_graph[right].get(left, 0.0) + weight def _extract_terms(self, text: str) -> list[str]: tokens = re.findall(r"[A-Za-z][A-Za-z0-9_-]{2,}", text.lower()) return [t for t in tokens if t not in STOPWORDS] def _extract_entities(self, text: str) -> set[str]: entities: set[str] = set() for phrase in re.findall(r"\b([A-Z][a-z]+(?:\s+[A-Z][a-z]+){0,3})\b", text): norm = phrase.strip().lower() if len(norm) > 2 and norm not in STOPWORDS: entities.add(norm) for acronym in re.findall(r"\b[A-Z]{2,}\b", text): entities.add(acronym.lower()) terms = self._extract_terms(text) counts = Counter(terms) for token, count in counts.most_common(6): if count >= 2 and token not in STOPWORDS: entities.add(token) return entities def _expand_query_entities(self, query_entities: set[str], hop_limit: int = 1, per_node_limit: int = 6) -> set[str]: expanded = set(query_entities) frontier = set(query_entities) for _ in range(hop_limit): next_frontier: set[str] = set() for ent in frontier: neighbors = sorted( self._entity_graph.get(ent, {}).items(), key=lambda kv: kv[1], reverse=True, )[:per_node_limit] for nbr, _weight in neighbors: if nbr not in expanded: expanded.add(nbr) next_frontier.add(nbr) frontier = next_frontier if not frontier: break return expanded def _candidate_chunks(self, query_entities: set[str], expanded_entities: set[str]) -> set[int]: candidates: set[int] = set() for ent in expanded_entities: candidates.update(self._entity_to_chunks.get(ent, set())) if candidates: return candidates # No graph-entity match. Fall back to all chunks for lexical retrieval. return set(range(len(self.chunks))) def _keyword_overlap(self, query_terms: list[str], chunk_terms: Counter[str]) -> float: if not query_terms: return 0.0 hit = sum(1 for t in set(query_terms) if t in chunk_terms) return hit / max(len(set(query_terms)), 1) def _entity_coverage(self, query_entities: set[str], chunk_entities: set[str]) -> float: if not query_entities: return 0.0 inter = len(query_entities.intersection(chunk_entities)) return inter / max(len(query_entities), 1) def _neighbor_support(self, query_entities: set[str], chunk_entities: set[str]) -> float: if not query_entities: return 0.0 support = 0.0 for q in query_entities: nbrs = self._entity_graph.get(q, {}) if not nbrs: continue top = sorted(nbrs.items(), key=lambda kv: kv[1], reverse=True)[:8] total_weight = sum(w for _, w in top) or 1.0 chunk_weight = sum(w for e, w in top if e in chunk_entities) support += chunk_weight / total_weight return support / max(len(query_entities), 1) def retrieve(self, query: str) -> list[RetrievedChunk]: if not self.chunks: return [] query_terms = self._extract_terms(query) query_entities = self._extract_entities(query) expanded = self._expand_query_entities(query_entities) candidates = self._candidate_chunks(query_entities, expanded) scored: list[RetrievedChunk] = [] for idx in candidates: chunk_terms = self._chunk_terms.get(idx, Counter()) chunk_entities = self._chunk_entities.get(idx, set()) entity_score = self._entity_coverage(query_entities, chunk_entities) neighbor_score = self._neighbor_support(query_entities, chunk_entities) lexical_score = self._keyword_overlap(query_terms, chunk_terms) if query_entities: score = (0.55 * entity_score) + (0.25 * neighbor_score) + (0.20 * lexical_score) else: score = (0.85 * lexical_score) + (0.15 * (1.0 if chunk_entities else 0.0)) if score > 0: scored.append(RetrievedChunk(index=idx, score=float(score), text=self.chunks[idx])) if not scored: return [] scored.sort(key=lambda c: c.score, reverse=True) candidate_count = max(self.top_k * 4, self.top_k) candidates = scored[:candidate_count] top_score = candidates[0].score min_abs = float(os.environ.get("RETRIEVE_MIN_SCORE", "0.15")) min_ratio = float(os.environ.get("RETRIEVE_RELATIVE_RATIO", "0.35")) filtered = [c for c in candidates if c.score >= min_abs and c.score >= (top_score * min_ratio)] if not filtered: filtered = [candidates[0]] return filtered[: self.top_k] def _confidence_from_retrieved(self, retrieved: list[RetrievedChunk]) -> tuple[float, str]: top_score = float(retrieved[0].score) if retrieved else 0.0 strong_support = sum(1 for c in retrieved if c.score >= 0.5) if top_score >= 0.75 and strong_support >= 2: return top_score, "High" if top_score >= 0.45: return top_score, "Medium" return top_score, "Low" def answer(self, query: str, history: list[dict] | None = None, answer_mode: str = "balanced") -> dict: retrieved = self.retrieve(query) chunks_payload = [{"index": c.index, "score": round(c.score, 4), "text": c.text} for c in retrieved] top_score, confidence_label = self._confidence_from_retrieved(retrieved) if answer_mode == "strict_grounded" and confidence_label == "Low": return { "answer": "INSUFFICIENT_CONTEXT", "chunks": chunks_payload, "model_used": None, "top_score": round(top_score, 4), "confidence_label": confidence_label, "reasoning_type": "insufficient", "path": [], "used_chunks": [], "justification": "Low retrieval confidence; no valid path found.", } context_block = "\n\n".join( f"[chunk_id: {c.index}]\n{c.text}" for c in retrieved ) hf_token = os.environ.get("HF_TOKEN") if not hf_token: return { "answer": "HF_TOKEN not set.", "chunks": chunks_payload, "model_used": None, "top_score": round(top_score, 4), "confidence_label": confidence_label, "reasoning_type": "insufficient", "path": [], "used_chunks": [], "justification": "", } messages = self._build_messages(query, context_block, history) return self._call_llm(hf_token, messages, chunks_payload, top_score, confidence_label) def stream_answer(self, query: str, history: list[dict] | None = None, answer_mode: str = "balanced"): retrieved = self.retrieve(query) chunks_payload = [{"index": c.index, "score": round(c.score, 4), "text": c.text} for c in retrieved] top_score, confidence_label = self._confidence_from_retrieved(retrieved) yield {"type": "meta", "top_score": round(top_score, 4), "confidence_label": confidence_label} yield {"type": "chunks", "data": chunks_payload} if answer_mode == "strict_grounded" and confidence_label == "Low": yield {"type": "token", "data": "INSUFFICIENT_CONTEXT"} yield { "type": "done", "model_used": None, "reasoning_type": "insufficient", "path": [], "used_chunks": [], "justification": "Low retrieval confidence; no valid path found.", } return hf_token = os.environ.get("HF_TOKEN") if not hf_token: yield {"type": "token", "data": "HF_TOKEN not set."} yield {"type": "done", "model_used": None, "reasoning_type": "insufficient", "path": [], "used_chunks": [], "justification": ""} return context_block = "\n\n".join( f"[chunk_id: {c.index}]\n{c.text}" for c in retrieved ) messages = self._build_messages(query, context_block, history) # Buffer full LLM output so we can parse the JSON before emitting clean answer tokens. full_text, model_used = self._buffer_llm(hf_token, messages) graph = self._parse_graph_response(full_text) answer_text = graph.get("answer", full_text) # Emit answer text word-by-word so the frontend stream still assembles naturally. for token_chunk in re.split(r'(\s+)', answer_text): if token_chunk: yield {"type": "token", "data": token_chunk} yield { "type": "done", "model_used": model_used, "reasoning_type": graph.get("reasoning_type", "direct"), "path": graph.get("path", []), "used_chunks": graph.get("used_chunks", []), "justification": graph.get("justification", ""), } def _buffer_llm(self, token: str, messages: list[dict]) -> tuple[str, str | None]: """Non-streaming call used by stream_answer to enable JSON parsing before token emission.""" client = InferenceClient(api_key=token) candidates = list(dict.fromkeys([self.llm_model] + FALLBACK_MODELS)) for model in candidates: try: resp = client.chat_completion(model=model, messages=messages, max_tokens=600, temperature=0.2) return resp.choices[0].message.content, model except Exception as e: log.warning("Buffered LLM %s failed: %s", model, e) continue return "All candidate models failed.", None def _parse_graph_response(self, text: str) -> dict: """Extract and parse the mandatory JSON object from the LLM response.""" import json # Strip optional markdown fences cleaned = re.sub(r'^```(?:json)?\s*|\s*```$', '', text.strip(), flags=re.MULTILINE) # Grab outermost JSON object match = re.search(r'\{.*\}', cleaned, re.DOTALL) if match: try: return json.loads(match.group()) except json.JSONDecodeError: pass # Fallback: treat raw text as answer, mark as direct return { "answer": text.strip(), "reasoning_type": "direct", "path": [], "used_chunks": [], "justification": "JSON parse failed; raw answer returned.", } def _build_messages(self, query: str, context: str, history: list[dict] | None) -> list[dict]: messages: list[dict] = [{"role": "system", "content": SYSTEM_PROMPT}] if history: messages.extend({"role": m["role"], "content": m["content"]} for m in history) messages.append( { "role": "user", "content": ( f"Question:\n{query}\n\n" "Retrieved Context (use the chunk_id values in your used_chunks field):\n" f"{context}\n\n" "Instructions:\n" "- Step 1: Extract entities from the question.\n" "- Step 2: Traverse relationships across chunks to build a path.\n" "- Step 3: Return ONLY a valid JSON object matching the mandatory format.\n" "- Do NOT include any text outside the JSON object." ), } ) return messages def _call_llm( self, token: str, messages: list[dict], chunks_payload: list[dict], top_score: float, confidence_label: str, ) -> dict: client = InferenceClient(api_key=token) candidates = list(dict.fromkeys([self.llm_model] + FALLBACK_MODELS)) for model in candidates: try: resp = client.chat_completion(model=model, messages=messages, max_tokens=600, temperature=0.2) raw = resp.choices[0].message.content graph = self._parse_graph_response(raw) return { "answer": graph.get("answer", raw), "chunks": chunks_payload, "model_used": model, "top_score": round(top_score, 4), "confidence_label": confidence_label, "reasoning_type": graph.get("reasoning_type", "direct"), "path": graph.get("path", []), "used_chunks": graph.get("used_chunks", []), "justification": graph.get("justification", ""), } except Exception as e: log.warning("LLM %s failed: %s", model, e) continue return { "answer": "All candidate models failed.", "chunks": chunks_payload, "model_used": None, "top_score": round(top_score, 4), "confidence_label": confidence_label, "reasoning_type": "insufficient", "path": [], "used_chunks": [], "justification": "All LLM candidates failed.", }