Knowledge-Universe / src /orchestration /request_handler.py
vlsiddarth's picture
fix: Optimized Code performance
2549784
"""
Knowledge Universe — Request Orchestration (Blend Mode — John's Per-Crawler Timeout Fix)
JOHN'S FIX:
_crawl_with_timeout() now uses settings.get_crawler_timeout(crawler_name)
instead of the global settings.CRAWLER_TIMEOUT for all crawlers.
RICK'S FIX (Enterprise Audit):
Added Feature 6: Difficulty Drift Detection. Calculates standard deviation of
difficulty across returned sources and injects `difficulty_coherence` into
the coverage_intelligence block.
"""
import asyncio
import hashlib
import json
import time
import math
import logging
from typing import List, Dict, Any, Set, Optional
from collections import defaultdict
from src.scoring.diversity_filter import PLATFORM_QUALITY_FLOORS
from src.api.models import DiscoveryRequest, Source
from src.cache.redis_manager import RedisManager
from src.crawlers.crawler_pool import CrawlerPool
from src.scoring.ranker import UnifiedRanker
from src.scoring.diversity_filter import DiversityFilter
from src.scoring.coverage_confidence import CoverageConfidenceScorer
from src.integrations.local_llm_reranker import LocalLLMReranker
from config.settings import get_settings
settings = get_settings()
logger = logging.getLogger(__name__)
_SOURCE_FIELDS: Set[str] = {
"id", "title", "authors", "quality_score", "pedagogical_fit", "difficulty",
"links", "formats", "summary", "prerequisites", "tags", "language",
"citation_count", "peer_reviewed", "open_access", "publication_date",
"last_updated", "views", "likes", "rating", "stars", "forks", "downloads",
"source_platform", "thumbnail_url", "url",
"duration_seconds", "file_size_bytes", "page_count", "kernel_type",
"dataset_rows", "dataset_cols", "license",
"decay_report", "_ranking_signals",
"retraction_status", "related_sources", # Added the two new enterprise features
}
_DATASET_PLATFORMS: Set[str] = {
"huggingface", "kaggle", "crossref",
"paperswithcode", "semantic_scholar", "documentation",
"distill", "observablehq", "sketchfab", "freesound", "wolfram",
}
_PLATFORM_PRIMARY_FORMAT: Dict[str, str] = {
"github": "github",
"gharchive": "github",
"kaggle": "kaggle",
"youtube": "video",
"arxiv": "pdf",
"stackoverflow": "stackoverflow",
"wikipedia": "html",
"openlibrary": "epub",
"huggingface": "dataset",
"mit_ocw": "html",
"podcast": "podcast",
"common_crawl": "html",
"paperswithcode": "pdf",
"documentation": "html",
"distill": "html",
"observablehq": "sandbox",
"crossref": "pdf",
"sketchfab": "3d_model",
"freesound": "audio",
"wolfram": "simulation",
}
_MAX_PER_PLATFORM_DEFAULT = 2
_MAX_PER_PLATFORM_ADVANCED = 3 # difficulty >= 4: arXiv/GitHub get 3 slots
_FLOOR_MAX_QUALITY = 1.5
_FLOOR_MIN_DECAY = 0.70
_DIFFICULTY_HARD_CEILING = 2
_MIN_RESULTS_BEFORE_RELAX = 3
def _sanitize_for_source(src: Dict[str, Any]) -> Dict[str, Any]:
clean = {k: v for k, v in src.items() if k in _SOURCE_FIELDS}
clean.setdefault("quality_score", 5.0)
clean.setdefault("pedagogical_fit", 0.5)
from src.api.models import SourceFormat
valid_formats = {f.value for f in SourceFormat}
if "formats" in clean and isinstance(clean["formats"], list):
clean["formats"] = [
f for f in clean["formats"]
if f in valid_formats
]
if not clean["formats"]:
platform = clean.get("source_platform", "")
fallback = {
"github": "github", "youtube": "video",
"arxiv": "pdf", "stackoverflow": "html",
"wikipedia": "html", "huggingface": "dataset",
"kaggle": "kaggle", "podcast": "podcast",
"mit_ocw": "html", "openlibrary": "epub",
}
clean["formats"] = [fallback.get(platform, "html")]
for date_field in ("publication_date", "last_updated"):
val = clean.get(date_field)
if isinstance(val, str) and len(val) == 10 and "T" not in val:
clean[date_field] = val + "T00:00:00"
return clean
class RequestOrchestrator:
def __init__(self, redis_manager: RedisManager):
self.redis = redis_manager
self.crawler_pool = CrawlerPool()
self.ranker = UnifiedRanker()
self.diversity_filter = DiversityFilter()
self.llm_reranker = LocalLLMReranker()
self.confidence_scorer = CoverageConfidenceScorer()
self.was_cache_hit = False
self.processing_time_ms = 0.0
self.coverage_intelligence: Dict = {}
self._in_flight_requests: Dict[str, asyncio.Future] = {}
async def handle_request(self, request: DiscoveryRequest) -> List[Source]:
start_time = time.time()
cache_key = self._generate_cache_key(request)
cached = await self.redis.get_json(cache_key)
if cached and not self._is_stale(cached):
self.was_cache_hit = True
self.processing_time_ms = (time.time() - start_time) * 1000
self.coverage_intelligence = cached.get("coverage_intelligence", {})
sources = []
for s in cached["sources"]:
try:
sources.append(Source(**_sanitize_for_source(s)))
except Exception as e:
logger.warning(
f"Cache hit: bad source skipped id={s.get('id','?')} "
f"platform={s.get('source_platform','?')} error={e}"
)
return sources
self.was_cache_hit = False
if cache_key in self._in_flight_requests:
return await self._in_flight_requests[cache_key]
loop = asyncio.get_running_loop()
future = loop.create_future()
self._in_flight_requests[cache_key] = future
try:
sources = await self._execute_pipeline(request)
self.processing_time_ms = (time.time() - start_time) * 1000
await self._cache_result(cache_key, sources)
future.set_result(sources)
return sources
except Exception as exc:
future.set_exception(exc)
raise
finally:
self._in_flight_requests.pop(cache_key, None)
async def _execute_pipeline(self, request: DiscoveryRequest) -> List[Source]:
raw_sources = await self._parallel_crawl(request)
if not raw_sources:
self.coverage_intelligence = self.confidence_scorer._no_results_response(request.topic)
return []
seen_ids: Set[str] = set()
unique_raw: List[Dict[str, Any]] = []
for raw in raw_sources:
src_id = str(raw.get("id", ""))
if src_id and src_id in seen_ids:
continue
if src_id:
seen_ids.add(src_id)
unique_raw.append(raw)
normalized = []
for raw in unique_raw:
src = dict(raw)
src = self._enforce_platform_formats(src)
src = self._normalize_formats(src)
src = self._ensure_links(src)
normalized.append(src)
filtered = self._filter_by_requested_formats(normalized, request)
filtered = self._semantic_prefilter(filtered, request.topic, threshold=0.25)
if not filtered:
return []
# ── RANKING & ENRICHMENT PIPELINE ──
try:
# Feature 7: Pass customer context for half-life overrides
customer = getattr(request, "state", {}).get("customer")
scored = self.ranker.rank_batch(filtered, request, customer=customer)
except Exception as e:
logger.error(f"Ranking failed: {e}")
scored = filtered
# Feature 5: Retraction status enrichment
try:
from src.crawlers.retraction_checker import enrich_sources_with_retraction_status
scored = await enrich_sources_with_retraction_status(scored)
except ImportError:
pass
except Exception as _re:
logger.warning(f"Retraction enrichment non-fatal: {_re}")
# Feature 9: Citation graph enrichment
try:
from src.crawlers.citation_graph import enrich_sources_with_citation_graph
scored = await enrich_sources_with_citation_graph(scored)
except ImportError:
pass
except Exception as _ce:
logger.warning(f"Citation graph enrichment non-fatal: {_ce}")
pre_floor = len(scored)
scored = [
src for src in scored
if not (
src.get("quality_score", 0) < _FLOOR_MAX_QUALITY
and src.get("decay_report", {}).get("decay_score", 0) > _FLOOR_MIN_DECAY
)
]
if len(scored) < pre_floor:
logger.info(f"Hard floor removed {pre_floor - len(scored)} sources")
req_difficulty = request.difficulty
within_ceiling = [
src for src in scored
if abs(int(src.get("difficulty", req_difficulty)) - req_difficulty)
<= _DIFFICULTY_HARD_CEILING
or src.get("source_platform") == "wikipedia"
]
if len(within_ceiling) >= _MIN_RESULTS_BEFORE_RELAX:
scored = within_ceiling
else:
relaxed = _DIFFICULTY_HARD_CEILING + 1
scored = [
src for src in scored
if abs(int(src.get("difficulty", req_difficulty)) - req_difficulty)
<= relaxed
]
if not scored:
scored = filtered
def _diff_sort_key(src: Dict) -> float:
try:
gap = abs(int(src.get("difficulty", req_difficulty)) - req_difficulty)
except (ValueError, TypeError):
gap = 0
return src.get("quality_score", 0) - (gap * 2.0)
scored.sort(key=_diff_sort_key, reverse=True)
unique = await self.diversity_filter.filter(scored)
if not unique:
return []
if req_difficulty <= 2:
learning_platforms = [
"wikipedia", "youtube", "arxiv", "github", "stackoverflow",
"kaggle", "huggingface", "mit_ocw", "openlibrary",
"podcast", "common_crawl",
]
elif req_difficulty >= 4:
learning_platforms = [
"arxiv", "github", "stackoverflow", "youtube",
"wikipedia", "kaggle", "huggingface", "mit_ocw",
"openlibrary", "podcast", "common_crawl",
]
else:
learning_platforms = [
"youtube", "arxiv", "github", "stackoverflow",
"wikipedia", "kaggle", "huggingface", "mit_ocw",
"openlibrary", "podcast", "common_crawl",
]
buckets: Dict[str, List] = defaultdict(list)
for src in unique:
buckets[src.get("source_platform", "unknown")].append(src)
guaranteed: List[Dict[str, Any]] = []
used_obj_ids: Set[int] = set()
for platform in learning_platforms:
if buckets.get(platform):
floor = PLATFORM_QUALITY_FLOORS.get(platform, 0)
eligible = [
s for s in buckets[platform]
if s.get("quality_score", 0) >= floor
]
if eligible:
best = max(eligible, key=lambda x: x.get("quality_score", 0))
guaranteed.append(best)
used_obj_ids.add(id(best))
remaining = [
src for src in unique
if id(src) not in used_obj_ids
and src.get("quality_score", 0) >= (settings.MIN_QUALITY_SCORE - 2.0)
]
remaining.sort(key=lambda x: x.get("quality_score", 0), reverse=True)
platform_counts: Dict[str, int] = defaultdict(int)
capped: List[Dict[str, Any]] = []
for src in guaranteed:
capped.append(src)
platform_counts[src.get("source_platform", "unknown")] += 1
max_per_plat = (
_MAX_PER_PLATFORM_ADVANCED
if req_difficulty >= 4
else _MAX_PER_PLATFORM_DEFAULT
)
for src in remaining:
plat = src.get("source_platform", "unknown")
if platform_counts[plat] < max_per_plat:
capped.append(src)
platform_counts[plat] += 1
combined = capped
# Step 8 — Rerank with shared embeddings
query_emb = None
doc_embs = None
try:
reranked, query_emb, doc_embs = self.llm_reranker.rerank_with_embeddings(
request.topic,
combined,
requested_difficulty=request.difficulty,
)
except Exception as e:
logger.warning(f"LLM rerank failed: {e}")
reranked = combined
# Re-enforce min_freshness AFTER reranking (reranker can reorder past the pre-filter)
if getattr(request, "min_freshness", None) is not None:
from src.scoring.decay_engine import KnowledgeDecayEngine
_de = KnowledgeDecayEngine()
reranked = [
src for src in reranked
if _de.compute_from_dict(src).freshness >= request.min_freshness
]
if not reranked:
logger.warning(
f"min_freshness={request.min_freshness} dropped ALL post-rerank results. "
f"Returning empty — client should broaden freshness threshold."
)
# Step 8b — Coverage confidence (fast path, shared embeddings)
try:
self.coverage_intelligence = self.confidence_scorer.compute_from_embeddings(
query=request.topic,
sources=reranked,
query_emb=query_emb,
doc_embs=doc_embs,
top_k=5,
)
except Exception as e:
logger.warning(f"Confidence scoring failed: {e}")
self.coverage_intelligence = self.confidence_scorer._unavailable_response()
results: List[Source] = []
for src in reranked[:request.max_results]:
try:
if "_ranking_signals" in src:
align = src["_ranking_signals"].get("difficulty_alignment", 5.0)
src["pedagogical_fit"] = round(align / 10.0, 3)
clean = _sanitize_for_source(src)
results.append(Source(**clean))
except Exception as e:
logger.error(f"Source({src.get('id')}) skipped: {e}")
# ── RICK'S FIX: Feature 6 - Difficulty Drift Detection ──
if results and isinstance(self.coverage_intelligence, dict):
difficulties = [src.difficulty for src in results if getattr(src, "difficulty", None) is not None]
if len(difficulties) > 1:
mean_diff = sum(difficulties) / len(difficulties)
variance = sum((x - mean_diff) ** 2 for x in difficulties) / len(difficulties)
std_dev = math.sqrt(variance)
# Map std_dev (typically 0.0 to 2.0) to a 1.0 to 0.0 coherence score
coherence = max(0.0, round(1.0 - (std_dev / 2.0), 3))
else:
coherence = 1.0
self.coverage_intelligence["difficulty_coherence"] = coherence
# ────────────────────────────────────────────────────────
return results
async def _parallel_crawl(self, request: DiscoveryRequest) -> List[Dict[str, Any]]:
crawlers = self.crawler_pool.get_active_crawlers()
tasks = [
asyncio.create_task(
self._crawl_with_timeout(crawler, request.topic, request.difficulty)
)
for crawler in crawlers
]
results = await asyncio.gather(*tasks, return_exceptions=True)
all_sources: List[Dict[str, Any]] = []
for result in results:
if isinstance(result, list):
all_sources.extend(result)
elif isinstance(result, Exception):
logger.debug(f"Crawler exception: {result}")
return all_sources
async def _crawl_with_timeout(
self, crawler, topic: str, difficulty: int
) -> List[Dict]:
crawler_name = crawler.__class__.__name__
timeout = settings.get_crawler_timeout(crawler_name)
start = time.time()
try:
result = await asyncio.wait_for(
crawler.crawl(topic, difficulty),
timeout=timeout,
)
elapsed = round((time.time() - start) * 1000)
if elapsed > 3000:
logger.info(
f"{crawler_name} slow: {elapsed}ms, "
f"{len(result)} results (timeout={timeout}s)"
)
return result
except asyncio.TimeoutError:
elapsed = round((time.time() - start) * 1000)
logger.info(
f"{crawler_name} TIMEOUT after {elapsed}ms "
f"(limit={timeout}s)"
)
return []
except Exception as e:
logger.debug(f"{crawler_name} failed: {e}")
return []
def _enforce_platform_formats(self, src: Dict) -> Dict:
platform = src.get("source_platform", "")
formats = list(src.get("formats") or [])
if platform in ("github", "gharchive"):
formats = [f for f in formats if f not in ("dataset", "repo")]
if not formats or "github" not in formats:
formats = ["github"]
if platform not in _DATASET_PLATFORMS and "dataset" in formats:
formats = [f for f in formats if f != "dataset"]
primary = _PLATFORM_PRIMARY_FORMAT.get(platform)
if primary and primary not in formats:
formats = [primary] + formats
src["formats"] = formats if formats else [
_PLATFORM_PRIMARY_FORMAT.get(platform, "html")
]
return src
def _filter_by_requested_formats(
self, sources: List[Dict], request: DiscoveryRequest
) -> List[Dict]:
requested = {f.value for f in request.formats}
filtered = []
for src in sources:
formats = set(src.get("formats") or [])
if not formats:
platform = src.get("source_platform", "")
fallback = _PLATFORM_PRIMARY_FORMAT.get(platform, "html")
formats = {fallback}
src["formats"] = list(formats)
if formats & requested:
filtered.append(src)
return filtered
# REPLACE with (title-only fast prefilter, no full encode):
def _semantic_prefilter(
self,
sources: List[Dict],
query: str,
threshold: float = 0.25,
) -> List[Dict]:
if not sources:
return sources
# Fast keyword prefilter first — eliminates obvious junk without model call
query_words = {w.lower() for w in query.split() if len(w) > 3}
if query_words:
keyword_filtered = []
for src in sources:
title = (src.get("title", "") or "").lower()
summary = (src.get("summary", "") or "")[:100].lower()
tags = " ".join(src.get("tags", []) or []).lower()
combined = f"{title} {summary} {tags}"
if any(w in combined for w in query_words):
keyword_filtered.append(src)
# Only run expensive semantic filter if keyword filter kept too many
if len(keyword_filtered) <= 25:
return keyword_filtered
sources = keyword_filtered
# Semantic filter only when needed (>25 sources remaining)
try:
from src.integrations.shared_model import get_shared_model
from sentence_transformers import util
model = get_shared_model()
query_emb = model.encode(query, convert_to_tensor=True)
# Encode only titles (4x faster than title+summary)
texts = [s.get('title', '') for s in sources]
doc_embs = model.encode(texts, convert_to_tensor=True)
sims = util.cos_sim(query_emb, doc_embs)[0]
filtered = []
for src, sim in zip(sources, sims):
score = float(sim)
if score >= threshold:
filtered.append(src)
else:
logger.debug(
f"Semantic pre-filter dropped: sim={score:.3f} "
f"title='{src.get('title','')[:50]}'"
)
dropped = len(sources) - len(filtered)
if dropped > 0:
logger.info(f"Semantic pre-filter: dropped {dropped} irrelevant sources")
return filtered
except Exception as e:
logger.warning(f"Semantic pre-filter failed (non-fatal): {e}")
return sources
def _ensure_links(self, src: Dict) -> Dict:
if not src.get("links"):
fmt = (src.get("formats") or ["html"])[0]
src["links"] = [{
"type": fmt, "url": src.get("url", ""),
"format": fmt, "size_bytes": None, "access_method": "direct",
}]
return src
def _normalize_formats(self, src: Dict) -> Dict:
formats = set(src.get("formats") or [])
alias_map = {"repository": "github", "repo": "github",
"question": "html", "code": "html"}
for link in src.get("links") or []:
if link.get("format"):
formats.add(link["format"])
src["formats"] = [alias_map.get(f, f) for f in formats]
return src
def _generate_cache_key(self, request: DiscoveryRequest) -> str:
payload = {
"topic": request.topic.lower().strip(),
"difficulty": request.difficulty,
"formats": sorted(f.value for f in request.formats),
"language": request.language,
"max_results": request.max_results,
}
return "req:" + hashlib.sha256(
json.dumps(payload, sort_keys=True).encode()
).hexdigest()
# REPLACE with (strip embeddings before caching):
async def _cache_result(self, key: str, sources: List[Source]) -> None:
try:
def _strip_embeddings(source_dict: dict) -> dict:
"""Remove embedding vectors from cache — they're regenerated on demand."""
source_dict.pop("embedding", None)
return source_dict
await self.redis.set_json(
key,
{
"sources": [
_strip_embeddings(s.model_dump(mode="json"))
for s in sources
],
"coverage_intelligence": self.coverage_intelligence,
"timestamp": time.time(),
},
ttl=settings.CACHE_TTL_REQUEST,
)
except Exception as e:
logger.error(f"Cache write failed: {e}")
def _is_stale(self, cached: Dict) -> bool:
age = time.time() - cached.get("timestamp", 0)
return age > settings.CACHE_TTL_REQUEST