Spaces:
Running
Running
| """ | |
| Knowledge Universe — Request Orchestration (Blend Mode — John's Per-Crawler Timeout Fix) | |
| JOHN'S FIX: | |
| _crawl_with_timeout() now uses settings.get_crawler_timeout(crawler_name) | |
| instead of the global settings.CRAWLER_TIMEOUT for all crawlers. | |
| RICK'S FIX (Enterprise Audit): | |
| Added Feature 6: Difficulty Drift Detection. Calculates standard deviation of | |
| difficulty across returned sources and injects `difficulty_coherence` into | |
| the coverage_intelligence block. | |
| """ | |
| import asyncio | |
| import hashlib | |
| import json | |
| import time | |
| import math | |
| import logging | |
| from typing import List, Dict, Any, Set, Optional | |
| from collections import defaultdict | |
| from src.scoring.diversity_filter import PLATFORM_QUALITY_FLOORS | |
| from src.api.models import DiscoveryRequest, Source | |
| from src.cache.redis_manager import RedisManager | |
| from src.crawlers.crawler_pool import CrawlerPool | |
| from src.scoring.ranker import UnifiedRanker | |
| from src.scoring.diversity_filter import DiversityFilter | |
| from src.scoring.coverage_confidence import CoverageConfidenceScorer | |
| from src.integrations.local_llm_reranker import LocalLLMReranker | |
| from config.settings import get_settings | |
| settings = get_settings() | |
| logger = logging.getLogger(__name__) | |
| _SOURCE_FIELDS: Set[str] = { | |
| "id", "title", "authors", "quality_score", "pedagogical_fit", "difficulty", | |
| "links", "formats", "summary", "prerequisites", "tags", "language", | |
| "citation_count", "peer_reviewed", "open_access", "publication_date", | |
| "last_updated", "views", "likes", "rating", "stars", "forks", "downloads", | |
| "source_platform", "thumbnail_url", "url", | |
| "duration_seconds", "file_size_bytes", "page_count", "kernel_type", | |
| "dataset_rows", "dataset_cols", "license", | |
| "decay_report", "_ranking_signals", | |
| "retraction_status", "related_sources", # Added the two new enterprise features | |
| } | |
| _DATASET_PLATFORMS: Set[str] = { | |
| "huggingface", "kaggle", "crossref", | |
| "paperswithcode", "semantic_scholar", "documentation", | |
| "distill", "observablehq", "sketchfab", "freesound", "wolfram", | |
| } | |
| _PLATFORM_PRIMARY_FORMAT: Dict[str, str] = { | |
| "github": "github", | |
| "gharchive": "github", | |
| "kaggle": "kaggle", | |
| "youtube": "video", | |
| "arxiv": "pdf", | |
| "stackoverflow": "stackoverflow", | |
| "wikipedia": "html", | |
| "openlibrary": "epub", | |
| "huggingface": "dataset", | |
| "mit_ocw": "html", | |
| "podcast": "podcast", | |
| "common_crawl": "html", | |
| "paperswithcode": "pdf", | |
| "documentation": "html", | |
| "distill": "html", | |
| "observablehq": "sandbox", | |
| "crossref": "pdf", | |
| "sketchfab": "3d_model", | |
| "freesound": "audio", | |
| "wolfram": "simulation", | |
| } | |
| _MAX_PER_PLATFORM_DEFAULT = 2 | |
| _MAX_PER_PLATFORM_ADVANCED = 3 # difficulty >= 4: arXiv/GitHub get 3 slots | |
| _FLOOR_MAX_QUALITY = 1.5 | |
| _FLOOR_MIN_DECAY = 0.70 | |
| _DIFFICULTY_HARD_CEILING = 2 | |
| _MIN_RESULTS_BEFORE_RELAX = 3 | |
| def _sanitize_for_source(src: Dict[str, Any]) -> Dict[str, Any]: | |
| clean = {k: v for k, v in src.items() if k in _SOURCE_FIELDS} | |
| clean.setdefault("quality_score", 5.0) | |
| clean.setdefault("pedagogical_fit", 0.5) | |
| from src.api.models import SourceFormat | |
| valid_formats = {f.value for f in SourceFormat} | |
| if "formats" in clean and isinstance(clean["formats"], list): | |
| clean["formats"] = [ | |
| f for f in clean["formats"] | |
| if f in valid_formats | |
| ] | |
| if not clean["formats"]: | |
| platform = clean.get("source_platform", "") | |
| fallback = { | |
| "github": "github", "youtube": "video", | |
| "arxiv": "pdf", "stackoverflow": "html", | |
| "wikipedia": "html", "huggingface": "dataset", | |
| "kaggle": "kaggle", "podcast": "podcast", | |
| "mit_ocw": "html", "openlibrary": "epub", | |
| } | |
| clean["formats"] = [fallback.get(platform, "html")] | |
| for date_field in ("publication_date", "last_updated"): | |
| val = clean.get(date_field) | |
| if isinstance(val, str) and len(val) == 10 and "T" not in val: | |
| clean[date_field] = val + "T00:00:00" | |
| return clean | |
| class RequestOrchestrator: | |
| def __init__(self, redis_manager: RedisManager): | |
| self.redis = redis_manager | |
| self.crawler_pool = CrawlerPool() | |
| self.ranker = UnifiedRanker() | |
| self.diversity_filter = DiversityFilter() | |
| self.llm_reranker = LocalLLMReranker() | |
| self.confidence_scorer = CoverageConfidenceScorer() | |
| self.was_cache_hit = False | |
| self.processing_time_ms = 0.0 | |
| self.coverage_intelligence: Dict = {} | |
| self._in_flight_requests: Dict[str, asyncio.Future] = {} | |
| async def handle_request(self, request: DiscoveryRequest) -> List[Source]: | |
| start_time = time.time() | |
| cache_key = self._generate_cache_key(request) | |
| cached = await self.redis.get_json(cache_key) | |
| if cached and not self._is_stale(cached): | |
| self.was_cache_hit = True | |
| self.processing_time_ms = (time.time() - start_time) * 1000 | |
| self.coverage_intelligence = cached.get("coverage_intelligence", {}) | |
| sources = [] | |
| for s in cached["sources"]: | |
| try: | |
| sources.append(Source(**_sanitize_for_source(s))) | |
| except Exception as e: | |
| logger.warning( | |
| f"Cache hit: bad source skipped id={s.get('id','?')} " | |
| f"platform={s.get('source_platform','?')} error={e}" | |
| ) | |
| return sources | |
| self.was_cache_hit = False | |
| if cache_key in self._in_flight_requests: | |
| return await self._in_flight_requests[cache_key] | |
| loop = asyncio.get_running_loop() | |
| future = loop.create_future() | |
| self._in_flight_requests[cache_key] = future | |
| try: | |
| sources = await self._execute_pipeline(request) | |
| self.processing_time_ms = (time.time() - start_time) * 1000 | |
| await self._cache_result(cache_key, sources) | |
| future.set_result(sources) | |
| return sources | |
| except Exception as exc: | |
| future.set_exception(exc) | |
| raise | |
| finally: | |
| self._in_flight_requests.pop(cache_key, None) | |
| async def _execute_pipeline(self, request: DiscoveryRequest) -> List[Source]: | |
| raw_sources = await self._parallel_crawl(request) | |
| if not raw_sources: | |
| self.coverage_intelligence = self.confidence_scorer._no_results_response(request.topic) | |
| return [] | |
| seen_ids: Set[str] = set() | |
| unique_raw: List[Dict[str, Any]] = [] | |
| for raw in raw_sources: | |
| src_id = str(raw.get("id", "")) | |
| if src_id and src_id in seen_ids: | |
| continue | |
| if src_id: | |
| seen_ids.add(src_id) | |
| unique_raw.append(raw) | |
| normalized = [] | |
| for raw in unique_raw: | |
| src = dict(raw) | |
| src = self._enforce_platform_formats(src) | |
| src = self._normalize_formats(src) | |
| src = self._ensure_links(src) | |
| normalized.append(src) | |
| filtered = self._filter_by_requested_formats(normalized, request) | |
| filtered = self._semantic_prefilter(filtered, request.topic, threshold=0.25) | |
| if not filtered: | |
| return [] | |
| # ── RANKING & ENRICHMENT PIPELINE ── | |
| try: | |
| # Feature 7: Pass customer context for half-life overrides | |
| customer = getattr(request, "state", {}).get("customer") | |
| scored = self.ranker.rank_batch(filtered, request, customer=customer) | |
| except Exception as e: | |
| logger.error(f"Ranking failed: {e}") | |
| scored = filtered | |
| # Feature 5: Retraction status enrichment | |
| try: | |
| from src.crawlers.retraction_checker import enrich_sources_with_retraction_status | |
| scored = await enrich_sources_with_retraction_status(scored) | |
| except ImportError: | |
| pass | |
| except Exception as _re: | |
| logger.warning(f"Retraction enrichment non-fatal: {_re}") | |
| # Feature 9: Citation graph enrichment | |
| try: | |
| from src.crawlers.citation_graph import enrich_sources_with_citation_graph | |
| scored = await enrich_sources_with_citation_graph(scored) | |
| except ImportError: | |
| pass | |
| except Exception as _ce: | |
| logger.warning(f"Citation graph enrichment non-fatal: {_ce}") | |
| pre_floor = len(scored) | |
| scored = [ | |
| src for src in scored | |
| if not ( | |
| src.get("quality_score", 0) < _FLOOR_MAX_QUALITY | |
| and src.get("decay_report", {}).get("decay_score", 0) > _FLOOR_MIN_DECAY | |
| ) | |
| ] | |
| if len(scored) < pre_floor: | |
| logger.info(f"Hard floor removed {pre_floor - len(scored)} sources") | |
| req_difficulty = request.difficulty | |
| within_ceiling = [ | |
| src for src in scored | |
| if abs(int(src.get("difficulty", req_difficulty)) - req_difficulty) | |
| <= _DIFFICULTY_HARD_CEILING | |
| or src.get("source_platform") == "wikipedia" | |
| ] | |
| if len(within_ceiling) >= _MIN_RESULTS_BEFORE_RELAX: | |
| scored = within_ceiling | |
| else: | |
| relaxed = _DIFFICULTY_HARD_CEILING + 1 | |
| scored = [ | |
| src for src in scored | |
| if abs(int(src.get("difficulty", req_difficulty)) - req_difficulty) | |
| <= relaxed | |
| ] | |
| if not scored: | |
| scored = filtered | |
| def _diff_sort_key(src: Dict) -> float: | |
| try: | |
| gap = abs(int(src.get("difficulty", req_difficulty)) - req_difficulty) | |
| except (ValueError, TypeError): | |
| gap = 0 | |
| return src.get("quality_score", 0) - (gap * 2.0) | |
| scored.sort(key=_diff_sort_key, reverse=True) | |
| unique = await self.diversity_filter.filter(scored) | |
| if not unique: | |
| return [] | |
| if req_difficulty <= 2: | |
| learning_platforms = [ | |
| "wikipedia", "youtube", "arxiv", "github", "stackoverflow", | |
| "kaggle", "huggingface", "mit_ocw", "openlibrary", | |
| "podcast", "common_crawl", | |
| ] | |
| elif req_difficulty >= 4: | |
| learning_platforms = [ | |
| "arxiv", "github", "stackoverflow", "youtube", | |
| "wikipedia", "kaggle", "huggingface", "mit_ocw", | |
| "openlibrary", "podcast", "common_crawl", | |
| ] | |
| else: | |
| learning_platforms = [ | |
| "youtube", "arxiv", "github", "stackoverflow", | |
| "wikipedia", "kaggle", "huggingface", "mit_ocw", | |
| "openlibrary", "podcast", "common_crawl", | |
| ] | |
| buckets: Dict[str, List] = defaultdict(list) | |
| for src in unique: | |
| buckets[src.get("source_platform", "unknown")].append(src) | |
| guaranteed: List[Dict[str, Any]] = [] | |
| used_obj_ids: Set[int] = set() | |
| for platform in learning_platforms: | |
| if buckets.get(platform): | |
| floor = PLATFORM_QUALITY_FLOORS.get(platform, 0) | |
| eligible = [ | |
| s for s in buckets[platform] | |
| if s.get("quality_score", 0) >= floor | |
| ] | |
| if eligible: | |
| best = max(eligible, key=lambda x: x.get("quality_score", 0)) | |
| guaranteed.append(best) | |
| used_obj_ids.add(id(best)) | |
| remaining = [ | |
| src for src in unique | |
| if id(src) not in used_obj_ids | |
| and src.get("quality_score", 0) >= (settings.MIN_QUALITY_SCORE - 2.0) | |
| ] | |
| remaining.sort(key=lambda x: x.get("quality_score", 0), reverse=True) | |
| platform_counts: Dict[str, int] = defaultdict(int) | |
| capped: List[Dict[str, Any]] = [] | |
| for src in guaranteed: | |
| capped.append(src) | |
| platform_counts[src.get("source_platform", "unknown")] += 1 | |
| max_per_plat = ( | |
| _MAX_PER_PLATFORM_ADVANCED | |
| if req_difficulty >= 4 | |
| else _MAX_PER_PLATFORM_DEFAULT | |
| ) | |
| for src in remaining: | |
| plat = src.get("source_platform", "unknown") | |
| if platform_counts[plat] < max_per_plat: | |
| capped.append(src) | |
| platform_counts[plat] += 1 | |
| combined = capped | |
| # Step 8 — Rerank with shared embeddings | |
| query_emb = None | |
| doc_embs = None | |
| try: | |
| reranked, query_emb, doc_embs = self.llm_reranker.rerank_with_embeddings( | |
| request.topic, | |
| combined, | |
| requested_difficulty=request.difficulty, | |
| ) | |
| except Exception as e: | |
| logger.warning(f"LLM rerank failed: {e}") | |
| reranked = combined | |
| # Re-enforce min_freshness AFTER reranking (reranker can reorder past the pre-filter) | |
| if getattr(request, "min_freshness", None) is not None: | |
| from src.scoring.decay_engine import KnowledgeDecayEngine | |
| _de = KnowledgeDecayEngine() | |
| reranked = [ | |
| src for src in reranked | |
| if _de.compute_from_dict(src).freshness >= request.min_freshness | |
| ] | |
| if not reranked: | |
| logger.warning( | |
| f"min_freshness={request.min_freshness} dropped ALL post-rerank results. " | |
| f"Returning empty — client should broaden freshness threshold." | |
| ) | |
| # Step 8b — Coverage confidence (fast path, shared embeddings) | |
| try: | |
| self.coverage_intelligence = self.confidence_scorer.compute_from_embeddings( | |
| query=request.topic, | |
| sources=reranked, | |
| query_emb=query_emb, | |
| doc_embs=doc_embs, | |
| top_k=5, | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Confidence scoring failed: {e}") | |
| self.coverage_intelligence = self.confidence_scorer._unavailable_response() | |
| results: List[Source] = [] | |
| for src in reranked[:request.max_results]: | |
| try: | |
| if "_ranking_signals" in src: | |
| align = src["_ranking_signals"].get("difficulty_alignment", 5.0) | |
| src["pedagogical_fit"] = round(align / 10.0, 3) | |
| clean = _sanitize_for_source(src) | |
| results.append(Source(**clean)) | |
| except Exception as e: | |
| logger.error(f"Source({src.get('id')}) skipped: {e}") | |
| # ── RICK'S FIX: Feature 6 - Difficulty Drift Detection ── | |
| if results and isinstance(self.coverage_intelligence, dict): | |
| difficulties = [src.difficulty for src in results if getattr(src, "difficulty", None) is not None] | |
| if len(difficulties) > 1: | |
| mean_diff = sum(difficulties) / len(difficulties) | |
| variance = sum((x - mean_diff) ** 2 for x in difficulties) / len(difficulties) | |
| std_dev = math.sqrt(variance) | |
| # Map std_dev (typically 0.0 to 2.0) to a 1.0 to 0.0 coherence score | |
| coherence = max(0.0, round(1.0 - (std_dev / 2.0), 3)) | |
| else: | |
| coherence = 1.0 | |
| self.coverage_intelligence["difficulty_coherence"] = coherence | |
| # ──────────────────────────────────────────────────────── | |
| return results | |
| async def _parallel_crawl(self, request: DiscoveryRequest) -> List[Dict[str, Any]]: | |
| crawlers = self.crawler_pool.get_active_crawlers() | |
| tasks = [ | |
| asyncio.create_task( | |
| self._crawl_with_timeout(crawler, request.topic, request.difficulty) | |
| ) | |
| for crawler in crawlers | |
| ] | |
| results = await asyncio.gather(*tasks, return_exceptions=True) | |
| all_sources: List[Dict[str, Any]] = [] | |
| for result in results: | |
| if isinstance(result, list): | |
| all_sources.extend(result) | |
| elif isinstance(result, Exception): | |
| logger.debug(f"Crawler exception: {result}") | |
| return all_sources | |
| async def _crawl_with_timeout( | |
| self, crawler, topic: str, difficulty: int | |
| ) -> List[Dict]: | |
| crawler_name = crawler.__class__.__name__ | |
| timeout = settings.get_crawler_timeout(crawler_name) | |
| start = time.time() | |
| try: | |
| result = await asyncio.wait_for( | |
| crawler.crawl(topic, difficulty), | |
| timeout=timeout, | |
| ) | |
| elapsed = round((time.time() - start) * 1000) | |
| if elapsed > 3000: | |
| logger.info( | |
| f"{crawler_name} slow: {elapsed}ms, " | |
| f"{len(result)} results (timeout={timeout}s)" | |
| ) | |
| return result | |
| except asyncio.TimeoutError: | |
| elapsed = round((time.time() - start) * 1000) | |
| logger.info( | |
| f"{crawler_name} TIMEOUT after {elapsed}ms " | |
| f"(limit={timeout}s)" | |
| ) | |
| return [] | |
| except Exception as e: | |
| logger.debug(f"{crawler_name} failed: {e}") | |
| return [] | |
| def _enforce_platform_formats(self, src: Dict) -> Dict: | |
| platform = src.get("source_platform", "") | |
| formats = list(src.get("formats") or []) | |
| if platform in ("github", "gharchive"): | |
| formats = [f for f in formats if f not in ("dataset", "repo")] | |
| if not formats or "github" not in formats: | |
| formats = ["github"] | |
| if platform not in _DATASET_PLATFORMS and "dataset" in formats: | |
| formats = [f for f in formats if f != "dataset"] | |
| primary = _PLATFORM_PRIMARY_FORMAT.get(platform) | |
| if primary and primary not in formats: | |
| formats = [primary] + formats | |
| src["formats"] = formats if formats else [ | |
| _PLATFORM_PRIMARY_FORMAT.get(platform, "html") | |
| ] | |
| return src | |
| def _filter_by_requested_formats( | |
| self, sources: List[Dict], request: DiscoveryRequest | |
| ) -> List[Dict]: | |
| requested = {f.value for f in request.formats} | |
| filtered = [] | |
| for src in sources: | |
| formats = set(src.get("formats") or []) | |
| if not formats: | |
| platform = src.get("source_platform", "") | |
| fallback = _PLATFORM_PRIMARY_FORMAT.get(platform, "html") | |
| formats = {fallback} | |
| src["formats"] = list(formats) | |
| if formats & requested: | |
| filtered.append(src) | |
| return filtered | |
| # REPLACE with (title-only fast prefilter, no full encode): | |
| def _semantic_prefilter( | |
| self, | |
| sources: List[Dict], | |
| query: str, | |
| threshold: float = 0.25, | |
| ) -> List[Dict]: | |
| if not sources: | |
| return sources | |
| # Fast keyword prefilter first — eliminates obvious junk without model call | |
| query_words = {w.lower() for w in query.split() if len(w) > 3} | |
| if query_words: | |
| keyword_filtered = [] | |
| for src in sources: | |
| title = (src.get("title", "") or "").lower() | |
| summary = (src.get("summary", "") or "")[:100].lower() | |
| tags = " ".join(src.get("tags", []) or []).lower() | |
| combined = f"{title} {summary} {tags}" | |
| if any(w in combined for w in query_words): | |
| keyword_filtered.append(src) | |
| # Only run expensive semantic filter if keyword filter kept too many | |
| if len(keyword_filtered) <= 25: | |
| return keyword_filtered | |
| sources = keyword_filtered | |
| # Semantic filter only when needed (>25 sources remaining) | |
| try: | |
| from src.integrations.shared_model import get_shared_model | |
| from sentence_transformers import util | |
| model = get_shared_model() | |
| query_emb = model.encode(query, convert_to_tensor=True) | |
| # Encode only titles (4x faster than title+summary) | |
| texts = [s.get('title', '') for s in sources] | |
| doc_embs = model.encode(texts, convert_to_tensor=True) | |
| sims = util.cos_sim(query_emb, doc_embs)[0] | |
| filtered = [] | |
| for src, sim in zip(sources, sims): | |
| score = float(sim) | |
| if score >= threshold: | |
| filtered.append(src) | |
| else: | |
| logger.debug( | |
| f"Semantic pre-filter dropped: sim={score:.3f} " | |
| f"title='{src.get('title','')[:50]}'" | |
| ) | |
| dropped = len(sources) - len(filtered) | |
| if dropped > 0: | |
| logger.info(f"Semantic pre-filter: dropped {dropped} irrelevant sources") | |
| return filtered | |
| except Exception as e: | |
| logger.warning(f"Semantic pre-filter failed (non-fatal): {e}") | |
| return sources | |
| def _ensure_links(self, src: Dict) -> Dict: | |
| if not src.get("links"): | |
| fmt = (src.get("formats") or ["html"])[0] | |
| src["links"] = [{ | |
| "type": fmt, "url": src.get("url", ""), | |
| "format": fmt, "size_bytes": None, "access_method": "direct", | |
| }] | |
| return src | |
| def _normalize_formats(self, src: Dict) -> Dict: | |
| formats = set(src.get("formats") or []) | |
| alias_map = {"repository": "github", "repo": "github", | |
| "question": "html", "code": "html"} | |
| for link in src.get("links") or []: | |
| if link.get("format"): | |
| formats.add(link["format"]) | |
| src["formats"] = [alias_map.get(f, f) for f in formats] | |
| return src | |
| def _generate_cache_key(self, request: DiscoveryRequest) -> str: | |
| payload = { | |
| "topic": request.topic.lower().strip(), | |
| "difficulty": request.difficulty, | |
| "formats": sorted(f.value for f in request.formats), | |
| "language": request.language, | |
| "max_results": request.max_results, | |
| } | |
| return "req:" + hashlib.sha256( | |
| json.dumps(payload, sort_keys=True).encode() | |
| ).hexdigest() | |
| # REPLACE with (strip embeddings before caching): | |
| async def _cache_result(self, key: str, sources: List[Source]) -> None: | |
| try: | |
| def _strip_embeddings(source_dict: dict) -> dict: | |
| """Remove embedding vectors from cache — they're regenerated on demand.""" | |
| source_dict.pop("embedding", None) | |
| return source_dict | |
| await self.redis.set_json( | |
| key, | |
| { | |
| "sources": [ | |
| _strip_embeddings(s.model_dump(mode="json")) | |
| for s in sources | |
| ], | |
| "coverage_intelligence": self.coverage_intelligence, | |
| "timestamp": time.time(), | |
| }, | |
| ttl=settings.CACHE_TTL_REQUEST, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Cache write failed: {e}") | |
| def _is_stale(self, cached: Dict) -> bool: | |
| age = time.time() - cached.get("timestamp", 0) | |
| return age > settings.CACHE_TTL_REQUEST |