Spaces:
Sleeping
Sleeping
| """ | |
| Query Analyzer — extracts semantic and structural features from incoming queries. | |
| V1: heuristic rules + keyword scoring + regex. | |
| Designed so that a V2 DistilBERT/MiniLM classifier can drop in by | |
| implementing the same QueryFeatures interface. | |
| """ | |
| from __future__ import annotations | |
| import re | |
| import logging | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Feature schema | |
| # --------------------------------------------------------------------------- | |
| class QueryFeatures: | |
| """Structured representation of query characteristics.""" | |
| raw_query: str | |
| # Structural | |
| token_count: int = 0 | |
| sentence_count: int = 0 | |
| has_code_block: bool = False | |
| has_math_notation: bool = False | |
| has_numbered_list: bool = False | |
| has_question_mark: bool = False | |
| has_bullet_points: bool = False | |
| # Semantic domain flags | |
| domain_code: bool = False | |
| domain_math: bool = False | |
| domain_science: bool = False | |
| domain_creative: bool = False | |
| domain_factual: bool = False | |
| domain_reasoning: bool = False | |
| domain_summarization: bool = False | |
| domain_translation: bool = False | |
| # Complexity signals | |
| multi_step: bool = False # "first ... then ... finally" | |
| requires_comparison: bool = False | |
| requires_generation: bool = False | |
| requires_explanation: bool = False | |
| requires_analysis: bool = False | |
| requires_debate: bool = False | |
| # Estimated output length bucket | |
| estimated_output_length: str = "short" # short | medium | long | very_long | |
| # Derived | |
| primary_domain: str = "general" | |
| domain_scores: dict[str, float] = field(default_factory=dict) | |
| _expert_signal: bool = False | |
| def to_dict(self) -> dict: | |
| return self.__dict__.copy() | |
| # --------------------------------------------------------------------------- | |
| # Keyword lists | |
| # --------------------------------------------------------------------------- | |
| _CODE_KEYWORDS = { | |
| "code", "function", "algorithm", "implement", "debug", "fix bug", | |
| "write a script", "python", "javascript", "typescript", "java", "c++", | |
| "rust", "golang", "sql", "class", "method", "api", "framework", | |
| "library", "module", "refactor", "unit test", "regex", "recursion", | |
| "data structure", "linked list", "binary tree", "graph", "sort", | |
| "search", "complexity", "big o", "dockerfile", "kubernetes", | |
| } | |
| _MATH_KEYWORDS = { | |
| "prove", "theorem", "equation", "integral", "derivative", "matrix", | |
| "eigenvalue", "probability", "statistics", "calculus", "algebra", | |
| "geometry", "topology", "series", "differential", "optimize", | |
| "maximize", "minimize", "linear programming", "convex", "gradient", | |
| "bayesian", "stochastic", "markov", | |
| } | |
| _SCIENCE_KEYWORDS = { | |
| "physics", "chemistry", "biology", "quantum", "relativity", "neural", | |
| "neuroscience", "genetics", "dna", "protein", "molecule", "atom", | |
| "thermodynamics", "entropy", "evolution", "ecology", "astronomy", | |
| "cosmology", "electromagnetism", | |
| } | |
| _CREATIVE_KEYWORDS = { | |
| "write a story", "poem", "song", "creative writing", "fiction", | |
| "narrative", "character", "dialogue", "screenplay", "plot", "novel", | |
| "essay", "persuasive", "metaphor", "haiku", "limerick", | |
| } | |
| _REASONING_KEYWORDS = { | |
| "analyze", "evaluate", "compare", "contrast", "critique", "argue", | |
| "debate", "reason", "logical", "fallacy", "philosophy", "ethics", | |
| "should", "why", "explain why", "what would happen", "consequences", | |
| "trade-off", "pros and cons", "design", "architecture", "strategy", | |
| } | |
| _SUMMARIZATION_KEYWORDS = { | |
| "summarize", "summary", "tldr", "tl;dr", "brief", "overview", | |
| "key points", "main ideas", "condense", "shorten", | |
| } | |
| _TRANSLATION_KEYWORDS = { | |
| "translate", "in french", "in spanish", "in german", "in japanese", | |
| "in chinese", "in hindi", "en français", "auf deutsch", | |
| } | |
| _MULTI_STEP_PATTERNS = [ | |
| r"\bstep[s]?\b", | |
| r"\bfirst\b.{0,60}\bthen\b", | |
| r"\bfinally\b", | |
| r"\b(1\.|2\.|3\.)", | |
| r"\bplan\b", | |
| r"\bworkflow\b", | |
| r"\bprocess\b", | |
| ] | |
| _MATH_NOTATION_PATTERNS = [ | |
| r"\$.*?\$", # LaTeX inline | |
| r"\\\[.*?\\\]", # LaTeX block | |
| r"[∫∑∏√∞∂∇≤≥≠≈±×÷]", # Unicode math symbols | |
| r"\b\d+\s*[\+\-\*\/\^]\s*\d+", # arithmetic expressions | |
| ] | |
| # Output length estimation heuristics (in words of expected response) | |
| _OUTPUT_LENGTH_MAP = { | |
| "short": ["what is", "define", "who is", "when was", "where is", "yes or no"], | |
| "medium": ["explain", "describe", "how does", "difference between"], | |
| "long": ["compare", "analyze", "write a", "design", "implement", "step by step"], | |
| "very_long": ["comprehensive", "in depth", "detailed guide", "full tutorial", | |
| "write an essay", "complete implementation"], | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Analyzer | |
| # --------------------------------------------------------------------------- | |
| class QueryAnalyzer: | |
| """ | |
| Extracts semantic + structural features from a raw query string. | |
| V2: Uses HuggingFace zero-shot classification if transformers is installed. | |
| Falls back to V1 heuristics if not available. | |
| """ | |
| def __init__(self): | |
| self.ml_classifier = None | |
| self.ml_labels = [ | |
| "code", "math", "science", "creative", | |
| "reasoning", "summarization", "translation", "factual" | |
| ] | |
| import os | |
| if os.getenv("USE_ML_ANALYZER", "false").lower() == "true": | |
| try: | |
| from transformers import pipeline # type: ignore | |
| logger.info("Loading ML Zero-Shot Classifier for Query Analyzer...") | |
| self.ml_classifier = pipeline( | |
| "zero-shot-classification", | |
| model="cross-encoder/nli-distilroberta-base", | |
| device=-1, | |
| local_files_only=True | |
| ) | |
| except ImportError: | |
| logger.info("transformers not found, using V1 heuristic Query Analyzer.") | |
| except Exception as e: | |
| logger.warning(f"Failed to load ML classifier: {e}. Falling back to V1.") | |
| def analyze(self, query: str) -> QueryFeatures: | |
| q = query.strip() | |
| ql = q.lower() | |
| features = QueryFeatures(raw_query=q) | |
| self._structural_features(q, ql, features) | |
| self._domain_features(ql, features) | |
| self._complexity_signals(ql, features) | |
| self._output_length_estimate(ql, features) | |
| self._primary_domain(features) | |
| return features | |
| # ------------------------------------------------------------------ | |
| # Structural | |
| # ------------------------------------------------------------------ | |
| def _structural_features(self, q: str, ql: str, f: QueryFeatures) -> None: | |
| f.token_count = len(q.split()) | |
| f.sentence_count = max(1, len(re.split(r"[.!?]+", q))) | |
| f.has_code_block = bool(re.search(r"```|`[^`]+`", q)) | |
| f.has_math_notation = any( | |
| bool(re.search(p, q)) for p in _MATH_NOTATION_PATTERNS | |
| ) | |
| f.has_numbered_list = bool(re.search(r"^\s*\d+[\.\)]", q, re.MULTILINE)) | |
| f.has_question_mark = "?" in q | |
| f.has_bullet_points = bool(re.search(r"^\s*[-*•]", q, re.MULTILINE)) | |
| # ------------------------------------------------------------------ | |
| # Domain classification | |
| # ------------------------------------------------------------------ | |
| def _domain_features(self, ql: str, f: QueryFeatures) -> None: | |
| scores: dict[str, float] = {} | |
| if self.ml_classifier: | |
| try: | |
| result = self.ml_classifier(f.raw_query, self.ml_labels, multi_label=True) | |
| for label, score in zip(result['labels'], result['scores']): | |
| scores[label] = score | |
| except Exception as e: | |
| logger.warning(f"ML inference failed: {e}. Falling back to V1.") | |
| if not scores: | |
| scores = { | |
| "code": self._keyword_score(ql, _CODE_KEYWORDS), | |
| "math": self._keyword_score(ql, _MATH_KEYWORDS), | |
| "science": self._keyword_score(ql, _SCIENCE_KEYWORDS), | |
| "creative": self._keyword_score(ql, _CREATIVE_KEYWORDS), | |
| "reasoning": self._keyword_score(ql, _REASONING_KEYWORDS), | |
| "summarization": self._keyword_score(ql, _SUMMARIZATION_KEYWORDS), | |
| "translation": self._keyword_score(ql, _TRANSLATION_KEYWORDS), | |
| } | |
| f.domain_scores = scores | |
| if self.ml_classifier: | |
| f.domain_code = scores.get("code", 0) > 0.4 | |
| f.domain_math = scores.get("math", 0) > 0.4 or f.has_math_notation | |
| f.domain_science = scores.get("science", 0) > 0.4 | |
| f.domain_creative = scores.get("creative", 0) > 0.4 | |
| f.domain_reasoning = scores.get("reasoning", 0) > 0.4 | |
| f.domain_summarization = scores.get("summarization", 0) > 0.4 | |
| f.domain_translation = scores.get("translation", 0) > 0.4 | |
| f.domain_factual = scores.get("factual", 0) > 0.4 | |
| else: | |
| f.domain_code = scores.get("code", 0) > 0 | |
| f.domain_math = scores.get("math", 0) > 0 or f.has_math_notation | |
| f.domain_science = scores.get("science", 0) > 0 | |
| f.domain_creative = scores.get("creative", 0) > 0 | |
| f.domain_reasoning = scores.get("reasoning", 0) > 0 | |
| f.domain_summarization = scores.get("summarization", 0) > 0 | |
| f.domain_translation = scores.get("translation", 0) > 0 | |
| f.domain_factual = ( | |
| f.has_question_mark | |
| and sum(scores.values()) < 0.5 | |
| and f.token_count < 25 | |
| ) | |
| def _keyword_score(self, ql: str, keywords: set) -> float: | |
| """Fraction of keywords found; capped to avoid over-scoring long queries.""" | |
| hits = sum(1 for kw in keywords if kw in ql) | |
| return min(hits / max(len(keywords), 1) * 20, 1.0) # normalize | |
| # ------------------------------------------------------------------ | |
| # Complexity signals | |
| # ------------------------------------------------------------------ | |
| def _complexity_signals(self, ql: str, f: QueryFeatures) -> None: | |
| f.multi_step = any( | |
| bool(re.search(p, ql)) for p in _MULTI_STEP_PATTERNS | |
| ) | |
| f.requires_comparison = any( | |
| kw in ql for kw in ["compare", "contrast", "vs", "versus", "difference between", "pros and cons"] | |
| ) | |
| f.requires_generation = any( | |
| kw in ql for kw in ["write", "generate", "create", "build", "make", "implement", "design"] | |
| ) | |
| f.requires_explanation = any( | |
| kw in ql for kw in ["explain", "how does", "why does", "what is", "describe"] | |
| ) | |
| f.requires_analysis = any( | |
| kw in ql for kw in ["analyze", "evaluate", "assess", "review", "critique"] | |
| ) | |
| f.requires_debate = any( | |
| kw in ql for kw in ["debate", "argue", "advocate", "defend", "persuade"] | |
| ) | |
| # Expert-level signal: proofs, formal derivations, system design | |
| f._expert_signal = any( | |
| kw in ql for kw in [ | |
| "prove", "proof", "theorem", "derive", "derivation", | |
| "from first principles", "consensus algorithm", "paxos", "raft", | |
| "distributed system", "compiler", "operating system", | |
| "quantum", "eigenvalue", "eigenvector", | |
| ] | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Output length | |
| # ------------------------------------------------------------------ | |
| def _output_length_estimate(self, ql: str, f: QueryFeatures) -> None: | |
| # Override for known long-output domains regardless of query length | |
| if f.domain_math and any(kw in ql for kw in ["prove", "derive", "theorem", "proof"]): | |
| f.estimated_output_length = "long" | |
| return | |
| for bucket, triggers in reversed(list(_OUTPUT_LENGTH_MAP.items())): | |
| if any(t in ql for t in triggers): | |
| f.estimated_output_length = bucket | |
| return | |
| # Default based on token count | |
| if f.token_count < 10: | |
| f.estimated_output_length = "medium" # even trivial math needs explanation | |
| elif f.token_count < 30: | |
| f.estimated_output_length = "medium" | |
| else: | |
| f.estimated_output_length = "long" | |
| # ------------------------------------------------------------------ | |
| # Primary domain | |
| # ------------------------------------------------------------------ | |
| def _primary_domain(self, f: QueryFeatures) -> None: | |
| scores = f.domain_scores | |
| if not scores or max(scores.values()) == 0: | |
| f.primary_domain = "general" | |
| return | |
| f.primary_domain = max(scores.keys(), key=lambda k: scores[k]) | |