""" Query Analyzer — extracts semantic and structural features from incoming queries. V1: heuristic rules + keyword scoring + regex. Designed so that a V2 DistilBERT/MiniLM classifier can drop in by implementing the same QueryFeatures interface. """ from __future__ import annotations import re import logging from dataclasses import dataclass, field from typing import Optional logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Feature schema # --------------------------------------------------------------------------- @dataclass class QueryFeatures: """Structured representation of query characteristics.""" raw_query: str # Structural token_count: int = 0 sentence_count: int = 0 has_code_block: bool = False has_math_notation: bool = False has_numbered_list: bool = False has_question_mark: bool = False has_bullet_points: bool = False # Semantic domain flags domain_code: bool = False domain_math: bool = False domain_science: bool = False domain_creative: bool = False domain_factual: bool = False domain_reasoning: bool = False domain_summarization: bool = False domain_translation: bool = False # Complexity signals multi_step: bool = False # "first ... then ... finally" requires_comparison: bool = False requires_generation: bool = False requires_explanation: bool = False requires_analysis: bool = False requires_debate: bool = False # Estimated output length bucket estimated_output_length: str = "short" # short | medium | long | very_long # Derived primary_domain: str = "general" domain_scores: dict[str, float] = field(default_factory=dict) _expert_signal: bool = False def to_dict(self) -> dict: return self.__dict__.copy() # --------------------------------------------------------------------------- # Keyword lists # --------------------------------------------------------------------------- _CODE_KEYWORDS = { "code", "function", "algorithm", "implement", "debug", "fix bug", "write a script", "python", "javascript", "typescript", "java", "c++", "rust", "golang", "sql", "class", "method", "api", "framework", "library", "module", "refactor", "unit test", "regex", "recursion", "data structure", "linked list", "binary tree", "graph", "sort", "search", "complexity", "big o", "dockerfile", "kubernetes", } _MATH_KEYWORDS = { "prove", "theorem", "equation", "integral", "derivative", "matrix", "eigenvalue", "probability", "statistics", "calculus", "algebra", "geometry", "topology", "series", "differential", "optimize", "maximize", "minimize", "linear programming", "convex", "gradient", "bayesian", "stochastic", "markov", } _SCIENCE_KEYWORDS = { "physics", "chemistry", "biology", "quantum", "relativity", "neural", "neuroscience", "genetics", "dna", "protein", "molecule", "atom", "thermodynamics", "entropy", "evolution", "ecology", "astronomy", "cosmology", "electromagnetism", } _CREATIVE_KEYWORDS = { "write a story", "poem", "song", "creative writing", "fiction", "narrative", "character", "dialogue", "screenplay", "plot", "novel", "essay", "persuasive", "metaphor", "haiku", "limerick", } _REASONING_KEYWORDS = { "analyze", "evaluate", "compare", "contrast", "critique", "argue", "debate", "reason", "logical", "fallacy", "philosophy", "ethics", "should", "why", "explain why", "what would happen", "consequences", "trade-off", "pros and cons", "design", "architecture", "strategy", } _SUMMARIZATION_KEYWORDS = { "summarize", "summary", "tldr", "tl;dr", "brief", "overview", "key points", "main ideas", "condense", "shorten", } _TRANSLATION_KEYWORDS = { "translate", "in french", "in spanish", "in german", "in japanese", "in chinese", "in hindi", "en français", "auf deutsch", } _MULTI_STEP_PATTERNS = [ r"\bstep[s]?\b", r"\bfirst\b.{0,60}\bthen\b", r"\bfinally\b", r"\b(1\.|2\.|3\.)", r"\bplan\b", r"\bworkflow\b", r"\bprocess\b", ] _MATH_NOTATION_PATTERNS = [ r"\$.*?\$", # LaTeX inline r"\\\[.*?\\\]", # LaTeX block r"[∫∑∏√∞∂∇≤≥≠≈±×÷]", # Unicode math symbols r"\b\d+\s*[\+\-\*\/\^]\s*\d+", # arithmetic expressions ] # Output length estimation heuristics (in words of expected response) _OUTPUT_LENGTH_MAP = { "short": ["what is", "define", "who is", "when was", "where is", "yes or no"], "medium": ["explain", "describe", "how does", "difference between"], "long": ["compare", "analyze", "write a", "design", "implement", "step by step"], "very_long": ["comprehensive", "in depth", "detailed guide", "full tutorial", "write an essay", "complete implementation"], } # --------------------------------------------------------------------------- # Analyzer # --------------------------------------------------------------------------- class QueryAnalyzer: """ Extracts semantic + structural features from a raw query string. V2: Uses HuggingFace zero-shot classification if transformers is installed. Falls back to V1 heuristics if not available. """ def __init__(self): self.ml_classifier = None self.ml_labels = [ "code", "math", "science", "creative", "reasoning", "summarization", "translation", "factual" ] import os if os.getenv("USE_ML_ANALYZER", "false").lower() == "true": try: from transformers import pipeline # type: ignore logger.info("Loading ML Zero-Shot Classifier for Query Analyzer...") self.ml_classifier = pipeline( "zero-shot-classification", model="cross-encoder/nli-distilroberta-base", device=-1, local_files_only=True ) except ImportError: logger.info("transformers not found, using V1 heuristic Query Analyzer.") except Exception as e: logger.warning(f"Failed to load ML classifier: {e}. Falling back to V1.") def analyze(self, query: str) -> QueryFeatures: q = query.strip() ql = q.lower() features = QueryFeatures(raw_query=q) self._structural_features(q, ql, features) self._domain_features(ql, features) self._complexity_signals(ql, features) self._output_length_estimate(ql, features) self._primary_domain(features) return features # ------------------------------------------------------------------ # Structural # ------------------------------------------------------------------ def _structural_features(self, q: str, ql: str, f: QueryFeatures) -> None: f.token_count = len(q.split()) f.sentence_count = max(1, len(re.split(r"[.!?]+", q))) f.has_code_block = bool(re.search(r"```|`[^`]+`", q)) f.has_math_notation = any( bool(re.search(p, q)) for p in _MATH_NOTATION_PATTERNS ) f.has_numbered_list = bool(re.search(r"^\s*\d+[\.\)]", q, re.MULTILINE)) f.has_question_mark = "?" in q f.has_bullet_points = bool(re.search(r"^\s*[-*•]", q, re.MULTILINE)) # ------------------------------------------------------------------ # Domain classification # ------------------------------------------------------------------ def _domain_features(self, ql: str, f: QueryFeatures) -> None: scores: dict[str, float] = {} if self.ml_classifier: try: result = self.ml_classifier(f.raw_query, self.ml_labels, multi_label=True) for label, score in zip(result['labels'], result['scores']): scores[label] = score except Exception as e: logger.warning(f"ML inference failed: {e}. Falling back to V1.") if not scores: scores = { "code": self._keyword_score(ql, _CODE_KEYWORDS), "math": self._keyword_score(ql, _MATH_KEYWORDS), "science": self._keyword_score(ql, _SCIENCE_KEYWORDS), "creative": self._keyword_score(ql, _CREATIVE_KEYWORDS), "reasoning": self._keyword_score(ql, _REASONING_KEYWORDS), "summarization": self._keyword_score(ql, _SUMMARIZATION_KEYWORDS), "translation": self._keyword_score(ql, _TRANSLATION_KEYWORDS), } f.domain_scores = scores if self.ml_classifier: f.domain_code = scores.get("code", 0) > 0.4 f.domain_math = scores.get("math", 0) > 0.4 or f.has_math_notation f.domain_science = scores.get("science", 0) > 0.4 f.domain_creative = scores.get("creative", 0) > 0.4 f.domain_reasoning = scores.get("reasoning", 0) > 0.4 f.domain_summarization = scores.get("summarization", 0) > 0.4 f.domain_translation = scores.get("translation", 0) > 0.4 f.domain_factual = scores.get("factual", 0) > 0.4 else: f.domain_code = scores.get("code", 0) > 0 f.domain_math = scores.get("math", 0) > 0 or f.has_math_notation f.domain_science = scores.get("science", 0) > 0 f.domain_creative = scores.get("creative", 0) > 0 f.domain_reasoning = scores.get("reasoning", 0) > 0 f.domain_summarization = scores.get("summarization", 0) > 0 f.domain_translation = scores.get("translation", 0) > 0 f.domain_factual = ( f.has_question_mark and sum(scores.values()) < 0.5 and f.token_count < 25 ) def _keyword_score(self, ql: str, keywords: set) -> float: """Fraction of keywords found; capped to avoid over-scoring long queries.""" hits = sum(1 for kw in keywords if kw in ql) return min(hits / max(len(keywords), 1) * 20, 1.0) # normalize # ------------------------------------------------------------------ # Complexity signals # ------------------------------------------------------------------ def _complexity_signals(self, ql: str, f: QueryFeatures) -> None: f.multi_step = any( bool(re.search(p, ql)) for p in _MULTI_STEP_PATTERNS ) f.requires_comparison = any( kw in ql for kw in ["compare", "contrast", "vs", "versus", "difference between", "pros and cons"] ) f.requires_generation = any( kw in ql for kw in ["write", "generate", "create", "build", "make", "implement", "design"] ) f.requires_explanation = any( kw in ql for kw in ["explain", "how does", "why does", "what is", "describe"] ) f.requires_analysis = any( kw in ql for kw in ["analyze", "evaluate", "assess", "review", "critique"] ) f.requires_debate = any( kw in ql for kw in ["debate", "argue", "advocate", "defend", "persuade"] ) # Expert-level signal: proofs, formal derivations, system design f._expert_signal = any( kw in ql for kw in [ "prove", "proof", "theorem", "derive", "derivation", "from first principles", "consensus algorithm", "paxos", "raft", "distributed system", "compiler", "operating system", "quantum", "eigenvalue", "eigenvector", ] ) # ------------------------------------------------------------------ # Output length # ------------------------------------------------------------------ def _output_length_estimate(self, ql: str, f: QueryFeatures) -> None: # Override for known long-output domains regardless of query length if f.domain_math and any(kw in ql for kw in ["prove", "derive", "theorem", "proof"]): f.estimated_output_length = "long" return for bucket, triggers in reversed(list(_OUTPUT_LENGTH_MAP.items())): if any(t in ql for t in triggers): f.estimated_output_length = bucket return # Default based on token count if f.token_count < 10: f.estimated_output_length = "medium" # even trivial math needs explanation elif f.token_count < 30: f.estimated_output_length = "medium" else: f.estimated_output_length = "long" # ------------------------------------------------------------------ # Primary domain # ------------------------------------------------------------------ def _primary_domain(self, f: QueryFeatures) -> None: scores = f.domain_scores if not scores or max(scores.values()) == 0: f.primary_domain = "general" return f.primary_domain = max(scores.keys(), key=lambda k: scores[k])