""" LLM Rationale Generator ======================== Generates brief, grounded rationales for why each retrieved standard is relevant to the query - using ONLY free, local, open-source methods. Zero paid APIs. Zero external calls. CRITICAL DESIGN PRINCIPLE: - The `retrieved_standards` list (IS numbers) comes ENTIRELY from the retriever. - The LLM is NEVER asked to suggest, generate, or modify standard IDs. - The LLM is ONLY asked to write 1-sentence rationales based on provided context. - This guarantees zero hallucination of standard IDs. Two modes (both free): 1. "extract" - sentence-scoring from chunk text (default, instant, zero deps) 2. "llm" - local HuggingFace model via transformers (better quality, needs GPU/RAM) Fallback: If the local LLM fails for any reason, automatically drops to extract mode. This ensures the pipeline always produces valid output. """ import re import logging from typing import Optional from collections import Counter logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def generate_rationales( query: str, retrieved: list, use_llm: bool = False, model_path: Optional[str] = None, api_key: Optional[str] = None, # kept for signature compat, ignored ) -> list: """ Generate rationales for retrieved standards. Args: query : Original product description query. retrieved : List of dicts from BISRetriever.retrieve(). Each must have: standard_id, title, rationale_context, category, score. use_llm : If True, attempt a local HuggingFace model first, then fall back to extract mode on any error. model_path : HuggingFace model ID or local path for LLM mode. Defaults to "TinyLlama/TinyLlama-1.1B-Chat-v1.0". api_key : Ignored - kept only for drop-in compatibility. Returns: List of dicts with an added 'rationale' field. standard_id values are NEVER modified. GUARANTEE: This function NEVER modifies the standard_id field. All IS codes in output come verbatim from `retrieved`. """ if not retrieved: return [] if use_llm: try: return _llm_rationales(query, retrieved, model_path) except Exception as e: logger.warning(f"Local LLM rationale failed: {e}. Falling back to extract mode.") return _extract_rationales(query, retrieved) def generate_rationales_batch( queries: list, retrieved_batch: list, use_llm: bool = False, model_path: Optional[str] = None, api_key: Optional[str] = None, ) -> list: """ Batch version of generate_rationales. Processes each query independently. In extract mode (default) all calls are near-instant. """ return [ generate_rationales(q, r, use_llm=use_llm, model_path=model_path) for q, r in zip(queries, retrieved_batch) ] # --------------------------------------------------------------------------- # Mode 1: Sentence-extraction rationale (default - fast, zero deps) # --------------------------------------------------------------------------- _STOPWORDS = { "the", "a", "an", "of", "for", "in", "is", "to", "and", "or", "be", "this", "with", "at", "by", "as", "on", "that", "are", "its", "it", "was", "from", "has", "have", "not", "shall", "which", "also", "such", "may", "all", "any", "been", "their", "they", "more", "used", "when", "where", "while", "than", "into", "these", "those", "can", "will", } def _tokenize(text: str) -> list: return [w for w in re.findall(r'\b[a-z]{3,}\b', text.lower()) if w not in _STOPWORDS] def _score_sentence(sent: str, query_tokens: set) -> float: """Score a sentence by token overlap with query, normalised by length.""" sent_tokens = set(_tokenize(sent)) if not sent_tokens: return 0.0 overlap = len(query_tokens & sent_tokens) length_bonus = min(len(sent_tokens) / 8, 1.0) return (overlap / len(query_tokens | sent_tokens)) * (0.6 + 0.4 * length_bonus) def _best_sentences(text: str, query_tokens: set, n: int = 2) -> str: """Return the top-n most relevant sentences from text, joined.""" sentences = re.split(r'(?<=[.!?])\s+', text.strip()) sentences = [s.strip() for s in sentences if len(s.split()) >= 5] if not sentences: words = text.split()[:40] return " ".join(words) + ("..." if len(text.split()) > 40 else "") scored = sorted(sentences, key=lambda s: _score_sentence(s, query_tokens), reverse=True) top = scored[:n] order = {s: i for i, s in enumerate(sentences)} top.sort(key=lambda s: order.get(s, 999)) result = " ".join(top) return result[:350] + ("..." if len(result) > 350 else "") def _extract_rationales(query: str, retrieved: list) -> list: """ Extract the most query-relevant sentences from each chunk's text. No model needed - fast, deterministic, zero hallucination. """ query_tokens = set(_tokenize(query)) if not query_tokens: query_tokens = set(query.lower().split()) result = [] for item in retrieved: context = ( item.get("rationale_context") or item.get("text") or item.get("chunk_text") or "" ) if context.strip(): rationale = _best_sentences(context, query_tokens, n=2) else: rationale = _template_rationale_single(query, item) result.append({**item, "rationale": rationale}) return result # --------------------------------------------------------------------------- # Mode 2: Local HuggingFace LLM (optional, better quality) # --------------------------------------------------------------------------- _llm_pipeline = None # cached after first load def _load_llm(model_path: Optional[str]): """Load (and cache) a local HuggingFace text-generation pipeline.""" global _llm_pipeline if _llm_pipeline is not None: return _llm_pipeline from transformers import pipeline as hf_pipeline model_id = model_path or "TinyLlama/TinyLlama-1.1B-Chat-v1.0" logger.info(f"Loading local LLM: {model_id} (first call only)") _llm_pipeline = hf_pipeline( "text-generation", model=model_id, max_new_tokens=180, temperature=0.05, do_sample=False, device_map="auto", return_full_text=False, ) logger.info("Local LLM loaded.") return _llm_pipeline def _build_prompt(query: str, item: dict) -> str: """Build a tightly scoped prompt for one standard.""" sid = item["standard_id"] title = item.get("title", "") context = ( item.get("rationale_context") or item.get("text") or item.get("chunk_text") or "" )[:400] return ( f"<|system|>You are a BIS Standards compliance expert. " f"Answer strictly from the provided excerpt. " f"Do NOT mention any standard IDs other than the one given.\n" f"<|user|>Product query: {query}\n\n" f"Standard: {sid} - {title}\n" f"Excerpt: {context}\n\n" f"Write ONE sentence (max 50 words) explaining why {sid} is relevant to the query above. " f"Start directly with the explanation.\n" f"<|assistant|>" ) def _llm_rationales(query: str, retrieved: list, model_path: Optional[str]) -> list: """ Generate rationales using a local HuggingFace model. Falls back per-item to extract mode if generation fails. """ pipe = _load_llm(model_path) query_tokens = set(_tokenize(query)) result = [] for item in retrieved: try: prompt = _build_prompt(query, item) out = pipe(prompt) raw = out[0]["generated_text"].strip() own_id = item["standard_id"] raw = _strip_foreign_standard_ids(raw, own_id) rationale = raw[:300] if raw else _template_rationale_single(query, item) except Exception as e: logger.warning(f"LLM generation failed for {item['standard_id']}: {e}") rationale = _best_sentences( item.get("rationale_context") or item.get("text") or "", query_tokens ) or _template_rationale_single(query, item) result.append({**item, "rationale": rationale}) return result def _strip_foreign_standard_ids(text: str, own_id: str) -> str: """ Remove any IS code references from generated text that are NOT own_id. Prevents the LLM from sneaking in hallucinated standard IDs. """ own_num = re.search(r'\d+', own_id) own_num_str = own_num.group() if own_num else "" def _replacer(m): if own_num_str and own_num_str in m.group(): return m.group() return "(the standard)" return re.sub(r'\bIS[\s:.\-]?\d{2,6}\b', _replacer, text, flags=re.IGNORECASE) # --------------------------------------------------------------------------- # Template fallback (used only when chunk text is empty) # --------------------------------------------------------------------------- def _template_rationale_single(query: str, item: dict) -> str: """Last-resort template rationale built purely from metadata.""" sid = item["standard_id"] title = item.get("title") or "this building material standard" category = item.get("category", "Building Materials") _templates = { "Cement": f"{sid} specifies requirements for {title[:80]}, directly applicable to your cement query.", "Aggregates": f"{sid} covers {title[:80]}, relevant to your aggregates compliance requirements.", "Concrete Products": f"{sid} provides specifications for {title[:80]}, applicable to your concrete product.", "Steel": f"{sid} specifies requirements for {title[:80]}, relevant to your steel product query.", "Masonry": f"{sid} covers {title[:80]}, applicable to your masonry product requirements.", "Roofing": f"{sid} specifies {title[:80]}, directly relevant to your roofing material query.", "Pipes": f"{sid} provides specifications for {title[:80]}, applicable to your pipe query.", } return _templates.get( category, f"{sid} - {title[:80]} - is relevant to your product based on material type and application." ) # --------------------------------------------------------------------------- # Quick test # --------------------------------------------------------------------------- if __name__ == "__main__": test_retrieved = [ { "standard_id": "IS 269", "title": "Specification for Ordinary Portland Cement, 33 Grade", "category": "Cement", "score": 0.95, "rationale_context": ( "This standard covers 33 grade ordinary portland cement. " "Chemical requirements include lime saturation factor, silica ratio, " "alumina iron ratio. Physical tests cover fineness, setting time, " "soundness, and compressive strength at 3 and 7 days." ), }, { "standard_id": "IS 1489", "title": "Portland Pozzolana Cement", "category": "Cement", "score": 0.88, "rationale_context": ( "This standard covers Portland pozzolana cement made by intergrinding " "Portland cement clinker with pozzolanic materials. Suitable for structures " "exposed to sulphate attack and marine environments." ), }, ] print("=== Extract mode (default) ===") results = generate_rationales( query="33 Grade Ordinary Portland Cement for RCC construction", retrieved=test_retrieved, use_llm=False, ) for r in results: print(f"\n{r['standard_id']}:\n {r['rationale']}")