Spaces:
Running
Running
feat: hybrid RAG pipeline upgrade
Browse files- NewsAPI adapter: real-time news search (80k+ sources, fallback to DDG)
- Jina Reranker API: cloud reranker replacing slow CPU self-hosted (~1s vs 42s)
- Jina Reader: full article extraction from live search URLs
- Intent classifier v5: keyword pre-check layer (0ms for 80% of queries)
- Smart caching: 4-layer Redis cache with intent-aware TTLs
- Query orchestrator: NewsAPI-first live search with DDG fallback + Jina enhancement
- Redis adapter: layered cache methods (get/set intent, live, translation, response)
- Config: NEWSAPI_*, JINA_RERANKER_* settings added
- .env +23 -2
- src/api/dependencies.py +8 -0
- src/core/config.py +11 -0
- src/core/orchestrator/query_orchestrator.py +76 -54
- src/core/use_cases/rag_chat_use_case.py +47 -43
- src/infrastructure/adapters/bge_reranker_adapter.py +292 -64
- src/infrastructure/adapters/intent_classifier_v2.py +274 -152
- src/infrastructure/adapters/jina_reranker_adapter.py +161 -0
- src/infrastructure/adapters/newsapi_adapter.py +376 -0
- src/infrastructure/adapters/redis_adapter.py +223 -21
.env
CHANGED
|
@@ -29,7 +29,7 @@ POSTGRES_DB=rag_interactions
|
|
| 29 |
# --- Models configuration ---
|
| 30 |
EMBEDDING_MODEL=BAAI/bge-m3
|
| 31 |
VECTOR_SIZE=1024
|
| 32 |
-
RERANKER_MODEL=
|
| 33 |
|
| 34 |
# ==========================================
|
| 35 |
# LLM Provider — set LLM_PROVIDER to one of:
|
|
@@ -114,7 +114,28 @@ SEARXNG_MAX_RESULTS=10
|
|
| 114 |
# Free models: Llama 4, Qwen 3, DeepSeek, Gemma 3 and more
|
| 115 |
OPENROUTER_API_KEY=your-openrouter-api-key-here
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
# --- Jina AI Reader (Full Article Extraction) ---
|
| 118 |
# Get free key: https://jina.ai (1M tokens/month free)
|
| 119 |
# Without key: most news sites return 401 Unauthorized
|
| 120 |
-
JINA_API_KEY=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
# --- Models configuration ---
|
| 30 |
EMBEDDING_MODEL=BAAI/bge-m3
|
| 31 |
VECTOR_SIZE=1024
|
| 32 |
+
RERANKER_MODEL=jinaai/jina-reranker-v3
|
| 33 |
|
| 34 |
# ==========================================
|
| 35 |
# LLM Provider — set LLM_PROVIDER to one of:
|
|
|
|
| 114 |
# Free models: Llama 4, Qwen 3, DeepSeek, Gemma 3 and more
|
| 115 |
OPENROUTER_API_KEY=your-openrouter-api-key-here
|
| 116 |
|
| 117 |
+
# --- NewsAPI.org (Real-Time News Search) ---
|
| 118 |
+
# Get free key: https://newsapi.org/register (100 requests/day free)
|
| 119 |
+
# Paid tier: $449/month for production (250,000 requests/month)
|
| 120 |
+
NEWSAPI_KEY=74f434d6dafd4e0fb68b6f6c1252f8e0
|
| 121 |
+
NEWSAPI_ENABLED=true
|
| 122 |
+
NEWSAPI_TIMEOUT=2.0
|
| 123 |
+
NEWSAPI_MAX_RESULTS=20
|
| 124 |
+
|
| 125 |
# --- Jina AI Reader (Full Article Extraction) ---
|
| 126 |
# Get free key: https://jina.ai (1M tokens/month free)
|
| 127 |
# Without key: most news sites return 401 Unauthorized
|
| 128 |
+
JINA_API_KEY=jina_21658d5feda2467aad7b3bfc08a1b52a4KAI3aLzYhgeua81sPQSyyaYqoh_
|
| 129 |
+
JINA_RERANKER_ENABLED=true
|
| 130 |
+
JINA_RERANKER_MODEL=jina-reranker-v3
|
| 131 |
+
JINA_RERANKER_TIMEOUT=5.0
|
| 132 |
+
|
| 133 |
+
# --- ACLED Conflict Data (Structured conflict events for Ethiopia) ---
|
| 134 |
+
# Register at: https://acleddata.com/register
|
| 135 |
+
# Use your acleddata.com login credentials (email + password)
|
| 136 |
+
# No separate API key needed — OAuth token is generated automatically
|
| 137 |
+
ACLED_ENABLED=false
|
| 138 |
+
ACLED_EMAIL=your-acled-email@example.com
|
| 139 |
+
ACLED_PASSWORD=your-acled-password
|
| 140 |
+
ACLED_TIMEOUT=8.0
|
| 141 |
+
ACLED_MAX_RESULTS=20
|
src/api/dependencies.py
CHANGED
|
@@ -16,6 +16,7 @@ from src.infrastructure.adapters.clickhouse_adapter import ClickHouseAdapter
|
|
| 16 |
from src.infrastructure.adapters.postgres_adapter import PostgresAdapter
|
| 17 |
from src.infrastructure.adapters.redis_adapter import RedisAdapter
|
| 18 |
from src.infrastructure.adapters.duckduckgo_adapter import DuckDuckGoAdapter
|
|
|
|
| 19 |
|
| 20 |
# Hybrid Search Components
|
| 21 |
from src.core.orchestrator.query_orchestrator import QueryOrchestrator
|
|
@@ -45,8 +46,15 @@ duckduckgo_adapter = DuckDuckGoAdapter(
|
|
| 45 |
timeout=settings.LIVE_SEARCH_TIMEOUT,
|
| 46 |
max_results=settings.LIVE_SEARCH_MAX_RESULTS
|
| 47 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
query_orchestrator = QueryOrchestrator(
|
| 49 |
live_search_adapter=duckduckgo_adapter,
|
|
|
|
| 50 |
enable_hybrid=settings.ENABLE_HYBRID_SEARCH,
|
| 51 |
default_live_weight=settings.LIVE_SEARCH_WEIGHT,
|
| 52 |
default_db_weight=settings.DB_SEARCH_WEIGHT
|
|
|
|
| 16 |
from src.infrastructure.adapters.postgres_adapter import PostgresAdapter
|
| 17 |
from src.infrastructure.adapters.redis_adapter import RedisAdapter
|
| 18 |
from src.infrastructure.adapters.duckduckgo_adapter import DuckDuckGoAdapter
|
| 19 |
+
from src.infrastructure.adapters.newsapi_adapter import NewsAPIAdapter
|
| 20 |
|
| 21 |
# Hybrid Search Components
|
| 22 |
from src.core.orchestrator.query_orchestrator import QueryOrchestrator
|
|
|
|
| 46 |
timeout=settings.LIVE_SEARCH_TIMEOUT,
|
| 47 |
max_results=settings.LIVE_SEARCH_MAX_RESULTS
|
| 48 |
)
|
| 49 |
+
newsapi_adapter = NewsAPIAdapter(
|
| 50 |
+
api_key=settings.NEWSAPI_KEY,
|
| 51 |
+
timeout=settings.NEWSAPI_TIMEOUT,
|
| 52 |
+
max_results=settings.NEWSAPI_MAX_RESULTS
|
| 53 |
+
) if settings.NEWSAPI_ENABLED else None
|
| 54 |
+
|
| 55 |
query_orchestrator = QueryOrchestrator(
|
| 56 |
live_search_adapter=duckduckgo_adapter,
|
| 57 |
+
newsapi_adapter=newsapi_adapter,
|
| 58 |
enable_hybrid=settings.ENABLE_HYBRID_SEARCH,
|
| 59 |
default_live_weight=settings.LIVE_SEARCH_WEIGHT,
|
| 60 |
default_db_weight=settings.DB_SEARCH_WEIGHT
|
src/core/config.py
CHANGED
|
@@ -87,6 +87,17 @@ class Settings(BaseSettings):
|
|
| 87 |
JINA_READER_TIMEOUT: float = float(os.getenv("JINA_READER_TIMEOUT", "8.0"))
|
| 88 |
JINA_READER_MAX_CONCURRENT: int = int(os.getenv("JINA_READER_MAX_CONCURRENT", "10"))
|
| 89 |
JINA_API_KEY: str = os.getenv("JINA_API_KEY", "") # Get free key at https://jina.ai
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
# Cache Settings (TTL in seconds)
|
| 92 |
CACHE_RESPONSE_TTL: int = int(os.getenv("CACHE_RESPONSE_TTL", "300")) # 5 minutes
|
|
|
|
| 87 |
JINA_READER_TIMEOUT: float = float(os.getenv("JINA_READER_TIMEOUT", "8.0"))
|
| 88 |
JINA_READER_MAX_CONCURRENT: int = int(os.getenv("JINA_READER_MAX_CONCURRENT", "10"))
|
| 89 |
JINA_API_KEY: str = os.getenv("JINA_API_KEY", "") # Get free key at https://jina.ai
|
| 90 |
+
|
| 91 |
+
# Jina Reranker API
|
| 92 |
+
JINA_RERANKER_ENABLED: bool = os.getenv("JINA_RERANKER_ENABLED", "true").lower() == "true"
|
| 93 |
+
JINA_RERANKER_MODEL: str = os.getenv("JINA_RERANKER_MODEL", "jina-reranker-v3")
|
| 94 |
+
JINA_RERANKER_TIMEOUT: float = float(os.getenv("JINA_RERANKER_TIMEOUT", "5.0"))
|
| 95 |
+
|
| 96 |
+
# NewsAPI Settings (Real-Time News Search)
|
| 97 |
+
NEWSAPI_KEY: str = os.getenv("NEWSAPI_KEY", "") # Get free key at https://newsapi.org/register
|
| 98 |
+
NEWSAPI_ENABLED: bool = os.getenv("NEWSAPI_ENABLED", "true").lower() == "true"
|
| 99 |
+
NEWSAPI_TIMEOUT: float = float(os.getenv("NEWSAPI_TIMEOUT", "2.0"))
|
| 100 |
+
NEWSAPI_MAX_RESULTS: int = int(os.getenv("NEWSAPI_MAX_RESULTS", "20"))
|
| 101 |
|
| 102 |
# Cache Settings (TTL in seconds)
|
| 103 |
CACHE_RESPONSE_TTL: int = int(os.getenv("CACHE_RESPONSE_TTL", "300")) # 5 minutes
|
src/core/orchestrator/query_orchestrator.py
CHANGED
|
@@ -80,6 +80,7 @@ class QueryOrchestrator:
|
|
| 80 |
def __init__(
|
| 81 |
self,
|
| 82 |
live_search_adapter,
|
|
|
|
| 83 |
enable_hybrid: bool = True,
|
| 84 |
default_live_weight: float = 0.5,
|
| 85 |
default_db_weight: float = 0.5
|
|
@@ -89,11 +90,13 @@ class QueryOrchestrator:
|
|
| 89 |
|
| 90 |
Args:
|
| 91 |
live_search_adapter: DuckDuckGo adapter instance
|
|
|
|
| 92 |
enable_hybrid: Global flag to enable/disable hybrid search
|
| 93 |
default_live_weight: Default weight for live results
|
| 94 |
default_db_weight: Default weight for database results
|
| 95 |
"""
|
| 96 |
self.live_search = live_search_adapter
|
|
|
|
| 97 |
self.enable_hybrid = enable_hybrid
|
| 98 |
self.default_live_weight = default_live_weight
|
| 99 |
self.default_db_weight = default_db_weight
|
|
@@ -446,11 +449,12 @@ class QueryOrchestrator:
|
|
| 446 |
"""
|
| 447 |
Execute live search with Jina Reader enhancement.
|
| 448 |
|
| 449 |
-
|
| 450 |
-
1.
|
| 451 |
-
2.
|
| 452 |
-
3.
|
| 453 |
-
4.
|
|
|
|
| 454 |
|
| 455 |
Args:
|
| 456 |
query: Search query (English)
|
|
@@ -458,60 +462,78 @@ class QueryOrchestrator:
|
|
| 458 |
Returns:
|
| 459 |
List of enhanced live search results with full articles
|
| 460 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
try:
|
| 462 |
-
|
| 463 |
-
results = await self.live_search.search(query)
|
| 464 |
-
logger.info(f"Live search: {len(results)} results from DuckDuckGo")
|
| 465 |
|
| 466 |
-
|
| 467 |
-
|
|
|
|
|
|
|
| 468 |
|
| 469 |
-
# Step
|
| 470 |
-
|
|
|
|
|
|
|
|
|
|
| 471 |
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
jina = get_jina_reader_adapter(
|
| 481 |
-
timeout=settings.JINA_READER_TIMEOUT,
|
| 482 |
-
max_concurrent=settings.JINA_READER_MAX_CONCURRENT
|
| 483 |
-
)
|
| 484 |
-
|
| 485 |
-
# Step 4: Extract full articles (replaces snippets)
|
| 486 |
-
enhanced_results = await jina.enhance_search_results(
|
| 487 |
-
results,
|
| 488 |
-
fallback_to_snippet=True # Keep snippet if Jina fails
|
| 489 |
-
)
|
| 490 |
-
|
| 491 |
-
# Log enhancement stats
|
| 492 |
-
full_articles = sum(1 for r in enhanced_results if r.get("full_article"))
|
| 493 |
-
snippets = len(enhanced_results) - full_articles
|
| 494 |
-
total_chars = sum(
|
| 495 |
-
r.get("content_length", 0)
|
| 496 |
-
for r in enhanced_results
|
| 497 |
-
if r.get("full_article")
|
| 498 |
-
)
|
| 499 |
-
|
| 500 |
-
logger.info(
|
| 501 |
-
f"Jina enhancement: {full_articles} full articles ({total_chars:,} chars), "
|
| 502 |
-
f"{snippets} snippets (fallback)"
|
| 503 |
-
)
|
| 504 |
-
|
| 505 |
-
return enhanced_results
|
| 506 |
-
|
| 507 |
-
except ImportError:
|
| 508 |
-
logger.warning("Jina Reader not available - using snippets only")
|
| 509 |
-
return results
|
| 510 |
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
|
|
|
|
|
|
| 514 |
|
| 515 |
except Exception as e:
|
| 516 |
-
logger.
|
| 517 |
-
|
|
|
|
| 80 |
def __init__(
|
| 81 |
self,
|
| 82 |
live_search_adapter,
|
| 83 |
+
newsapi_adapter=None,
|
| 84 |
enable_hybrid: bool = True,
|
| 85 |
default_live_weight: float = 0.5,
|
| 86 |
default_db_weight: float = 0.5
|
|
|
|
| 90 |
|
| 91 |
Args:
|
| 92 |
live_search_adapter: DuckDuckGo adapter instance
|
| 93 |
+
newsapi_adapter: NewsAPI adapter instance (optional, for temporal queries)
|
| 94 |
enable_hybrid: Global flag to enable/disable hybrid search
|
| 95 |
default_live_weight: Default weight for live results
|
| 96 |
default_db_weight: Default weight for database results
|
| 97 |
"""
|
| 98 |
self.live_search = live_search_adapter
|
| 99 |
+
self.newsapi = newsapi_adapter
|
| 100 |
self.enable_hybrid = enable_hybrid
|
| 101 |
self.default_live_weight = default_live_weight
|
| 102 |
self.default_db_weight = default_db_weight
|
|
|
|
| 449 |
"""
|
| 450 |
Execute live search with Jina Reader enhancement.
|
| 451 |
|
| 452 |
+
Strategy:
|
| 453 |
+
1. Try NewsAPI first (if available and temporal query)
|
| 454 |
+
2. Fallback to DuckDuckGo
|
| 455 |
+
3. Extract full articles using Jina Reader (parallel)
|
| 456 |
+
4. Replace snippets with full content (14,000+ chars)
|
| 457 |
+
5. Fallback to snippets if extraction fails
|
| 458 |
|
| 459 |
Args:
|
| 460 |
query: Search query (English)
|
|
|
|
| 462 |
Returns:
|
| 463 |
List of enhanced live search results with full articles
|
| 464 |
"""
|
| 465 |
+
results = []
|
| 466 |
+
|
| 467 |
+
# Try NewsAPI first (best for temporal queries)
|
| 468 |
+
if self.newsapi and self.newsapi.is_available():
|
| 469 |
+
try:
|
| 470 |
+
logger.info(f"Live search: trying NewsAPI first for '{query}'")
|
| 471 |
+
newsapi_results = await self.newsapi.search(query)
|
| 472 |
+
|
| 473 |
+
if newsapi_results:
|
| 474 |
+
logger.info(f"NewsAPI: {len(newsapi_results)} results")
|
| 475 |
+
results.extend(newsapi_results)
|
| 476 |
+
else:
|
| 477 |
+
logger.info("NewsAPI: no results, falling back to DuckDuckGo")
|
| 478 |
+
except Exception as e:
|
| 479 |
+
logger.warning(f"NewsAPI failed: {e}, falling back to DuckDuckGo")
|
| 480 |
+
|
| 481 |
+
# Fallback to DuckDuckGo (or primary if NewsAPI not available)
|
| 482 |
+
if not results:
|
| 483 |
+
try:
|
| 484 |
+
logger.info(f"Live search: using DuckDuckGo for '{query}'")
|
| 485 |
+
results = await self.live_search.search(query)
|
| 486 |
+
logger.info(f"DuckDuckGo: {len(results)} results")
|
| 487 |
+
except Exception as e:
|
| 488 |
+
logger.error(f"DuckDuckGo search error: {e}")
|
| 489 |
+
return []
|
| 490 |
+
|
| 491 |
+
if not results:
|
| 492 |
+
logger.warning("No live search results from any source")
|
| 493 |
+
return results
|
| 494 |
+
|
| 495 |
+
# Step 2: Check if Jina Reader is enabled
|
| 496 |
+
from src.core.config import settings
|
| 497 |
+
|
| 498 |
+
if not settings.ENABLE_JINA_READER:
|
| 499 |
+
logger.info("Jina Reader disabled - using snippets only")
|
| 500 |
+
return results
|
| 501 |
+
|
| 502 |
+
# Step 3: Try to enhance with Jina Reader
|
| 503 |
try:
|
| 504 |
+
from src.infrastructure.adapters.jina_reader_adapter import get_jina_reader_adapter
|
|
|
|
|
|
|
| 505 |
|
| 506 |
+
jina = get_jina_reader_adapter(
|
| 507 |
+
timeout=settings.JINA_READER_TIMEOUT,
|
| 508 |
+
max_concurrent=settings.JINA_READER_MAX_CONCURRENT
|
| 509 |
+
)
|
| 510 |
|
| 511 |
+
# Step 4: Extract full articles (replaces snippets)
|
| 512 |
+
enhanced_results = await jina.enhance_search_results(
|
| 513 |
+
results,
|
| 514 |
+
fallback_to_snippet=True # Keep snippet if Jina fails
|
| 515 |
+
)
|
| 516 |
|
| 517 |
+
# Log enhancement stats
|
| 518 |
+
full_articles = sum(1 for r in enhanced_results if r.get("full_article"))
|
| 519 |
+
snippets = len(enhanced_results) - full_articles
|
| 520 |
+
total_chars = sum(
|
| 521 |
+
r.get("content_length", 0)
|
| 522 |
+
for r in enhanced_results
|
| 523 |
+
if r.get("full_article")
|
| 524 |
+
)
|
| 525 |
|
| 526 |
+
logger.info(
|
| 527 |
+
f"Jina enhancement: {full_articles} full articles ({total_chars:,} chars), "
|
| 528 |
+
f"{snippets} snippets (fallback)"
|
| 529 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
|
| 531 |
+
return enhanced_results
|
| 532 |
+
|
| 533 |
+
except ImportError:
|
| 534 |
+
logger.warning("Jina Reader not available - using snippets only")
|
| 535 |
+
return results
|
| 536 |
|
| 537 |
except Exception as e:
|
| 538 |
+
logger.warning(f"Jina Reader enhancement failed: {e} - using snippets")
|
| 539 |
+
return results
|
src/core/use_cases/rag_chat_use_case.py
CHANGED
|
@@ -388,15 +388,19 @@ JSON:"""
|
|
| 388 |
logger.info(f"[RAG] Hybrid search enabled - checking intent and strategy")
|
| 389 |
|
| 390 |
# Classify intent using v2 (production-grade) or v1 (fallback)
|
| 391 |
-
# Check Redis cache first to avoid
|
| 392 |
intent_result = None
|
| 393 |
intent_cache_key = f"intent_v2:{query[:80].lower().strip()}"
|
| 394 |
|
| 395 |
if self.cache:
|
| 396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
if cached_intent:
|
| 398 |
-
logger.info(f"[RAG] Intent cache HIT — skipping
|
| 399 |
-
# Reconstruct a minimal intent result from cache
|
| 400 |
class _CachedIntent:
|
| 401 |
def __init__(self, d):
|
| 402 |
self.intent = d["intent"]
|
|
@@ -405,7 +409,7 @@ JSON:"""
|
|
| 405 |
self.inference_time_ms = 0.0
|
| 406 |
intent_result = _CachedIntent(cached_intent)
|
| 407 |
intent = "NEWS" if intent_result.intent != "OTHER" else "OTHER"
|
| 408 |
-
logger.info(f"[RAG] Intent (cached): {intent_result.intent} (
|
| 409 |
|
| 410 |
if intent_result is None:
|
| 411 |
if self.use_v2_classifier and self.intent_classifier_v2:
|
|
@@ -419,13 +423,17 @@ JSON:"""
|
|
| 419 |
f"time={intent_result.inference_time_ms:.1f}ms"
|
| 420 |
)
|
| 421 |
|
| 422 |
-
# Cache intent result for 1 hour
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
if self.cache:
|
| 424 |
-
self.cache
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
}, expiration=3600)
|
| 429 |
else:
|
| 430 |
intent = self.intent_classifier.classify(query)
|
| 431 |
intent_result = None
|
|
@@ -689,24 +697,15 @@ JSON:"""
|
|
| 689 |
return "".join([f"{msg.role}: {msg.content}\n" for msg in past_messages])
|
| 690 |
|
| 691 |
def _get_cache_keys(self, query: str) -> Dict[str, str]:
|
| 692 |
-
"""
|
| 693 |
-
Generate cache keys for different caching layers.
|
| 694 |
-
|
| 695 |
-
Returns dict with keys: response, live, translation, intent
|
| 696 |
-
"""
|
| 697 |
if not self.cache:
|
| 698 |
return {}
|
| 699 |
-
|
| 700 |
query_hash = self.cache.generate_exact_hash(query)
|
| 701 |
-
query_prefix_hash = self.cache.generate_exact_hash(query[:50])
|
| 702 |
-
|
| 703 |
-
from src.core.config import settings
|
| 704 |
-
|
| 705 |
return {
|
| 706 |
-
"response":
|
| 707 |
-
"live":
|
| 708 |
"translation": f"translation:{query_hash}",
|
| 709 |
-
"intent":
|
| 710 |
}
|
| 711 |
|
| 712 |
async def execute_chat(self, request: ChatRequest) -> Dict[str, Any]:
|
|
@@ -718,17 +717,21 @@ JSON:"""
|
|
| 718 |
logger.info(f"[RAG] Generated new session_id: {request.session_id}")
|
| 719 |
session_id = request.session_id
|
| 720 |
|
| 721 |
-
# ── Layer 1: Full Response Cache
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 726 |
if cached_result:
|
| 727 |
-
|
| 728 |
self.chat_history_db.save_interaction(
|
| 729 |
-
session_id,
|
| 730 |
-
request.query,
|
| 731 |
-
cached_result["answer"],
|
| 732 |
[s.get("doc_id") for s in cached_result.get("sources", [])]
|
| 733 |
)
|
| 734 |
cached_result["debug"] = cached_result.get("debug", {})
|
|
@@ -736,7 +739,7 @@ JSON:"""
|
|
| 736 |
cached_result["debug"]["cache_layer"] = "response"
|
| 737 |
return cached_result
|
| 738 |
|
| 739 |
-
|
| 740 |
history_text = self._get_history_text(session_id)
|
| 741 |
|
| 742 |
context_text, final_sources = await self._build_context(
|
|
@@ -868,15 +871,16 @@ Answer:"""
|
|
| 868 |
}
|
| 869 |
}
|
| 870 |
|
| 871 |
-
# ── Cache the full response
|
| 872 |
-
if self.cache
|
| 873 |
-
|
| 874 |
-
self.cache
|
| 875 |
-
|
| 876 |
-
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
|
|
|
| 880 |
|
| 881 |
return result
|
| 882 |
|
|
|
|
| 388 |
logger.info(f"[RAG] Hybrid search enabled - checking intent and strategy")
|
| 389 |
|
| 390 |
# Classify intent using v2 (production-grade) or v1 (fallback)
|
| 391 |
+
# Check Redis cache first to avoid repeated LLM calls on same query
|
| 392 |
intent_result = None
|
| 393 |
intent_cache_key = f"intent_v2:{query[:80].lower().strip()}"
|
| 394 |
|
| 395 |
if self.cache:
|
| 396 |
+
# Use new layered cache method if available
|
| 397 |
+
if hasattr(self.cache, 'get_intent'):
|
| 398 |
+
cached_intent = self.cache.get_intent(query)
|
| 399 |
+
else:
|
| 400 |
+
cached_intent = self.cache.get(intent_cache_key)
|
| 401 |
+
|
| 402 |
if cached_intent:
|
| 403 |
+
logger.info(f"[RAG] Intent cache HIT — skipping LLM inference")
|
|
|
|
| 404 |
class _CachedIntent:
|
| 405 |
def __init__(self, d):
|
| 406 |
self.intent = d["intent"]
|
|
|
|
| 409 |
self.inference_time_ms = 0.0
|
| 410 |
intent_result = _CachedIntent(cached_intent)
|
| 411 |
intent = "NEWS" if intent_result.intent != "OTHER" else "OTHER"
|
| 412 |
+
logger.info(f"[RAG] Intent (cached): {intent_result.intent} (conf={intent_result.confidence:.2f})")
|
| 413 |
|
| 414 |
if intent_result is None:
|
| 415 |
if self.use_v2_classifier and self.intent_classifier_v2:
|
|
|
|
| 423 |
f"time={intent_result.inference_time_ms:.1f}ms"
|
| 424 |
)
|
| 425 |
|
| 426 |
+
# Cache intent result for 1 hour
|
| 427 |
+
intent_data = {
|
| 428 |
+
"intent": intent_result.intent,
|
| 429 |
+
"confidence": intent_result.confidence,
|
| 430 |
+
"method": intent_result.method,
|
| 431 |
+
}
|
| 432 |
if self.cache:
|
| 433 |
+
if hasattr(self.cache, 'set_intent'):
|
| 434 |
+
self.cache.set_intent(query, intent_data)
|
| 435 |
+
else:
|
| 436 |
+
self.cache.set(intent_cache_key, intent_data, expiration=3600)
|
|
|
|
| 437 |
else:
|
| 438 |
intent = self.intent_classifier.classify(query)
|
| 439 |
intent_result = None
|
|
|
|
| 697 |
return "".join([f"{msg.role}: {msg.content}\n" for msg in past_messages])
|
| 698 |
|
| 699 |
def _get_cache_keys(self, query: str) -> Dict[str, str]:
|
| 700 |
+
"""Generate cache keys — kept for backward compat, new code uses RedisAdapter methods directly."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 701 |
if not self.cache:
|
| 702 |
return {}
|
|
|
|
| 703 |
query_hash = self.cache.generate_exact_hash(query)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 704 |
return {
|
| 705 |
+
"response": f"rag_response:{query_hash}",
|
| 706 |
+
"live": f"live_search:{query_hash}",
|
| 707 |
"translation": f"translation:{query_hash}",
|
| 708 |
+
"intent": f"intent_v2:{query_hash}",
|
| 709 |
}
|
| 710 |
|
| 711 |
async def execute_chat(self, request: ChatRequest) -> Dict[str, Any]:
|
|
|
|
| 717 |
logger.info(f"[RAG] Generated new session_id: {request.session_id}")
|
| 718 |
session_id = request.session_id
|
| 719 |
|
| 720 |
+
# ── Layer 1: Full Response Cache ──────────────────────────────────────
|
| 721 |
+
if self.cache:
|
| 722 |
+
# Use new layered cache method if available
|
| 723 |
+
if hasattr(self.cache, 'get_response'):
|
| 724 |
+
cached_result = self.cache.get_response(request.query)
|
| 725 |
+
else:
|
| 726 |
+
cache_keys = self._get_cache_keys(request.query)
|
| 727 |
+
cached_result = self.cache.get(cache_keys.get("response", "")) if cache_keys else None
|
| 728 |
+
|
| 729 |
if cached_result:
|
| 730 |
+
logger.info("[RAG] Cache HIT — returning cached response")
|
| 731 |
self.chat_history_db.save_interaction(
|
| 732 |
+
session_id,
|
| 733 |
+
request.query,
|
| 734 |
+
cached_result["answer"],
|
| 735 |
[s.get("doc_id") for s in cached_result.get("sources", [])]
|
| 736 |
)
|
| 737 |
cached_result["debug"] = cached_result.get("debug", {})
|
|
|
|
| 739 |
cached_result["debug"]["cache_layer"] = "response"
|
| 740 |
return cached_result
|
| 741 |
|
| 742 |
+
logger.info("[RAG] Cache MISS — running full RAG pipeline")
|
| 743 |
history_text = self._get_history_text(session_id)
|
| 744 |
|
| 745 |
context_text, final_sources = await self._build_context(
|
|
|
|
| 871 |
}
|
| 872 |
}
|
| 873 |
|
| 874 |
+
# ── Cache the full response with intent-aware TTL ─────────────────────
|
| 875 |
+
if self.cache:
|
| 876 |
+
detected_intent = result.get("debug", {}).get("intent", "NEWS_GENERAL")
|
| 877 |
+
if hasattr(self.cache, 'set_response'):
|
| 878 |
+
self.cache.set_response(request.query, result, intent=detected_intent)
|
| 879 |
+
else:
|
| 880 |
+
cache_keys = self._get_cache_keys(request.query)
|
| 881 |
+
if cache_keys.get("response"):
|
| 882 |
+
from src.core.config import settings
|
| 883 |
+
self.cache.set(cache_keys["response"], result, expiration=settings.CACHE_RESPONSE_TTL)
|
| 884 |
|
| 885 |
return result
|
| 886 |
|
src/infrastructure/adapters/bge_reranker_adapter.py
CHANGED
|
@@ -1,67 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
import threading
|
| 3 |
-
from typing import List, Dict, Any
|
|
|
|
| 4 |
from src.core.config import settings
|
| 5 |
from src.core.ports.reranker_port import RerankerPort
|
| 6 |
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
|
|
|
|
| 9 |
try:
|
| 10 |
-
import transformers.utils.import_utils
|
| 11 |
-
if not hasattr(
|
| 12 |
-
|
| 13 |
except Exception:
|
| 14 |
pass
|
| 15 |
|
| 16 |
-
#
|
| 17 |
-
# Fallback to sentence-transformers CrossEncoder if FlagEmbedding is unavailable
|
| 18 |
try:
|
| 19 |
from FlagEmbedding import FlagReranker
|
| 20 |
HAS_FLAG_RERANKER = True
|
| 21 |
except ImportError:
|
| 22 |
HAS_FLAG_RERANKER = False
|
| 23 |
-
logger.warning("FlagEmbedding not available for FlagReranker — trying CrossEncoder fallback.")
|
| 24 |
|
|
|
|
| 25 |
try:
|
| 26 |
from sentence_transformers import CrossEncoder
|
| 27 |
HAS_CROSS_ENCODER = True
|
| 28 |
except ImportError:
|
| 29 |
HAS_CROSS_ENCODER = False
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
class BgeRerankerAdapter(RerankerPort):
|
| 36 |
"""
|
| 37 |
-
|
| 38 |
|
| 39 |
-
|
| 40 |
-
-
|
| 41 |
-
- Natively multilingual: Arabic, Amharic, Somali, Swahili, French, English
|
| 42 |
-
- Significantly better than ms-marco-TinyBERT for non-English content
|
| 43 |
-
- Uses FlagReranker (FlagEmbedding) as primary, CrossEncoder as fallback
|
| 44 |
|
| 45 |
-
|
| 46 |
"""
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
def __init__(self):
|
| 49 |
-
self.model = None
|
| 50 |
self.model_name = settings.RERANKER_MODEL
|
|
|
|
| 51 |
self._lock = threading.Lock()
|
| 52 |
self._load_failed = False
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
return
|
| 57 |
with self._lock:
|
| 58 |
-
if self.
|
| 59 |
return
|
| 60 |
-
logger.info(f"Loading
|
| 61 |
try:
|
| 62 |
if HAS_FLAG_RERANKER and "bge-reranker" in self.model_name.lower():
|
| 63 |
-
# Patch
|
| 64 |
-
# Different transformers versions on HF Spaces may lack different methods
|
| 65 |
try:
|
| 66 |
from transformers import XLMRobertaTokenizer, PreTrainedTokenizer
|
| 67 |
for method_name in [
|
|
@@ -69,7 +229,6 @@ class BgeRerankerAdapter(RerankerPort):
|
|
| 69 |
"build_inputs_with_special_tokens",
|
| 70 |
"create_token_type_ids_from_sequences",
|
| 71 |
"get_special_tokens_mask",
|
| 72 |
-
"special_tokens_pattern",
|
| 73 |
"convert_tokens_to_string",
|
| 74 |
]:
|
| 75 |
if not hasattr(XLMRobertaTokenizer, method_name):
|
|
@@ -79,73 +238,133 @@ class BgeRerankerAdapter(RerankerPort):
|
|
| 79 |
except Exception as patch_err:
|
| 80 |
logger.debug(f"Tokenizer patch skipped: {patch_err}")
|
| 81 |
|
| 82 |
-
|
| 83 |
-
self.
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
logger.info(f"✅ Loaded {self.model_name} via FlagReranker (multilingual, fp16)")
|
| 91 |
-
except Exception as flag_err:
|
| 92 |
-
logger.warning(f"FlagReranker failed ({flag_err}) — falling back to CrossEncoder")
|
| 93 |
-
if HAS_CROSS_ENCODER:
|
| 94 |
-
self.model = CrossEncoder(self.model_name)
|
| 95 |
-
self._use_flag = False
|
| 96 |
-
logger.info(f"✅ Loaded {self.model_name} via CrossEncoder (fallback)")
|
| 97 |
-
else:
|
| 98 |
-
raise
|
| 99 |
|
| 100 |
elif HAS_CROSS_ENCODER:
|
| 101 |
-
self.
|
| 102 |
self._use_flag = False
|
| 103 |
-
logger.info(f"✅
|
|
|
|
| 104 |
else:
|
| 105 |
-
logger.error("No
|
| 106 |
self._load_failed = True
|
|
|
|
| 107 |
except Exception as e:
|
| 108 |
-
logger.error(f"Failed to load reranker
|
| 109 |
self._load_failed = True
|
| 110 |
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
if not docs:
|
| 113 |
return []
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n]
|
| 122 |
|
| 123 |
-
# Build (query, content) pairs.
|
| 124 |
-
# PERFORMANCE: Truncate content to 512 chars (~128 tokens) before scoring.
|
| 125 |
-
# The reranker only needs the opening paragraph to judge topical relevance.
|
| 126 |
-
# Full articles waste ~60% of inference time on boilerplate text.
|
| 127 |
-
MAX_CONTENT_CHARS = 512
|
| 128 |
pairs = []
|
| 129 |
valid_docs = []
|
| 130 |
for doc in docs:
|
| 131 |
content = doc.get("content", "").strip()
|
| 132 |
if content:
|
| 133 |
-
|
| 134 |
-
pairs.append([query, truncated])
|
| 135 |
valid_docs.append(doc)
|
| 136 |
|
| 137 |
if not pairs:
|
| 138 |
return []
|
| 139 |
|
| 140 |
try:
|
| 141 |
-
if
|
| 142 |
-
|
| 143 |
-
# content is truncated (much smaller tensors per pair)
|
| 144 |
-
scores = self.model.compute_score(pairs, batch_size=64)
|
| 145 |
if isinstance(scores, float):
|
| 146 |
scores = [scores]
|
| 147 |
else:
|
| 148 |
-
scores = self.
|
| 149 |
if isinstance(scores, float):
|
| 150 |
scores = [scores]
|
| 151 |
|
|
@@ -153,8 +372,17 @@ class BgeRerankerAdapter(RerankerPort):
|
|
| 153 |
doc["rerank_score"] = float(scores[i])
|
| 154 |
|
| 155 |
valid_docs.sort(key=lambda x: x["rerank_score"], reverse=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
return valid_docs[:top_n]
|
| 157 |
|
| 158 |
except Exception as e:
|
| 159 |
-
logger.error(f"
|
| 160 |
return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reranker Adapter — supports BGE-Reranker-v2-m3 AND Jina-Reranker-v3
|
| 3 |
+
|
| 4 |
+
Auto-detects which model to load based on RERANKER_MODEL setting:
|
| 5 |
+
- "BAAI/bge-reranker-v2-m3" → FlagReranker (pointwise cross-encoder)
|
| 6 |
+
- "jinaai/jina-reranker-v3" → Jina v3 listwise reranker
|
| 7 |
+
|
| 8 |
+
Jina v3 advantages over BGE for this project:
|
| 9 |
+
- Listwise: sees ALL docs at once → better cross-doc comparison
|
| 10 |
+
- 131K context window → reads full Jina-extracted articles (not just 512 chars)
|
| 11 |
+
- +9.6% better on English news (BEIR 61.94 vs 56.51)
|
| 12 |
+
- Better Arabic ranking (78.69 nDCG)
|
| 13 |
+
- Same size (0.6B), same memory, same cost (free, self-hosted)
|
| 14 |
+
|
| 15 |
+
Thread-safe lazy loading — model loads once on first rerank call.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
import logging
|
| 19 |
import threading
|
| 20 |
+
from typing import List, Dict, Any, Optional
|
| 21 |
+
|
| 22 |
from src.core.config import settings
|
| 23 |
from src.core.ports.reranker_port import RerankerPort
|
| 24 |
|
| 25 |
logger = logging.getLogger(__name__)
|
| 26 |
|
| 27 |
+
# ── Patch transformers compatibility issue ────────────────────────────────────
|
| 28 |
try:
|
| 29 |
+
import transformers.utils.import_utils as _tui
|
| 30 |
+
if not hasattr(_tui, "is_torch_fx_available"):
|
| 31 |
+
_tui.is_torch_fx_available = lambda: False
|
| 32 |
except Exception:
|
| 33 |
pass
|
| 34 |
|
| 35 |
+
# ── Try FlagEmbedding (for BGE) ───────────────────────────────────────────────
|
|
|
|
| 36 |
try:
|
| 37 |
from FlagEmbedding import FlagReranker
|
| 38 |
HAS_FLAG_RERANKER = True
|
| 39 |
except ImportError:
|
| 40 |
HAS_FLAG_RERANKER = False
|
|
|
|
| 41 |
|
| 42 |
+
# ── Try sentence-transformers CrossEncoder (BGE fallback) ────────────────────
|
| 43 |
try:
|
| 44 |
from sentence_transformers import CrossEncoder
|
| 45 |
HAS_CROSS_ENCODER = True
|
| 46 |
except ImportError:
|
| 47 |
HAS_CROSS_ENCODER = False
|
| 48 |
|
| 49 |
+
# ── Try transformers (for Jina v3) ────────────────────────────────────────────
|
| 50 |
+
try:
|
| 51 |
+
import torch
|
| 52 |
+
from transformers import AutoModel
|
| 53 |
+
HAS_TRANSFORMERS = True
|
| 54 |
+
except ImportError:
|
| 55 |
+
HAS_TRANSFORMERS = False
|
| 56 |
+
logger.warning("transformers/torch not available — Jina v3 reranker disabled.")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 60 |
+
# JINA V3 RERANKER
|
| 61 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 62 |
+
|
| 63 |
+
class JinaV3Reranker:
|
| 64 |
+
"""
|
| 65 |
+
Jina-Reranker-v3 self-hosted reranker.
|
| 66 |
+
|
| 67 |
+
Key differences from BGE pointwise:
|
| 68 |
+
- Listwise: processes all docs in one forward pass
|
| 69 |
+
- 131K context window: reads full articles, not just first 512 chars
|
| 70 |
+
- Built on Qwen3-0.6B backbone with causal self-attention
|
| 71 |
+
- State-of-the-art BEIR: 61.94 nDCG@10 (vs BGE's 56.51)
|
| 72 |
+
|
| 73 |
+
Scoring: uses sigmoid(logits) for normalized 0-1 scores.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def __init__(self, model_name: str):
|
| 77 |
+
self.model_name = model_name
|
| 78 |
+
self._model = None
|
| 79 |
+
self._lock = threading.Lock()
|
| 80 |
+
self._load_failed = False
|
| 81 |
+
self._device = "cpu"
|
| 82 |
+
|
| 83 |
+
def _load(self):
|
| 84 |
+
if self._model is not None or self._load_failed:
|
| 85 |
+
return
|
| 86 |
+
with self._lock:
|
| 87 |
+
if self._model is not None or self._load_failed:
|
| 88 |
+
return
|
| 89 |
+
if not HAS_TRANSFORMERS:
|
| 90 |
+
logger.error("transformers not installed — cannot load Jina v3")
|
| 91 |
+
self._load_failed = True
|
| 92 |
+
return
|
| 93 |
+
try:
|
| 94 |
+
logger.info(f"Loading Jina v3 reranker: {self.model_name}")
|
| 95 |
+
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 96 |
+
|
| 97 |
+
# Jina v3 uses AutoModel (NOT AutoModelForSequenceClassification)
|
| 98 |
+
# It has a built-in .rerank() method that returns relevance_score directly
|
| 99 |
+
from transformers import AutoModel
|
| 100 |
+
self._model = AutoModel.from_pretrained(
|
| 101 |
+
self.model_name,
|
| 102 |
+
trust_remote_code=True,
|
| 103 |
+
dtype="auto",
|
| 104 |
+
)
|
| 105 |
+
self._model.eval()
|
| 106 |
+
logger.info(
|
| 107 |
+
f"✅ Jina v3 reranker loaded on {self._device} "
|
| 108 |
+
f"(model={self.model_name})"
|
| 109 |
+
)
|
| 110 |
+
except Exception as e:
|
| 111 |
+
logger.error(f"Failed to load Jina v3 reranker: {e}", exc_info=True)
|
| 112 |
+
self._load_failed = True
|
| 113 |
+
|
| 114 |
+
def compute_scores(
|
| 115 |
+
self,
|
| 116 |
+
query: str,
|
| 117 |
+
docs: List[str],
|
| 118 |
+
max_length: int = 1024,
|
| 119 |
+
) -> List[float]:
|
| 120 |
+
"""
|
| 121 |
+
Score all (query, doc) pairs using Jina v3's built-in .rerank() method.
|
| 122 |
+
|
| 123 |
+
Returns scores in original doc order (not sorted).
|
| 124 |
+
"""
|
| 125 |
+
if not docs:
|
| 126 |
+
return []
|
| 127 |
+
|
| 128 |
+
self._load()
|
| 129 |
+
if self._model is None:
|
| 130 |
+
return [0.5] * len(docs)
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
# Jina v3's .rerank() returns list of dicts:
|
| 134 |
+
# [{"document": str, "relevance_score": float, "index": int}, ...]
|
| 135 |
+
# Results are sorted by relevance_score descending — we need to
|
| 136 |
+
# restore original order using the "index" field.
|
| 137 |
+
results = self._model.rerank(query, docs)
|
| 138 |
+
|
| 139 |
+
# Restore original order
|
| 140 |
+
scores = [0.0] * len(docs)
|
| 141 |
+
for r in results:
|
| 142 |
+
original_idx = r["index"]
|
| 143 |
+
scores[original_idx] = float(r["relevance_score"])
|
| 144 |
+
|
| 145 |
+
return scores
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
logger.error(f"Jina v3 rerank() failed: {e}")
|
| 149 |
+
return [0.0] * len(docs)
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def is_loaded(self) -> bool:
|
| 153 |
+
return self._model is not None
|
| 154 |
+
|
| 155 |
|
| 156 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 157 |
+
# UNIFIED RERANKER ADAPTER
|
| 158 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 159 |
|
| 160 |
class BgeRerankerAdapter(RerankerPort):
|
| 161 |
"""
|
| 162 |
+
Unified reranker adapter — auto-selects BGE or Jina v3 based on config.
|
| 163 |
|
| 164 |
+
RERANKER_MODEL=jinaai/jina-reranker-v3 → Jina v3 (recommended)
|
| 165 |
+
RERANKER_MODEL=BAAI/bge-reranker-v2-m3 → BGE (legacy)
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
+
Both are self-hosted, free, ~0.6B parameters, ~1.2GB on disk.
|
| 168 |
"""
|
| 169 |
|
| 170 |
+
# Max content chars to send to reranker
|
| 171 |
+
# Jina v3: 1024 tokens ≈ 4096 chars — reads much more than BGE's 512 chars
|
| 172 |
+
MAX_CONTENT_CHARS_JINA = 4096
|
| 173 |
+
MAX_CONTENT_CHARS_BGE = 512
|
| 174 |
+
|
| 175 |
def __init__(self):
|
|
|
|
| 176 |
self.model_name = settings.RERANKER_MODEL
|
| 177 |
+
self._is_jina_v3 = "jina-reranker-v3" in self.model_name.lower()
|
| 178 |
self._lock = threading.Lock()
|
| 179 |
self._load_failed = False
|
| 180 |
|
| 181 |
+
# Check if Jina API reranker is enabled (takes priority over self-hosted)
|
| 182 |
+
self._jina_api = None
|
| 183 |
+
if getattr(settings, 'JINA_RERANKER_ENABLED', False) and getattr(settings, 'JINA_API_KEY', ''):
|
| 184 |
+
try:
|
| 185 |
+
from src.infrastructure.adapters.jina_reranker_adapter import JinaRerankerAPIAdapter
|
| 186 |
+
jina_key = settings.JINA_API_KEY
|
| 187 |
+
if jina_key and jina_key not in ("", "your-jina-api-key-here"):
|
| 188 |
+
self._jina_api = JinaRerankerAPIAdapter(
|
| 189 |
+
api_key=jina_key,
|
| 190 |
+
model=getattr(settings, 'JINA_RERANKER_MODEL', 'jina-reranker-v3'),
|
| 191 |
+
timeout=getattr(settings, 'JINA_RERANKER_TIMEOUT', 5.0),
|
| 192 |
+
)
|
| 193 |
+
logger.info("Reranker configured: Jina API (cloud, fast)")
|
| 194 |
+
except Exception as e:
|
| 195 |
+
logger.warning(f"Jina API reranker init failed: {e}")
|
| 196 |
+
|
| 197 |
+
# Jina v3 self-hosted path
|
| 198 |
+
if self._is_jina_v3 and not self._jina_api:
|
| 199 |
+
self._jina = JinaV3Reranker(self.model_name)
|
| 200 |
+
self._bge_model = None
|
| 201 |
+
self._use_flag = False
|
| 202 |
+
logger.info(f"Reranker configured: Jina v3 self-hosted ({self.model_name})")
|
| 203 |
+
elif not self._jina_api:
|
| 204 |
+
# BGE path
|
| 205 |
+
self._jina = None
|
| 206 |
+
self._bge_model = None
|
| 207 |
+
self._use_flag = False
|
| 208 |
+
logger.info(f"Reranker configured: BGE ({self.model_name})")
|
| 209 |
+
else:
|
| 210 |
+
self._jina = None
|
| 211 |
+
self._bge_model = None
|
| 212 |
+
self._use_flag = False
|
| 213 |
+
|
| 214 |
+
def _load_bge(self):
|
| 215 |
+
"""Lazy-load BGE reranker (thread-safe)."""
|
| 216 |
+
if self._bge_model is not None or self._load_failed:
|
| 217 |
return
|
| 218 |
with self._lock:
|
| 219 |
+
if self._bge_model is not None or self._load_failed:
|
| 220 |
return
|
| 221 |
+
logger.info(f"Loading BGE reranker: {self.model_name}")
|
| 222 |
try:
|
| 223 |
if HAS_FLAG_RERANKER and "bge-reranker" in self.model_name.lower():
|
| 224 |
+
# Patch XLMRobertaTokenizer for older transformers versions
|
|
|
|
| 225 |
try:
|
| 226 |
from transformers import XLMRobertaTokenizer, PreTrainedTokenizer
|
| 227 |
for method_name in [
|
|
|
|
| 229 |
"build_inputs_with_special_tokens",
|
| 230 |
"create_token_type_ids_from_sequences",
|
| 231 |
"get_special_tokens_mask",
|
|
|
|
| 232 |
"convert_tokens_to_string",
|
| 233 |
]:
|
| 234 |
if not hasattr(XLMRobertaTokenizer, method_name):
|
|
|
|
| 238 |
except Exception as patch_err:
|
| 239 |
logger.debug(f"Tokenizer patch skipped: {patch_err}")
|
| 240 |
|
| 241 |
+
self._bge_model = FlagReranker(
|
| 242 |
+
self.model_name,
|
| 243 |
+
use_fp16=True,
|
| 244 |
+
normalize=True,
|
| 245 |
+
trust_remote_code=True,
|
| 246 |
+
)
|
| 247 |
+
self._use_flag = True
|
| 248 |
+
logger.info(f"✅ BGE loaded via FlagReranker (fp16, multilingual)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
|
| 250 |
elif HAS_CROSS_ENCODER:
|
| 251 |
+
self._bge_model = CrossEncoder(self.model_name)
|
| 252 |
self._use_flag = False
|
| 253 |
+
logger.info(f"✅ BGE loaded via CrossEncoder (fallback)")
|
| 254 |
+
|
| 255 |
else:
|
| 256 |
+
logger.error("No BGE backend available (FlagEmbedding or sentence-transformers required)")
|
| 257 |
self._load_failed = True
|
| 258 |
+
|
| 259 |
except Exception as e:
|
| 260 |
+
logger.error(f"Failed to load BGE reranker '{self.model_name}': {e}", exc_info=True)
|
| 261 |
self._load_failed = True
|
| 262 |
|
| 263 |
+
# ── Public interface ──────────────────────────────────────────────────────
|
| 264 |
+
|
| 265 |
+
def rerank(
|
| 266 |
+
self,
|
| 267 |
+
query: str,
|
| 268 |
+
docs: List[Dict[str, Any]],
|
| 269 |
+
top_n: int = 5,
|
| 270 |
+
) -> List[Dict[str, Any]]:
|
| 271 |
+
"""
|
| 272 |
+
Rerank documents by relevance to query.
|
| 273 |
+
|
| 274 |
+
Priority: Jina API (cloud) > Jina v3 self-hosted > BGE
|
| 275 |
+
Jina v3 path: uses full article content (up to 4096 chars)
|
| 276 |
+
BGE path: uses first 512 chars only
|
| 277 |
+
|
| 278 |
+
Returns top_n docs sorted by rerank_score descending.
|
| 279 |
+
"""
|
| 280 |
if not docs:
|
| 281 |
return []
|
| 282 |
|
| 283 |
+
# Priority: Jina API > Jina v3 self-hosted > BGE
|
| 284 |
+
if self._jina_api and self._jina_api.is_available():
|
| 285 |
+
return self._jina_api.rerank(query, docs, top_n)
|
| 286 |
+
elif self._is_jina_v3 and self._jina:
|
| 287 |
+
return self._rerank_jina(query, docs, top_n)
|
| 288 |
+
else:
|
| 289 |
+
return self._rerank_bge(query, docs, top_n)
|
| 290 |
+
|
| 291 |
+
def _rerank_jina(
|
| 292 |
+
self,
|
| 293 |
+
query: str,
|
| 294 |
+
docs: List[Dict[str, Any]],
|
| 295 |
+
top_n: int,
|
| 296 |
+
) -> List[Dict[str, Any]]:
|
| 297 |
+
"""Rerank using Jina v3 — reads full article content."""
|
| 298 |
+
# Ensure model is loaded
|
| 299 |
+
self._jina._load()
|
| 300 |
+
|
| 301 |
+
if self._jina._load_failed or not self._jina.is_loaded:
|
| 302 |
+
logger.warning("Jina v3 unavailable — falling back to vector score ordering")
|
| 303 |
+
return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n]
|
| 304 |
+
|
| 305 |
+
# Build content list — use full content up to 4096 chars
|
| 306 |
+
# This is the key advantage: Jina reads 8x more content than BGE
|
| 307 |
+
valid_docs = []
|
| 308 |
+
doc_texts = []
|
| 309 |
+
for doc in docs:
|
| 310 |
+
content = doc.get("content", "").strip()
|
| 311 |
+
if content:
|
| 312 |
+
doc_texts.append(content[:self.MAX_CONTENT_CHARS_JINA])
|
| 313 |
+
valid_docs.append(doc)
|
| 314 |
+
|
| 315 |
+
if not doc_texts:
|
| 316 |
+
return []
|
| 317 |
+
|
| 318 |
+
try:
|
| 319 |
+
scores = self._jina.compute_scores(query, doc_texts)
|
| 320 |
|
| 321 |
+
for i, doc in enumerate(valid_docs):
|
| 322 |
+
doc["rerank_score"] = scores[i]
|
| 323 |
+
|
| 324 |
+
valid_docs.sort(key=lambda x: x["rerank_score"], reverse=True)
|
| 325 |
+
|
| 326 |
+
logger.info(
|
| 327 |
+
f"[Reranker] Jina v3: {len(valid_docs)} docs → top {top_n} "
|
| 328 |
+
f"(max_score={valid_docs[0]['rerank_score']:.3f})"
|
| 329 |
+
)
|
| 330 |
+
return valid_docs[:top_n]
|
| 331 |
+
|
| 332 |
+
except Exception as e:
|
| 333 |
+
logger.error(f"Jina v3 reranking failed: {e} — falling back to vector score")
|
| 334 |
+
return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n]
|
| 335 |
+
|
| 336 |
+
def _rerank_bge(
|
| 337 |
+
self,
|
| 338 |
+
query: str,
|
| 339 |
+
docs: List[Dict[str, Any]],
|
| 340 |
+
top_n: int,
|
| 341 |
+
) -> List[Dict[str, Any]]:
|
| 342 |
+
"""Rerank using BGE — reads first 512 chars only."""
|
| 343 |
+
if self._bge_model is None:
|
| 344 |
+
self._load_bge()
|
| 345 |
+
|
| 346 |
+
if self._bge_model is None:
|
| 347 |
+
logger.warning("BGE unavailable — falling back to vector score ordering")
|
| 348 |
return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n]
|
| 349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
pairs = []
|
| 351 |
valid_docs = []
|
| 352 |
for doc in docs:
|
| 353 |
content = doc.get("content", "").strip()
|
| 354 |
if content:
|
| 355 |
+
pairs.append([query, content[:self.MAX_CONTENT_CHARS_BGE]])
|
|
|
|
| 356 |
valid_docs.append(doc)
|
| 357 |
|
| 358 |
if not pairs:
|
| 359 |
return []
|
| 360 |
|
| 361 |
try:
|
| 362 |
+
if self._use_flag:
|
| 363 |
+
scores = self._bge_model.compute_score(pairs, batch_size=64)
|
|
|
|
|
|
|
| 364 |
if isinstance(scores, float):
|
| 365 |
scores = [scores]
|
| 366 |
else:
|
| 367 |
+
scores = self._bge_model.predict(pairs)
|
| 368 |
if isinstance(scores, float):
|
| 369 |
scores = [scores]
|
| 370 |
|
|
|
|
| 372 |
doc["rerank_score"] = float(scores[i])
|
| 373 |
|
| 374 |
valid_docs.sort(key=lambda x: x["rerank_score"], reverse=True)
|
| 375 |
+
|
| 376 |
+
logger.info(
|
| 377 |
+
f"[Reranker] BGE: {len(valid_docs)} docs → top {top_n} "
|
| 378 |
+
f"(max_score={valid_docs[0]['rerank_score']:.3f})"
|
| 379 |
+
)
|
| 380 |
return valid_docs[:top_n]
|
| 381 |
|
| 382 |
except Exception as e:
|
| 383 |
+
logger.error(f"BGE reranking failed: {e} — falling back to vector score")
|
| 384 |
return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n]
|
| 385 |
+
|
| 386 |
+
@property
|
| 387 |
+
def model_type(self) -> str:
|
| 388 |
+
return "jina_v3" if self._is_jina_v3 else "bge"
|
src/infrastructure/adapters/intent_classifier_v2.py
CHANGED
|
@@ -1,53 +1,213 @@
|
|
| 1 |
"""
|
| 2 |
-
Intent Classifier
|
| 3 |
|
| 4 |
Architecture:
|
| 5 |
-
Layer
|
|
|
|
|
|
|
| 6 |
Layer 2: Groq llama-3.1-8b-instant — 14,400 free RPD, ~50ms (PRIMARY)
|
| 7 |
Layer 3: Gemini Flash fallback — 1,500 free RPD, ~200ms (FALLBACK 1)
|
| 8 |
Layer 4: OpenRouter free router — free models pool, ~300ms (FALLBACK 2)
|
| 9 |
Layer 5: HuggingFace Inference API — ~300 RPH, ~2s (FALLBACK 3)
|
| 10 |
Layer 6: Safe default — NEWS_GENERAL, 0ms (ALWAYS WORKS)
|
| 11 |
|
| 12 |
-
|
| 13 |
-
-
|
| 14 |
-
-
|
| 15 |
-
-
|
| 16 |
-
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
-
|
| 21 |
-
-
|
| 22 |
-
-
|
| 23 |
-
- HuggingFace: ~300 RPH free — last resort (slower but always available)
|
| 24 |
-
- Default: NEWS_GENERAL — never fails, safe for user experience
|
| 25 |
"""
|
| 26 |
|
| 27 |
import logging
|
|
|
|
| 28 |
import time
|
| 29 |
-
import threading
|
| 30 |
import httpx
|
| 31 |
from dataclasses import dataclass
|
| 32 |
-
from typing import Any, Dict, Optional
|
| 33 |
|
| 34 |
logger = logging.getLogger(__name__)
|
| 35 |
|
| 36 |
|
| 37 |
# ═══════════════════════════════════════════════════════════════════════════════
|
| 38 |
-
# LAYER
|
| 39 |
# ═══════════════════════════════════════════════════════════════════════════════
|
| 40 |
|
| 41 |
_INSTANT_OTHER = {
|
| 42 |
"hi", "hello", "hey", "thanks", "thank you", "bye", "goodbye",
|
| 43 |
"ok", "okay", "yes", "no", "sure", "cool", "nice",
|
| 44 |
"lol", "lmao", "haha", "omg", "wtf", "wow",
|
| 45 |
-
".", "..", "...", "?", "!", "test",
|
| 46 |
}
|
| 47 |
|
| 48 |
|
| 49 |
# ═══════════════════════════════════════════════════════════════════════════════
|
| 50 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
# ═══════════════════════════════════════════════════════════════════════════════
|
| 52 |
|
| 53 |
_CLASSIFY_PROMPT = """You are an intent classifier for ARKI AI, a news assistant focused on Ethiopia and Africa.
|
|
@@ -81,10 +241,10 @@ Category:"""
|
|
| 81 |
class IntentResult:
|
| 82 |
intent: str # NEWS_TEMPORAL | NEWS_HISTORICAL | NEWS_GENERAL | OTHER
|
| 83 |
confidence: float # 0.0 – 1.0
|
| 84 |
-
method: str # instant | llm_groq | llm_gemini | llm_openrouter | llm_hf | default
|
| 85 |
inference_time_ms: float
|
| 86 |
-
query_complexity: str # vague | simple | medium | complex
|
| 87 |
-
sub_type: str # general | conflict | humanitarian | identity | off_topic
|
| 88 |
should_use_live: bool
|
| 89 |
should_use_db: bool
|
| 90 |
metadata: Dict[str, Any]
|
|
@@ -109,22 +269,23 @@ class IntentResult:
|
|
| 109 |
|
| 110 |
class IntentClassifierV2:
|
| 111 |
"""
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
"""
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"
|
| 125 |
-
OPENROUTER_MODEL = "openrouter/auto" # Auto-selects best available free model
|
| 126 |
-
|
| 127 |
-
HF_URL = "https://api-inference.huggingface.co/models/meta-llama/Llama-3.2-3B-Instruct/v1/chat/completions"
|
| 128 |
|
| 129 |
VALID_INTENTS = {"NEWS_TEMPORAL", "NEWS_HISTORICAL", "NEWS_GENERAL", "OTHER"}
|
| 130 |
|
|
@@ -134,49 +295,42 @@ class IntentClassifierV2:
|
|
| 134 |
self._openrouter_key: Optional[str] = None
|
| 135 |
self._hf_token: Optional[str] = None
|
| 136 |
self._client = httpx.Client(timeout=5.0)
|
| 137 |
-
self._metrics = {
|
| 138 |
"total": 0,
|
| 139 |
"by_intent": {},
|
| 140 |
"by_method": {},
|
| 141 |
"total_ms": 0.0,
|
|
|
|
|
|
|
| 142 |
}
|
| 143 |
self._load_keys()
|
| 144 |
|
| 145 |
def _load_keys(self):
|
| 146 |
-
"""Load API keys from settings."""
|
| 147 |
try:
|
| 148 |
from src.core.config import settings
|
| 149 |
-
|
| 150 |
key = settings.GROQ_API_KEY
|
| 151 |
if key and key not in ("", "your-groq-api-key-here"):
|
| 152 |
self._groq_key = key
|
| 153 |
-
|
| 154 |
gem = settings.GEMINI_API_KEY
|
| 155 |
if gem and gem not in ("", "your-gemini-api-key-here"):
|
| 156 |
self._gemini_key = gem
|
| 157 |
-
|
| 158 |
-
# OpenRouter key (add OPENROUTER_API_KEY to .env)
|
| 159 |
try:
|
| 160 |
or_key = getattr(settings, "OPENROUTER_API_KEY", "")
|
| 161 |
if or_key and or_key not in ("", "your-openrouter-api-key-here"):
|
| 162 |
self._openrouter_key = or_key
|
| 163 |
except Exception:
|
| 164 |
pass
|
| 165 |
-
|
| 166 |
-
# HuggingFace token
|
| 167 |
hf = settings.HF_TOKEN
|
| 168 |
if hf and hf not in ("", "your-hf-token-here"):
|
| 169 |
self._hf_token = hf
|
| 170 |
|
| 171 |
-
providers = []
|
| 172 |
-
if self._groq_key:
|
| 173 |
-
if self._gemini_key:
|
| 174 |
-
if self._openrouter_key:
|
| 175 |
-
if self._hf_token:
|
| 176 |
providers.append("Default")
|
| 177 |
-
|
| 178 |
-
logger.info(f"✅ Intent classifier providers: {' → '.join(providers)}")
|
| 179 |
-
|
| 180 |
except Exception as e:
|
| 181 |
logger.error(f"Intent classifier: failed to load keys: {e}")
|
| 182 |
|
|
@@ -188,63 +342,62 @@ class IntentClassifierV2:
|
|
| 188 |
ql = q.lower()
|
| 189 |
complexity = self._complexity(q)
|
| 190 |
|
| 191 |
-
# ── Layer
|
| 192 |
if ql in _INSTANT_OTHER:
|
| 193 |
return self._result("OTHER", 1.0, "instant", t0, complexity, "identity")
|
| 194 |
|
| 195 |
-
# ── Layer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
if self._groq_key:
|
| 197 |
intent = self._call_openai_compat(
|
| 198 |
-
url=self.GROQ_URL,
|
| 199 |
-
|
| 200 |
-
model=self.GROQ_MODEL,
|
| 201 |
-
query=q,
|
| 202 |
-
provider="groq",
|
| 203 |
)
|
| 204 |
if intent:
|
| 205 |
return self._result(intent, 0.97, "llm_groq", t0, complexity,
|
| 206 |
self._sub_type(q, intent))
|
| 207 |
|
| 208 |
-
# ── Layer 3: Gemini Flash (FALLBACK 1) ────────────────────────────────
|
| 209 |
if self._gemini_key:
|
| 210 |
intent = self._call_gemini(q)
|
| 211 |
if intent:
|
| 212 |
return self._result(intent, 0.95, "llm_gemini", t0, complexity,
|
| 213 |
self._sub_type(q, intent))
|
| 214 |
|
| 215 |
-
# ── Layer 4: OpenRouter free router (FALLBACK 2) ─────────────────────
|
| 216 |
if self._openrouter_key:
|
| 217 |
intent = self._call_openai_compat(
|
| 218 |
-
url=self.OPENROUTER_URL,
|
| 219 |
-
|
| 220 |
-
model=self.OPENROUTER_MODEL,
|
| 221 |
-
query=q,
|
| 222 |
-
provider="openrouter",
|
| 223 |
extra_headers={
|
| 224 |
"HTTP-Referer": "https://arki-ai.com",
|
| 225 |
"X-Title": "ARKI AI Intent Classifier",
|
| 226 |
-
}
|
| 227 |
)
|
| 228 |
if intent:
|
| 229 |
return self._result(intent, 0.93, "llm_openrouter", t0, complexity,
|
| 230 |
self._sub_type(q, intent))
|
| 231 |
|
| 232 |
-
# ── Layer 5: HuggingFace Inference API (FALLBACK 3) ───────────────────
|
| 233 |
if self._hf_token:
|
| 234 |
intent = self._call_openai_compat(
|
| 235 |
-
url=self.HF_URL,
|
| 236 |
-
api_key=self._hf_token,
|
| 237 |
model="meta-llama/Llama-3.2-3B-Instruct",
|
| 238 |
-
query=q,
|
| 239 |
-
provider="huggingface",
|
| 240 |
-
timeout=8.0, # HF is slower
|
| 241 |
)
|
| 242 |
if intent:
|
| 243 |
return self._result(intent, 0.90, "llm_hf", t0, complexity,
|
| 244 |
self._sub_type(q, intent))
|
| 245 |
|
| 246 |
# ── Layer 6: Safe default ─────────────────────────────────────────────
|
| 247 |
-
logger.warning(f"Intent
|
| 248 |
return self._result("NEWS_GENERAL", 0.50, "default", t0, complexity, "general")
|
| 249 |
|
| 250 |
# ── Provider calls ────────────────────────────────────────────────────────
|
|
@@ -259,143 +412,110 @@ class IntentClassifierV2:
|
|
| 259 |
extra_headers: Optional[Dict] = None,
|
| 260 |
timeout: float = 4.0,
|
| 261 |
) -> Optional[str]:
|
| 262 |
-
"""
|
| 263 |
-
Generic OpenAI-compatible API call.
|
| 264 |
-
Works for: Groq, OpenRouter, HuggingFace (all use same format).
|
| 265 |
-
"""
|
| 266 |
try:
|
| 267 |
-
headers = {
|
| 268 |
-
"Authorization": f"Bearer {api_key}",
|
| 269 |
-
"Content-Type": "application/json",
|
| 270 |
-
}
|
| 271 |
if extra_headers:
|
| 272 |
headers.update(extra_headers)
|
| 273 |
-
|
| 274 |
response = self._client.post(
|
| 275 |
-
url,
|
| 276 |
-
headers=headers,
|
| 277 |
json={
|
| 278 |
"model": model,
|
| 279 |
-
"messages": [
|
| 280 |
-
{"role": "user", "content": _CLASSIFY_PROMPT.format(query=query)}
|
| 281 |
-
],
|
| 282 |
"max_tokens": 20,
|
| 283 |
"temperature": 0.0,
|
| 284 |
},
|
| 285 |
timeout=timeout,
|
| 286 |
)
|
| 287 |
-
|
| 288 |
if response.status_code == 200:
|
| 289 |
content = (
|
| 290 |
-
response.json()
|
| 291 |
-
.get("
|
| 292 |
-
.get("message", {})
|
| 293 |
-
.get("content", "")
|
| 294 |
-
.strip()
|
| 295 |
)
|
| 296 |
intent = self._parse_intent(content)
|
| 297 |
if intent:
|
| 298 |
-
logger.debug(f"{provider}: '{query[:40]}' → {intent}")
|
| 299 |
return intent
|
| 300 |
-
logger.warning(f"{provider}: unexpected response: '{content}'")
|
| 301 |
-
|
| 302 |
elif response.status_code == 429:
|
| 303 |
-
logger.warning(f"Intent
|
| 304 |
elif response.status_code == 503:
|
| 305 |
-
logger.warning(f"Intent
|
| 306 |
else:
|
| 307 |
-
logger.warning(f"Intent
|
| 308 |
-
|
| 309 |
except httpx.TimeoutException:
|
| 310 |
-
logger.warning(f"Intent
|
| 311 |
except Exception as e:
|
| 312 |
-
logger.error(f"Intent
|
| 313 |
-
|
| 314 |
return None
|
| 315 |
|
| 316 |
def _call_gemini(self, query: str) -> Optional[str]:
|
| 317 |
-
"""Gemini has a different API format."""
|
| 318 |
try:
|
| 319 |
url = f"{self.GEMINI_URL}?key={self._gemini_key}"
|
| 320 |
response = self._client.post(
|
| 321 |
url,
|
| 322 |
json={
|
| 323 |
-
"contents": [
|
| 324 |
-
|
| 325 |
-
],
|
| 326 |
-
"generationConfig": {
|
| 327 |
-
"maxOutputTokens": 20,
|
| 328 |
-
"temperature": 0.0,
|
| 329 |
-
},
|
| 330 |
},
|
| 331 |
timeout=4.0,
|
| 332 |
)
|
| 333 |
-
|
| 334 |
if response.status_code == 200:
|
| 335 |
content = (
|
| 336 |
-
response.json()
|
| 337 |
-
.get("
|
| 338 |
-
.get("
|
| 339 |
-
.get("parts", [{}])[0]
|
| 340 |
-
.get("text", "")
|
| 341 |
-
.strip()
|
| 342 |
)
|
| 343 |
intent = self._parse_intent(content)
|
| 344 |
if intent:
|
| 345 |
-
logger.debug(f"gemini: '{query[:40]}' → {intent}")
|
| 346 |
return intent
|
| 347 |
-
|
| 348 |
elif response.status_code == 429:
|
| 349 |
-
logger.warning("Intent
|
| 350 |
else:
|
| 351 |
-
logger.warning(f"Intent
|
| 352 |
-
|
| 353 |
except httpx.TimeoutException:
|
| 354 |
-
logger.warning("Intent
|
| 355 |
except Exception as e:
|
| 356 |
-
logger.error(f"Intent
|
| 357 |
-
|
| 358 |
return None
|
| 359 |
|
| 360 |
# ── Helpers ───────────────────────────────────────────────────────────────
|
| 361 |
|
| 362 |
def _parse_intent(self, raw: str) -> Optional[str]:
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
return cleaned
|
| 368 |
-
|
| 369 |
-
# Partial match (LLM sometimes adds extra words)
|
| 370 |
for intent in self.VALID_INTENTS:
|
| 371 |
if intent in cleaned:
|
| 372 |
return intent
|
| 373 |
-
|
| 374 |
return None
|
| 375 |
|
| 376 |
def _sub_type(self, query: str, intent: str) -> str:
|
| 377 |
-
"""Infer sub-type from query content for downstream routing."""
|
| 378 |
if intent == "OTHER":
|
| 379 |
ql = query.lower()
|
| 380 |
-
if
|
| 381 |
return "identity"
|
| 382 |
-
if
|
| 383 |
return "creative"
|
| 384 |
return "off_topic"
|
| 385 |
-
|
| 386 |
ql = query.lower()
|
| 387 |
-
if any(w in ql for w in ("clash", "attack", "killed", "battle", "fano", "tplf", "military")):
|
| 388 |
return "conflict"
|
| 389 |
-
if any(w in ql for w in ("displaced", "refugee", "aid", "humanitarian", "famine")):
|
| 390 |
return "humanitarian"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
return "general"
|
| 392 |
|
| 393 |
def _complexity(self, query: str) -> str:
|
| 394 |
n = len(query.split())
|
| 395 |
-
if n == 0:
|
| 396 |
-
if n == 1:
|
| 397 |
-
if n <= 4:
|
| 398 |
-
if n <= 12:
|
| 399 |
return "complex"
|
| 400 |
|
| 401 |
def _result(
|
|
@@ -411,14 +531,12 @@ class IntentClassifierV2:
|
|
| 411 |
ms = (time.time() - t0) * 1000
|
| 412 |
self._metrics["total"] += 1
|
| 413 |
self._metrics["by_intent"][intent] = self._metrics["by_intent"].get(intent, 0) + 1
|
| 414 |
-
self._metrics["by_method"][method]
|
| 415 |
self._metrics["total_ms"] += ms
|
| 416 |
-
|
| 417 |
logger.debug(
|
| 418 |
-
f"Intent
|
| 419 |
f"sub={sub_type} complexity={complexity} time={ms:.1f}ms"
|
| 420 |
)
|
| 421 |
-
|
| 422 |
return IntentResult(
|
| 423 |
intent=intent,
|
| 424 |
confidence=confidence,
|
|
@@ -433,7 +551,12 @@ class IntentClassifierV2:
|
|
| 433 |
|
| 434 |
def get_metrics(self) -> Dict[str, Any]:
|
| 435 |
total = self._metrics["total"] or 1
|
| 436 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
|
| 439 |
# ═══════════════════════════════════════════════════════════════════════════════
|
|
@@ -445,7 +568,6 @@ intent_classifier_v2 = IntentClassifierV2()
|
|
| 445 |
|
| 446 |
class IntentClassifier:
|
| 447 |
"""Backward-compatible binary wrapper (NEWS / OTHER)."""
|
| 448 |
-
|
| 449 |
def __init__(self):
|
| 450 |
self._v2 = intent_classifier_v2
|
| 451 |
|
|
|
|
| 1 |
"""
|
| 2 |
+
Intent Classifier v5 — Fast Keyword Pre-Check + LLM Fallback Chain
|
| 3 |
|
| 4 |
Architecture:
|
| 5 |
+
Layer 0: Instant exact match (0ms) — greetings, single-char, test
|
| 6 |
+
Layer 1: Fast keyword rules (0ms) — temporal/historical/other patterns
|
| 7 |
+
↳ Catches 80%+ of queries instantly, no API call needed
|
| 8 |
Layer 2: Groq llama-3.1-8b-instant — 14,400 free RPD, ~50ms (PRIMARY)
|
| 9 |
Layer 3: Gemini Flash fallback — 1,500 free RPD, ~200ms (FALLBACK 1)
|
| 10 |
Layer 4: OpenRouter free router — free models pool, ~300ms (FALLBACK 2)
|
| 11 |
Layer 5: HuggingFace Inference API — ~300 RPH, ~2s (FALLBACK 3)
|
| 12 |
Layer 6: Safe default — NEWS_GENERAL, 0ms (ALWAYS WORKS)
|
| 13 |
|
| 14 |
+
Layer 1 keyword rules cover:
|
| 15 |
+
- Temporal: "today", "now", "breaking", "latest", "just happened", etc.
|
| 16 |
+
- Historical: "history of", "background", "what caused", "explain", etc.
|
| 17 |
+
- Other: greetings, identity questions, math, creative writing
|
| 18 |
+
- Ethiopia-specific: "Abiy", "TPLF", "Fano", "Tigray" → NEWS_GENERAL fast path
|
| 19 |
+
|
| 20 |
+
Why this matters:
|
| 21 |
+
- Saves Groq API quota (14,400 RPD is finite)
|
| 22 |
+
- Reduces latency from ~50ms → 0ms for common queries
|
| 23 |
+
- Works offline / when all LLM providers are down
|
| 24 |
+
- Handles Amharic/Arabic/Somali temporal words natively
|
|
|
|
|
|
|
| 25 |
"""
|
| 26 |
|
| 27 |
import logging
|
| 28 |
+
import re
|
| 29 |
import time
|
|
|
|
| 30 |
import httpx
|
| 31 |
from dataclasses import dataclass
|
| 32 |
+
from typing import Any, Dict, Optional, Tuple
|
| 33 |
|
| 34 |
logger = logging.getLogger(__name__)
|
| 35 |
|
| 36 |
|
| 37 |
# ═══════════════════════════════════════════════════════════════════════════════
|
| 38 |
+
# LAYER 0: INSTANT EXACT MATCH — greetings, empty, test
|
| 39 |
# ═══════════════════════════════════════════════════════════════════════════════
|
| 40 |
|
| 41 |
_INSTANT_OTHER = {
|
| 42 |
"hi", "hello", "hey", "thanks", "thank you", "bye", "goodbye",
|
| 43 |
"ok", "okay", "yes", "no", "sure", "cool", "nice",
|
| 44 |
"lol", "lmao", "haha", "omg", "wtf", "wow",
|
| 45 |
+
".", "..", "...", "?", "!", "test", "ping",
|
| 46 |
}
|
| 47 |
|
| 48 |
|
| 49 |
# ═══════════════════════════════════════════════════════════════════════════════
|
| 50 |
+
# LAYER 1: FAST KEYWORD RULES
|
| 51 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 52 |
+
|
| 53 |
+
# ── Temporal signals → NEWS_TEMPORAL ─────────────────────────────────────────
|
| 54 |
+
# English
|
| 55 |
+
_TEMPORAL_EN = re.compile(
|
| 56 |
+
r"\b("
|
| 57 |
+
r"today|tonight|right now|just now|breaking|just happened|"
|
| 58 |
+
r"this morning|this afternoon|this evening|this hour|"
|
| 59 |
+
r"latest|current(ly)?|live|ongoing|unfolding|"
|
| 60 |
+
r"yesterday|last night|"
|
| 61 |
+
r"this week|this month|this year|"
|
| 62 |
+
r"recent(ly)?|new(ly)?|fresh|"
|
| 63 |
+
r"past (few )?(hours?|days?|weeks?)|"
|
| 64 |
+
r"in the (last|past) \d+|"
|
| 65 |
+
r"as of (today|now)|"
|
| 66 |
+
r"update[sd]?|news flash|alert"
|
| 67 |
+
r")\b",
|
| 68 |
+
re.IGNORECASE
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Amharic temporal words (common ones)
|
| 72 |
+
_TEMPORAL_AM = re.compile(
|
| 73 |
+
r"(ዛሬ|አሁን|ዘንድሮ|ቅርብ|አዲስ|ዜና|ዛሬ ምሽት|ዛሬ ጠዋት)",
|
| 74 |
+
re.UNICODE
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Arabic temporal words
|
| 78 |
+
_TEMPORAL_AR = re.compile(
|
| 79 |
+
r"(اليوم|الآن|عاجل|أخبار عاجلة|حديثاً|مؤخراً|هذا الأسبوع|هذا الشهر)",
|
| 80 |
+
re.UNICODE
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Somali temporal words
|
| 84 |
+
_TEMPORAL_SO = re.compile(r"(maanta|hadda|wararka|cusub)", re.IGNORECASE | re.UNICODE)
|
| 85 |
+
|
| 86 |
+
# Swahili temporal words
|
| 87 |
+
_TEMPORAL_SW = re.compile(r"(leo|sasa|habari za leo|mpya|hivi karibuni)", re.IGNORECASE | re.UNICODE)
|
| 88 |
+
|
| 89 |
+
# ── Historical signals → NEWS_HISTORICAL ─────────────────────────────────────
|
| 90 |
+
_HISTORICAL = re.compile(
|
| 91 |
+
r"\b("
|
| 92 |
+
r"history (of|behind)|historical(ly)?|"
|
| 93 |
+
r"background (of|on|to)|context (of|behind)|"
|
| 94 |
+
r"what caused|root cause|origin(s)? of|"
|
| 95 |
+
r"explain|overview|summary of|"
|
| 96 |
+
r"who (is|was|are|were)|what (is|was|are|were)|"
|
| 97 |
+
r"tell me about|describe|"
|
| 98 |
+
r"in \d{4}|since \d{4}|before \d{4}|"
|
| 99 |
+
r"decade(s)?|century|centuries|"
|
| 100 |
+
r"long.?term|over the years|traditionally|"
|
| 101 |
+
r"founded|established|created|formed"
|
| 102 |
+
r")\b",
|
| 103 |
+
re.IGNORECASE
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# ── Other signals → OTHER ─────────────────────────────────────────────────────
|
| 107 |
+
_OTHER_IDENTITY = re.compile(
|
| 108 |
+
r"\b("
|
| 109 |
+
r"who are you|what are you|are you (an? )?ai|"
|
| 110 |
+
r"what (model|llm|ai) are you|"
|
| 111 |
+
r"who (made|built|created|trained) you|"
|
| 112 |
+
r"your (name|purpose|capabilities)|"
|
| 113 |
+
r"can you (help|do|write|make|create|generate)|"
|
| 114 |
+
r"how (do you|does this) work"
|
| 115 |
+
r")\b",
|
| 116 |
+
re.IGNORECASE
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
_OTHER_CREATIVE = re.compile(
|
| 120 |
+
r"\b("
|
| 121 |
+
r"write (a |an )?(poem|story|essay|letter|email|code|script)|"
|
| 122 |
+
r"make (a |an )?(joke|list|plan|recipe)|"
|
| 123 |
+
r"translate (this|to|into)|"
|
| 124 |
+
r"calculate|solve|compute|"
|
| 125 |
+
r"what is \d|how many|how much|"
|
| 126 |
+
r"recommend|suggest|give me (a |an )?(list|idea)"
|
| 127 |
+
r")\b",
|
| 128 |
+
re.IGNORECASE
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# ── Ethiopia/Africa fast-path → NEWS_GENERAL (skip LLM entirely) ─────────────
|
| 132 |
+
_ETHIOPIA_ENTITIES = re.compile(
|
| 133 |
+
r"\b("
|
| 134 |
+
r"ethiopia(n)?|addis ababa|addis|"
|
| 135 |
+
r"tigray|amhara|oromia|oromo|afar|somali region|sidama|"
|
| 136 |
+
r"abiy ahmed?|abiy|"
|
| 137 |
+
r"tplf|fano|olf|oneg|endf|"
|
| 138 |
+
r"gerd|renaissance dam|nile dam|"
|
| 139 |
+
r"mekelle|bahir dar|gondar|hawassa|dire dawa|"
|
| 140 |
+
r"africa(n)?|horn of africa|east africa|"
|
| 141 |
+
r"sudan|somalia|eritrea|kenya|djibouti"
|
| 142 |
+
r")\b",
|
| 143 |
+
re.IGNORECASE
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# ── Conflict/humanitarian fast-path → NEWS_GENERAL ───────────────────────────
|
| 147 |
+
_NEWS_TOPICS = re.compile(
|
| 148 |
+
r"\b("
|
| 149 |
+
r"conflict|war|fighting|clashes?|attack(s|ed)?|killed|casualties|"
|
| 150 |
+
r"peace (talks?|deal|agreement|process)|ceasefire|"
|
| 151 |
+
r"election(s)?|vote|voting|ballot|"
|
| 152 |
+
r"government|minister|president|prime minister|parliament|"
|
| 153 |
+
r"economy|economic|inflation|gdp|trade|investment|"
|
| 154 |
+
r"humanitarian|refugee(s)?|displaced|famine|drought|flood|"
|
| 155 |
+
r"protest(s|ers)?|demonstration|rally|"
|
| 156 |
+
r"military|troops|soldiers?|forces?|"
|
| 157 |
+
r"news|report(s|ed)?|update(s)?"
|
| 158 |
+
r")\b",
|
| 159 |
+
re.IGNORECASE
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def _fast_classify(query: str) -> Optional[Tuple[str, float, str]]:
|
| 164 |
+
"""
|
| 165 |
+
Layer 1: Fast keyword-based classification.
|
| 166 |
+
Returns (intent, confidence, reason) or None if uncertain.
|
| 167 |
+
|
| 168 |
+
Priority order:
|
| 169 |
+
1. OTHER (identity/creative) — highest priority, avoid wasting search
|
| 170 |
+
2. NEWS_TEMPORAL — temporal signals are unambiguous
|
| 171 |
+
3. NEWS_HISTORICAL — historical signals are fairly unambiguous
|
| 172 |
+
4. NEWS_GENERAL — Ethiopia/Africa entities or news topics
|
| 173 |
+
5. None — uncertain, let LLM decide
|
| 174 |
+
"""
|
| 175 |
+
q = query.strip()
|
| 176 |
+
ql = q.lower()
|
| 177 |
+
|
| 178 |
+
# ── 1. OTHER: identity questions ─────────────────────────────────────────
|
| 179 |
+
if _OTHER_IDENTITY.search(q):
|
| 180 |
+
return ("OTHER", 0.95, "identity_pattern")
|
| 181 |
+
|
| 182 |
+
# ── 2. OTHER: creative/off-topic ─────────────────────────────────────────
|
| 183 |
+
if _OTHER_CREATIVE.search(q):
|
| 184 |
+
return ("OTHER", 0.90, "creative_pattern")
|
| 185 |
+
|
| 186 |
+
# ── 3. NEWS_TEMPORAL: multilingual temporal signals ───────────────────────
|
| 187 |
+
if (_TEMPORAL_EN.search(q) or _TEMPORAL_AM.search(q) or
|
| 188 |
+
_TEMPORAL_AR.search(q) or _TEMPORAL_SO.search(q) or
|
| 189 |
+
_TEMPORAL_SW.search(q)):
|
| 190 |
+
return ("NEWS_TEMPORAL", 0.92, "temporal_keyword")
|
| 191 |
+
|
| 192 |
+
# ── 4. NEWS_HISTORICAL: historical/background signals ────────────────────
|
| 193 |
+
if _HISTORICAL.search(q):
|
| 194 |
+
# But if it also has temporal signals, temporal wins
|
| 195 |
+
return ("NEWS_HISTORICAL", 0.88, "historical_keyword")
|
| 196 |
+
|
| 197 |
+
# ── 5. NEWS_GENERAL: Ethiopia/Africa entities ────────────────────────────
|
| 198 |
+
if _ETHIOPIA_ENTITIES.search(q):
|
| 199 |
+
return ("NEWS_GENERAL", 0.85, "ethiopia_entity")
|
| 200 |
+
|
| 201 |
+
# ── 6. NEWS_GENERAL: news topic keywords ─────────────────────────────────
|
| 202 |
+
if _NEWS_TOPICS.search(q):
|
| 203 |
+
return ("NEWS_GENERAL", 0.80, "news_topic_keyword")
|
| 204 |
+
|
| 205 |
+
# ── 7. Uncertain — let LLM decide ────────────────────────────────────────
|
| 206 |
+
return None
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 210 |
+
# LLM CLASSIFICATION PROMPT
|
| 211 |
# ═══════════════════════════════════════════════════════════════════════════════
|
| 212 |
|
| 213 |
_CLASSIFY_PROMPT = """You are an intent classifier for ARKI AI, a news assistant focused on Ethiopia and Africa.
|
|
|
|
| 241 |
class IntentResult:
|
| 242 |
intent: str # NEWS_TEMPORAL | NEWS_HISTORICAL | NEWS_GENERAL | OTHER
|
| 243 |
confidence: float # 0.0 – 1.0
|
| 244 |
+
method: str # instant | keyword | llm_groq | llm_gemini | llm_openrouter | llm_hf | default
|
| 245 |
inference_time_ms: float
|
| 246 |
+
query_complexity: str # empty | vague | simple | medium | complex
|
| 247 |
+
sub_type: str # general | conflict | humanitarian | identity | creative | off_topic
|
| 248 |
should_use_live: bool
|
| 249 |
should_use_db: bool
|
| 250 |
metadata: Dict[str, Any]
|
|
|
|
| 269 |
|
| 270 |
class IntentClassifierV2:
|
| 271 |
"""
|
| 272 |
+
Intent classifier v5: Fast keyword pre-check + LLM fallback chain.
|
| 273 |
+
|
| 274 |
+
Layer 0: Instant exact match (0ms)
|
| 275 |
+
Layer 1: Keyword rules (0ms) — handles ~80% of queries
|
| 276 |
+
Layer 2: Groq 8B (50ms)
|
| 277 |
+
Layer 3: Gemini Flash (200ms)
|
| 278 |
+
Layer 4: OpenRouter (300ms)
|
| 279 |
+
Layer 5: HuggingFace (2s)
|
| 280 |
+
Layer 6: Default NEWS_GENERAL (0ms)
|
| 281 |
"""
|
| 282 |
|
| 283 |
+
GROQ_URL = "https://api.groq.com/openai/v1/chat/completions"
|
| 284 |
+
GROQ_MODEL = "llama-3.1-8b-instant"
|
| 285 |
+
GEMINI_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"
|
| 286 |
+
OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"
|
| 287 |
+
OPENROUTER_MODEL = "openrouter/auto"
|
| 288 |
+
HF_URL = "https://api-inference.huggingface.co/models/meta-llama/Llama-3.2-3B-Instruct/v1/chat/completions"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
VALID_INTENTS = {"NEWS_TEMPORAL", "NEWS_HISTORICAL", "NEWS_GENERAL", "OTHER"}
|
| 291 |
|
|
|
|
| 295 |
self._openrouter_key: Optional[str] = None
|
| 296 |
self._hf_token: Optional[str] = None
|
| 297 |
self._client = httpx.Client(timeout=5.0)
|
| 298 |
+
self._metrics: Dict[str, Any] = {
|
| 299 |
"total": 0,
|
| 300 |
"by_intent": {},
|
| 301 |
"by_method": {},
|
| 302 |
"total_ms": 0.0,
|
| 303 |
+
"keyword_hits": 0, # how many queries handled by keyword layer
|
| 304 |
+
"llm_calls": 0, # how many queries needed LLM
|
| 305 |
}
|
| 306 |
self._load_keys()
|
| 307 |
|
| 308 |
def _load_keys(self):
|
|
|
|
| 309 |
try:
|
| 310 |
from src.core.config import settings
|
|
|
|
| 311 |
key = settings.GROQ_API_KEY
|
| 312 |
if key and key not in ("", "your-groq-api-key-here"):
|
| 313 |
self._groq_key = key
|
|
|
|
| 314 |
gem = settings.GEMINI_API_KEY
|
| 315 |
if gem and gem not in ("", "your-gemini-api-key-here"):
|
| 316 |
self._gemini_key = gem
|
|
|
|
|
|
|
| 317 |
try:
|
| 318 |
or_key = getattr(settings, "OPENROUTER_API_KEY", "")
|
| 319 |
if or_key and or_key not in ("", "your-openrouter-api-key-here"):
|
| 320 |
self._openrouter_key = or_key
|
| 321 |
except Exception:
|
| 322 |
pass
|
|
|
|
|
|
|
| 323 |
hf = settings.HF_TOKEN
|
| 324 |
if hf and hf not in ("", "your-hf-token-here"):
|
| 325 |
self._hf_token = hf
|
| 326 |
|
| 327 |
+
providers = ["Keyword"]
|
| 328 |
+
if self._groq_key: providers.append("Groq")
|
| 329 |
+
if self._gemini_key: providers.append("Gemini")
|
| 330 |
+
if self._openrouter_key: providers.append("OpenRouter")
|
| 331 |
+
if self._hf_token: providers.append("HuggingFace")
|
| 332 |
providers.append("Default")
|
| 333 |
+
logger.info(f"✅ Intent classifier v5 providers: {' → '.join(providers)}")
|
|
|
|
|
|
|
| 334 |
except Exception as e:
|
| 335 |
logger.error(f"Intent classifier: failed to load keys: {e}")
|
| 336 |
|
|
|
|
| 342 |
ql = q.lower()
|
| 343 |
complexity = self._complexity(q)
|
| 344 |
|
| 345 |
+
# ── Layer 0: Instant exact match ──────────────────────────────────────
|
| 346 |
if ql in _INSTANT_OTHER:
|
| 347 |
return self._result("OTHER", 1.0, "instant", t0, complexity, "identity")
|
| 348 |
|
| 349 |
+
# ── Layer 1: Fast keyword rules ───────────────────────────────────────
|
| 350 |
+
fast = _fast_classify(q)
|
| 351 |
+
if fast:
|
| 352 |
+
intent, confidence, reason = fast
|
| 353 |
+
self._metrics["keyword_hits"] += 1
|
| 354 |
+
logger.debug(f"[Intent] Keyword rule: '{q[:50]}' → {intent} ({reason})")
|
| 355 |
+
return self._result(intent, confidence, f"keyword:{reason}", t0, complexity,
|
| 356 |
+
self._sub_type(q, intent))
|
| 357 |
+
|
| 358 |
+
# ── Layers 2-5: LLM providers ─────────────────────────────────────────
|
| 359 |
+
self._metrics["llm_calls"] += 1
|
| 360 |
+
|
| 361 |
if self._groq_key:
|
| 362 |
intent = self._call_openai_compat(
|
| 363 |
+
url=self.GROQ_URL, api_key=self._groq_key,
|
| 364 |
+
model=self.GROQ_MODEL, query=q, provider="groq"
|
|
|
|
|
|
|
|
|
|
| 365 |
)
|
| 366 |
if intent:
|
| 367 |
return self._result(intent, 0.97, "llm_groq", t0, complexity,
|
| 368 |
self._sub_type(q, intent))
|
| 369 |
|
|
|
|
| 370 |
if self._gemini_key:
|
| 371 |
intent = self._call_gemini(q)
|
| 372 |
if intent:
|
| 373 |
return self._result(intent, 0.95, "llm_gemini", t0, complexity,
|
| 374 |
self._sub_type(q, intent))
|
| 375 |
|
|
|
|
| 376 |
if self._openrouter_key:
|
| 377 |
intent = self._call_openai_compat(
|
| 378 |
+
url=self.OPENROUTER_URL, api_key=self._openrouter_key,
|
| 379 |
+
model=self.OPENROUTER_MODEL, query=q, provider="openrouter",
|
|
|
|
|
|
|
|
|
|
| 380 |
extra_headers={
|
| 381 |
"HTTP-Referer": "https://arki-ai.com",
|
| 382 |
"X-Title": "ARKI AI Intent Classifier",
|
| 383 |
+
}
|
| 384 |
)
|
| 385 |
if intent:
|
| 386 |
return self._result(intent, 0.93, "llm_openrouter", t0, complexity,
|
| 387 |
self._sub_type(q, intent))
|
| 388 |
|
|
|
|
| 389 |
if self._hf_token:
|
| 390 |
intent = self._call_openai_compat(
|
| 391 |
+
url=self.HF_URL, api_key=self._hf_token,
|
|
|
|
| 392 |
model="meta-llama/Llama-3.2-3B-Instruct",
|
| 393 |
+
query=q, provider="huggingface", timeout=8.0
|
|
|
|
|
|
|
| 394 |
)
|
| 395 |
if intent:
|
| 396 |
return self._result(intent, 0.90, "llm_hf", t0, complexity,
|
| 397 |
self._sub_type(q, intent))
|
| 398 |
|
| 399 |
# ── Layer 6: Safe default ─────────────────────────────────────────────
|
| 400 |
+
logger.warning(f"[Intent] All providers failed for '{q[:50]}' — defaulting to NEWS_GENERAL")
|
| 401 |
return self._result("NEWS_GENERAL", 0.50, "default", t0, complexity, "general")
|
| 402 |
|
| 403 |
# ── Provider calls ────────────────────────────────────────────────────────
|
|
|
|
| 412 |
extra_headers: Optional[Dict] = None,
|
| 413 |
timeout: float = 4.0,
|
| 414 |
) -> Optional[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
try:
|
| 416 |
+
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
|
|
|
|
|
|
|
|
|
| 417 |
if extra_headers:
|
| 418 |
headers.update(extra_headers)
|
|
|
|
| 419 |
response = self._client.post(
|
| 420 |
+
url, headers=headers,
|
|
|
|
| 421 |
json={
|
| 422 |
"model": model,
|
| 423 |
+
"messages": [{"role": "user", "content": _CLASSIFY_PROMPT.format(query=query)}],
|
|
|
|
|
|
|
| 424 |
"max_tokens": 20,
|
| 425 |
"temperature": 0.0,
|
| 426 |
},
|
| 427 |
timeout=timeout,
|
| 428 |
)
|
|
|
|
| 429 |
if response.status_code == 200:
|
| 430 |
content = (
|
| 431 |
+
response.json().get("choices", [{}])[0]
|
| 432 |
+
.get("message", {}).get("content", "").strip()
|
|
|
|
|
|
|
|
|
|
| 433 |
)
|
| 434 |
intent = self._parse_intent(content)
|
| 435 |
if intent:
|
| 436 |
+
logger.debug(f"[Intent] {provider}: '{query[:40]}' → {intent}")
|
| 437 |
return intent
|
| 438 |
+
logger.warning(f"[Intent] {provider}: unexpected response: '{content}'")
|
|
|
|
| 439 |
elif response.status_code == 429:
|
| 440 |
+
logger.warning(f"[Intent] {provider} rate limited")
|
| 441 |
elif response.status_code == 503:
|
| 442 |
+
logger.warning(f"[Intent] {provider} unavailable (503)")
|
| 443 |
else:
|
| 444 |
+
logger.warning(f"[Intent] {provider} returned {response.status_code}")
|
|
|
|
| 445 |
except httpx.TimeoutException:
|
| 446 |
+
logger.warning(f"[Intent] {provider} timeout ({timeout}s)")
|
| 447 |
except Exception as e:
|
| 448 |
+
logger.error(f"[Intent] {provider} error: {e}")
|
|
|
|
| 449 |
return None
|
| 450 |
|
| 451 |
def _call_gemini(self, query: str) -> Optional[str]:
|
|
|
|
| 452 |
try:
|
| 453 |
url = f"{self.GEMINI_URL}?key={self._gemini_key}"
|
| 454 |
response = self._client.post(
|
| 455 |
url,
|
| 456 |
json={
|
| 457 |
+
"contents": [{"parts": [{"text": _CLASSIFY_PROMPT.format(query=query)}]}],
|
| 458 |
+
"generationConfig": {"maxOutputTokens": 20, "temperature": 0.0},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
},
|
| 460 |
timeout=4.0,
|
| 461 |
)
|
|
|
|
| 462 |
if response.status_code == 200:
|
| 463 |
content = (
|
| 464 |
+
response.json().get("candidates", [{}])[0]
|
| 465 |
+
.get("content", {}).get("parts", [{}])[0]
|
| 466 |
+
.get("text", "").strip()
|
|
|
|
|
|
|
|
|
|
| 467 |
)
|
| 468 |
intent = self._parse_intent(content)
|
| 469 |
if intent:
|
| 470 |
+
logger.debug(f"[Intent] gemini: '{query[:40]}' → {intent}")
|
| 471 |
return intent
|
|
|
|
| 472 |
elif response.status_code == 429:
|
| 473 |
+
logger.warning("[Intent] Gemini rate limited")
|
| 474 |
else:
|
| 475 |
+
logger.warning(f"[Intent] Gemini returned {response.status_code}")
|
|
|
|
| 476 |
except httpx.TimeoutException:
|
| 477 |
+
logger.warning("[Intent] Gemini timeout (4s)")
|
| 478 |
except Exception as e:
|
| 479 |
+
logger.error(f"[Intent] Gemini error: {e}")
|
|
|
|
| 480 |
return None
|
| 481 |
|
| 482 |
# ── Helpers ───────────────────────────────────────────────────────────────
|
| 483 |
|
| 484 |
def _parse_intent(self, raw: str) -> Optional[str]:
|
| 485 |
+
cleaned = raw.strip().upper().replace(".", "").replace(":", "")
|
| 486 |
+
first_word = cleaned.split()[0] if cleaned.split() else ""
|
| 487 |
+
if first_word in self.VALID_INTENTS:
|
| 488 |
+
return first_word
|
|
|
|
|
|
|
|
|
|
| 489 |
for intent in self.VALID_INTENTS:
|
| 490 |
if intent in cleaned:
|
| 491 |
return intent
|
|
|
|
| 492 |
return None
|
| 493 |
|
| 494 |
def _sub_type(self, query: str, intent: str) -> str:
|
|
|
|
| 495 |
if intent == "OTHER":
|
| 496 |
ql = query.lower()
|
| 497 |
+
if _OTHER_IDENTITY.search(query):
|
| 498 |
return "identity"
|
| 499 |
+
if _OTHER_CREATIVE.search(query):
|
| 500 |
return "creative"
|
| 501 |
return "off_topic"
|
|
|
|
| 502 |
ql = query.lower()
|
| 503 |
+
if any(w in ql for w in ("clash", "attack", "killed", "battle", "fano", "tplf", "military", "conflict", "war")):
|
| 504 |
return "conflict"
|
| 505 |
+
if any(w in ql for w in ("displaced", "refugee", "aid", "humanitarian", "famine", "drought")):
|
| 506 |
return "humanitarian"
|
| 507 |
+
if any(w in ql for w in ("election", "vote", "government", "minister", "president", "parliament")):
|
| 508 |
+
return "political"
|
| 509 |
+
if any(w in ql for w in ("economy", "economic", "inflation", "trade", "investment", "gdp")):
|
| 510 |
+
return "economic"
|
| 511 |
return "general"
|
| 512 |
|
| 513 |
def _complexity(self, query: str) -> str:
|
| 514 |
n = len(query.split())
|
| 515 |
+
if n == 0: return "empty"
|
| 516 |
+
if n == 1: return "vague"
|
| 517 |
+
if n <= 4: return "simple"
|
| 518 |
+
if n <= 12: return "medium"
|
| 519 |
return "complex"
|
| 520 |
|
| 521 |
def _result(
|
|
|
|
| 531 |
ms = (time.time() - t0) * 1000
|
| 532 |
self._metrics["total"] += 1
|
| 533 |
self._metrics["by_intent"][intent] = self._metrics["by_intent"].get(intent, 0) + 1
|
| 534 |
+
self._metrics["by_method"][method] = self._metrics["by_method"].get(method, 0) + 1
|
| 535 |
self._metrics["total_ms"] += ms
|
|
|
|
| 536 |
logger.debug(
|
| 537 |
+
f"[Intent] {intent} conf={confidence:.2f} method={method} "
|
| 538 |
f"sub={sub_type} complexity={complexity} time={ms:.1f}ms"
|
| 539 |
)
|
|
|
|
| 540 |
return IntentResult(
|
| 541 |
intent=intent,
|
| 542 |
confidence=confidence,
|
|
|
|
| 551 |
|
| 552 |
def get_metrics(self) -> Dict[str, Any]:
|
| 553 |
total = self._metrics["total"] or 1
|
| 554 |
+
kw_pct = (self._metrics["keyword_hits"] / total) * 100
|
| 555 |
+
return {
|
| 556 |
+
**self._metrics,
|
| 557 |
+
"avg_ms": self._metrics["total_ms"] / total,
|
| 558 |
+
"keyword_hit_rate_pct": round(kw_pct, 1),
|
| 559 |
+
}
|
| 560 |
|
| 561 |
|
| 562 |
# ═══════════════════════════════════════════════════════════════════════════════
|
|
|
|
| 568 |
|
| 569 |
class IntentClassifier:
|
| 570 |
"""Backward-compatible binary wrapper (NEWS / OTHER)."""
|
|
|
|
| 571 |
def __init__(self):
|
| 572 |
self._v2 = intent_classifier_v2
|
| 573 |
|
src/infrastructure/adapters/jina_reranker_adapter.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Jina Reranker API Adapter
|
| 3 |
+
|
| 4 |
+
Calls Jina AI's cloud reranker API instead of running the model locally.
|
| 5 |
+
Same jina-reranker-v3 model, but runs on Jina's GPU servers.
|
| 6 |
+
|
| 7 |
+
Benefits over self-hosted:
|
| 8 |
+
- ~300ms latency (vs ~6s/doc on CPU)
|
| 9 |
+
- No model download or GPU needed
|
| 10 |
+
- Same API key as Jina Reader (unified token balance)
|
| 11 |
+
- Production-ready immediately
|
| 12 |
+
|
| 13 |
+
API: POST https://api.jina.ai/v1/rerank
|
| 14 |
+
Docs: https://jina.ai/reranker
|
| 15 |
+
Free: 1M tokens on signup (same key as Reader)
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import logging
|
| 19 |
+
import time
|
| 20 |
+
import httpx
|
| 21 |
+
from typing import List, Dict, Any, Optional
|
| 22 |
+
|
| 23 |
+
from src.core.ports.reranker_port import RerankerPort
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class JinaRerankerAPIAdapter(RerankerPort):
|
| 29 |
+
"""
|
| 30 |
+
Reranker using Jina AI's cloud API.
|
| 31 |
+
|
| 32 |
+
Sends all documents in ONE API call — Jina handles batching server-side.
|
| 33 |
+
Falls back to score-based ordering if API fails.
|
| 34 |
+
|
| 35 |
+
Token usage: query_tokens + sum(doc_tokens)
|
| 36 |
+
Typical: ~1,400 tokens per call with 7 docs × 200 chars each
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
API_URL = "https://api.jina.ai/v1/rerank"
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
api_key: str,
|
| 44 |
+
model: str = "jina-reranker-v3",
|
| 45 |
+
timeout: float = 5.0,
|
| 46 |
+
):
|
| 47 |
+
self.api_key = api_key
|
| 48 |
+
self.model = model
|
| 49 |
+
self.timeout = timeout
|
| 50 |
+
self._client: Optional[httpx.Client] = None
|
| 51 |
+
|
| 52 |
+
if not api_key or api_key in ("", "your-jina-api-key-here"):
|
| 53 |
+
logger.warning("Jina Reranker API: no API key — adapter disabled")
|
| 54 |
+
self.api_key = None
|
| 55 |
+
else:
|
| 56 |
+
logger.info(f"Jina Reranker API ready (model={model}, timeout={timeout}s)")
|
| 57 |
+
|
| 58 |
+
def _get_client(self) -> httpx.Client:
|
| 59 |
+
if self._client is None:
|
| 60 |
+
self._client = httpx.Client(
|
| 61 |
+
timeout=self.timeout,
|
| 62 |
+
headers={
|
| 63 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 64 |
+
"Content-Type": "application/json",
|
| 65 |
+
"Accept": "application/json",
|
| 66 |
+
}
|
| 67 |
+
)
|
| 68 |
+
return self._client
|
| 69 |
+
|
| 70 |
+
def rerank(
|
| 71 |
+
self,
|
| 72 |
+
query: str,
|
| 73 |
+
docs: List[Dict[str, Any]],
|
| 74 |
+
top_n: int = 5,
|
| 75 |
+
) -> List[Dict[str, Any]]:
|
| 76 |
+
"""
|
| 77 |
+
Rerank documents using Jina API.
|
| 78 |
+
|
| 79 |
+
Sends all docs in one request — Jina returns them sorted by relevance.
|
| 80 |
+
Falls back to vector score ordering if API unavailable.
|
| 81 |
+
"""
|
| 82 |
+
if not docs:
|
| 83 |
+
return []
|
| 84 |
+
|
| 85 |
+
if not self.api_key:
|
| 86 |
+
logger.warning("Jina Reranker API disabled — falling back to score ordering")
|
| 87 |
+
return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n]
|
| 88 |
+
|
| 89 |
+
# Extract text content — truncate to 2048 chars (Jina handles tokenization)
|
| 90 |
+
MAX_CHARS = 2048
|
| 91 |
+
valid_docs = []
|
| 92 |
+
doc_texts = []
|
| 93 |
+
for doc in docs:
|
| 94 |
+
content = doc.get("content", "").strip()
|
| 95 |
+
if content:
|
| 96 |
+
doc_texts.append(content[:MAX_CHARS])
|
| 97 |
+
valid_docs.append(doc)
|
| 98 |
+
|
| 99 |
+
if not doc_texts:
|
| 100 |
+
return []
|
| 101 |
+
|
| 102 |
+
t0 = time.time()
|
| 103 |
+
try:
|
| 104 |
+
response = self._get_client().post(
|
| 105 |
+
self.API_URL,
|
| 106 |
+
json={
|
| 107 |
+
"model": self.model,
|
| 108 |
+
"query": query,
|
| 109 |
+
"documents": doc_texts,
|
| 110 |
+
"top_n": len(doc_texts), # Get all scores, we'll slice ourselves
|
| 111 |
+
"return_documents": False, # Save tokens — we already have the docs
|
| 112 |
+
}
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
elapsed_ms = (time.time() - t0) * 1000
|
| 116 |
+
|
| 117 |
+
if response.status_code == 200:
|
| 118 |
+
data = response.json()
|
| 119 |
+
results = data.get("results", [])
|
| 120 |
+
usage = data.get("usage", {})
|
| 121 |
+
|
| 122 |
+
# results = [{"index": int, "relevance_score": float}, ...]
|
| 123 |
+
# Restore scores to original docs
|
| 124 |
+
for r in results:
|
| 125 |
+
idx = r["index"]
|
| 126 |
+
if idx < len(valid_docs):
|
| 127 |
+
valid_docs[idx]["rerank_score"] = float(r["relevance_score"])
|
| 128 |
+
|
| 129 |
+
# Sort by rerank_score descending
|
| 130 |
+
valid_docs.sort(key=lambda x: x.get("rerank_score", 0), reverse=True)
|
| 131 |
+
|
| 132 |
+
logger.info(
|
| 133 |
+
f"[JinaReranker] {len(valid_docs)} docs → top {top_n} "
|
| 134 |
+
f"in {elapsed_ms:.0f}ms "
|
| 135 |
+
f"(tokens={usage.get('total_tokens', '?')}, "
|
| 136 |
+
f"top_score={valid_docs[0].get('rerank_score', 0):.3f})"
|
| 137 |
+
)
|
| 138 |
+
return valid_docs[:top_n]
|
| 139 |
+
|
| 140 |
+
elif response.status_code == 401:
|
| 141 |
+
logger.error("Jina Reranker API: Invalid API key")
|
| 142 |
+
elif response.status_code == 429:
|
| 143 |
+
logger.warning("Jina Reranker API: Rate limit exceeded")
|
| 144 |
+
elif response.status_code == 402:
|
| 145 |
+
logger.warning("Jina Reranker API: Insufficient tokens — top up at jina.ai")
|
| 146 |
+
else:
|
| 147 |
+
logger.warning(
|
| 148 |
+
f"Jina Reranker API: HTTP {response.status_code} — {response.text[:200]}"
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
except httpx.TimeoutException:
|
| 152 |
+
logger.warning(f"Jina Reranker API: timeout ({self.timeout}s)")
|
| 153 |
+
except Exception as e:
|
| 154 |
+
logger.error(f"Jina Reranker API error: {e}")
|
| 155 |
+
|
| 156 |
+
# Fallback: sort by vector score
|
| 157 |
+
logger.warning("Jina Reranker API failed — falling back to vector score ordering")
|
| 158 |
+
return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n]
|
| 159 |
+
|
| 160 |
+
def is_available(self) -> bool:
|
| 161 |
+
return self.api_key is not None
|
src/infrastructure/adapters/newsapi_adapter.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NewsAPI.org Adapter
|
| 3 |
+
|
| 4 |
+
Provides real-time news from 80,000+ sources worldwide.
|
| 5 |
+
Best for temporal queries requiring fresh, breaking news.
|
| 6 |
+
|
| 7 |
+
Features:
|
| 8 |
+
- Real-time updates (15-minute refresh)
|
| 9 |
+
- 80,000+ sources including African outlets
|
| 10 |
+
- Structured data (title, description, content, source, publishedAt)
|
| 11 |
+
- Free tier: 100 requests/day
|
| 12 |
+
- Paid tier: $449/month for production
|
| 13 |
+
|
| 14 |
+
Get API key: https://newsapi.org/register
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
import asyncio
|
| 19 |
+
from typing import List, Dict, Any, Optional
|
| 20 |
+
from datetime import datetime
|
| 21 |
+
import httpx
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class NewsAPIAdapter:
|
| 27 |
+
"""
|
| 28 |
+
Adapter for NewsAPI.org real-time news search.
|
| 29 |
+
|
| 30 |
+
Provides fresh news results to complement database search.
|
| 31 |
+
Designed to be fast (2s timeout) and resilient (graceful fallbacks).
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
BASE_URL = "https://newsapi.org/v2"
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
api_key: str,
|
| 39 |
+
timeout: float = 2.0,
|
| 40 |
+
max_results: int = 20
|
| 41 |
+
):
|
| 42 |
+
"""
|
| 43 |
+
Initialize NewsAPI adapter.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
api_key: NewsAPI.org API key
|
| 47 |
+
timeout: Maximum time to wait for results (seconds)
|
| 48 |
+
max_results: Maximum number of results to return
|
| 49 |
+
"""
|
| 50 |
+
self.api_key = api_key
|
| 51 |
+
self.timeout = timeout
|
| 52 |
+
self.max_results = max_results
|
| 53 |
+
self.client = None
|
| 54 |
+
|
| 55 |
+
if not api_key or api_key == "your-newsapi-key-here":
|
| 56 |
+
logger.warning("NewsAPI key not configured - adapter disabled")
|
| 57 |
+
self.api_key = None
|
| 58 |
+
else:
|
| 59 |
+
logger.info(f"NewsAPI adapter initialized (timeout={timeout}s, max={max_results})")
|
| 60 |
+
|
| 61 |
+
async def _ensure_client(self):
|
| 62 |
+
"""Lazy initialization of HTTP client"""
|
| 63 |
+
if self.client is None:
|
| 64 |
+
self.client = httpx.AsyncClient(
|
| 65 |
+
timeout=self.timeout,
|
| 66 |
+
headers={
|
| 67 |
+
"X-Api-Key": self.api_key,
|
| 68 |
+
"User-Agent": "ARKI-AI-RAG/2.5 (Ethiopia News Assistant)"
|
| 69 |
+
}
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
async def search(
|
| 73 |
+
self,
|
| 74 |
+
query: str,
|
| 75 |
+
language: str = "en",
|
| 76 |
+
sort_by: str = "publishedAt",
|
| 77 |
+
from_date: Optional[str] = None,
|
| 78 |
+
max_results: Optional[int] = None
|
| 79 |
+
) -> List[Dict[str, Any]]:
|
| 80 |
+
"""
|
| 81 |
+
Search NewsAPI for the given query.
|
| 82 |
+
Automatically wraps multi-word queries in quotes for exact matching.
|
| 83 |
+
"""
|
| 84 |
+
if not self.api_key:
|
| 85 |
+
logger.warning("NewsAPI unavailable - returning empty results")
|
| 86 |
+
return []
|
| 87 |
+
|
| 88 |
+
await self._ensure_client()
|
| 89 |
+
|
| 90 |
+
max_results = max_results or self.max_results
|
| 91 |
+
|
| 92 |
+
# Wrap in quotes if multi-word and not already quoted — improves precision
|
| 93 |
+
search_q = query
|
| 94 |
+
words = query.strip().split()
|
| 95 |
+
if len(words) > 1 and not query.startswith('"'):
|
| 96 |
+
# Use AND logic: all key terms must appear
|
| 97 |
+
search_q = " AND ".join(f'"{w}"' for w in words[:3])
|
| 98 |
+
|
| 99 |
+
try:
|
| 100 |
+
url = f"{self.BASE_URL}/everything"
|
| 101 |
+
params = {
|
| 102 |
+
"q": search_q,
|
| 103 |
+
"language": language,
|
| 104 |
+
"sortBy": sort_by,
|
| 105 |
+
"pageSize": max_results
|
| 106 |
+
}
|
| 107 |
+
if from_date:
|
| 108 |
+
params["from"] = from_date
|
| 109 |
+
|
| 110 |
+
logger.info(f"[NewsAPI] Searching: '{search_q}' (lang={language})")
|
| 111 |
+
|
| 112 |
+
response = await self.client.get(url, params=params)
|
| 113 |
+
|
| 114 |
+
if response.status_code == 200:
|
| 115 |
+
data = response.json()
|
| 116 |
+
if data.get("status") != "ok":
|
| 117 |
+
logger.warning(f"NewsAPI error: {data.get('message', 'unknown')}")
|
| 118 |
+
return []
|
| 119 |
+
|
| 120 |
+
articles = data.get("articles", [])
|
| 121 |
+
results = []
|
| 122 |
+
for article in articles:
|
| 123 |
+
normalized = self._normalize_result(article)
|
| 124 |
+
if normalized:
|
| 125 |
+
results.append(normalized)
|
| 126 |
+
|
| 127 |
+
logger.info(
|
| 128 |
+
f"[NewsAPI] '{query[:50]}' → {len(results)} results "
|
| 129 |
+
f"(total available: {data.get('totalResults', 0)})"
|
| 130 |
+
)
|
| 131 |
+
return results
|
| 132 |
+
|
| 133 |
+
elif response.status_code == 401:
|
| 134 |
+
logger.error("NewsAPI: Invalid API key")
|
| 135 |
+
return []
|
| 136 |
+
elif response.status_code == 429:
|
| 137 |
+
logger.warning("NewsAPI: Rate limit exceeded (100 requests/day on free tier)")
|
| 138 |
+
return []
|
| 139 |
+
elif response.status_code == 426:
|
| 140 |
+
logger.warning("NewsAPI: Upgrade required (free tier limitations)")
|
| 141 |
+
return []
|
| 142 |
+
else:
|
| 143 |
+
logger.warning(f"NewsAPI returned status {response.status_code}: {response.text[:200]}")
|
| 144 |
+
return []
|
| 145 |
+
|
| 146 |
+
except asyncio.TimeoutError:
|
| 147 |
+
logger.warning(f"NewsAPI timeout ({self.timeout}s)")
|
| 148 |
+
return []
|
| 149 |
+
except Exception as e:
|
| 150 |
+
logger.error(f"NewsAPI search error: {e}")
|
| 151 |
+
return []
|
| 152 |
+
|
| 153 |
+
async def search_top_headlines(
|
| 154 |
+
self,
|
| 155 |
+
country: str = "us",
|
| 156 |
+
category: Optional[str] = None,
|
| 157 |
+
max_results: Optional[int] = None
|
| 158 |
+
) -> List[Dict[str, Any]]:
|
| 159 |
+
"""
|
| 160 |
+
Get top headlines from NewsAPI.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
country: Country code (us, gb, etc.) - Note: Ethiopia (et) not supported
|
| 164 |
+
category: Category (business, entertainment, general, health, science, sports, technology)
|
| 165 |
+
max_results: Override default max_results
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
List of normalized search results
|
| 169 |
+
"""
|
| 170 |
+
if not self.api_key:
|
| 171 |
+
logger.warning("NewsAPI unavailable - returning empty results")
|
| 172 |
+
return []
|
| 173 |
+
|
| 174 |
+
await self._ensure_client()
|
| 175 |
+
|
| 176 |
+
max_results = max_results or self.max_results
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
url = f"{self.BASE_URL}/top-headlines"
|
| 180 |
+
params = {
|
| 181 |
+
"country": country,
|
| 182 |
+
"pageSize": max_results
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
if category:
|
| 186 |
+
params["category"] = category
|
| 187 |
+
|
| 188 |
+
logger.info(f"[NewsAPI] Fetching top headlines (country={country}, category={category})")
|
| 189 |
+
|
| 190 |
+
response = await self.client.get(url, params=params)
|
| 191 |
+
|
| 192 |
+
if response.status_code == 200:
|
| 193 |
+
data = response.json()
|
| 194 |
+
articles = data.get("articles", [])
|
| 195 |
+
|
| 196 |
+
results = []
|
| 197 |
+
for article in articles:
|
| 198 |
+
normalized = self._normalize_result(article)
|
| 199 |
+
if normalized:
|
| 200 |
+
results.append(normalized)
|
| 201 |
+
|
| 202 |
+
logger.info(f"[NewsAPI] Top headlines: {len(results)} results")
|
| 203 |
+
return results
|
| 204 |
+
|
| 205 |
+
else:
|
| 206 |
+
logger.warning(f"NewsAPI top headlines returned status {response.status_code}")
|
| 207 |
+
return []
|
| 208 |
+
|
| 209 |
+
except Exception as e:
|
| 210 |
+
logger.error(f"NewsAPI top headlines error: {e}")
|
| 211 |
+
return []
|
| 212 |
+
|
| 213 |
+
def _normalize_result(self, article: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| 214 |
+
"""
|
| 215 |
+
Normalize NewsAPI result to common format.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
article: Raw article from NewsAPI
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
Normalized result dict or None if invalid
|
| 222 |
+
"""
|
| 223 |
+
try:
|
| 224 |
+
# Extract fields
|
| 225 |
+
title = article.get("title", "").strip()
|
| 226 |
+
url = article.get("url", "").strip()
|
| 227 |
+
description = article.get("description", "").strip()
|
| 228 |
+
content = article.get("content", "").strip()
|
| 229 |
+
source_name = article.get("source", {}).get("name", "").strip()
|
| 230 |
+
published_at = article.get("publishedAt", "")
|
| 231 |
+
author = article.get("author", "")
|
| 232 |
+
url_to_image = article.get("urlToImage", "")
|
| 233 |
+
|
| 234 |
+
# Validate required fields
|
| 235 |
+
if not title or not url:
|
| 236 |
+
logger.debug(f"Skipping invalid result: missing title or URL")
|
| 237 |
+
return None
|
| 238 |
+
|
| 239 |
+
# Combine description + content for better context
|
| 240 |
+
full_content = description
|
| 241 |
+
if content and content != description:
|
| 242 |
+
# NewsAPI truncates content with [+X chars]
|
| 243 |
+
# We'll use Jina Reader to get full article later
|
| 244 |
+
full_content = f"{description}\n\n{content}"
|
| 245 |
+
|
| 246 |
+
# Calculate freshness score
|
| 247 |
+
freshness_score = self._calculate_freshness(published_at)
|
| 248 |
+
|
| 249 |
+
return {
|
| 250 |
+
"title": title,
|
| 251 |
+
"url": url,
|
| 252 |
+
"content": full_content or title, # Use title if no content
|
| 253 |
+
"snippet": description,
|
| 254 |
+
"source": source_name or self._extract_domain(url),
|
| 255 |
+
"published_at": published_at,
|
| 256 |
+
"author": author,
|
| 257 |
+
"image_url": url_to_image,
|
| 258 |
+
"source_type": "live",
|
| 259 |
+
"is_live": True,
|
| 260 |
+
"freshness_score": freshness_score,
|
| 261 |
+
"language": "en", # NewsAPI returns language in query
|
| 262 |
+
"metadata": {
|
| 263 |
+
"title": title,
|
| 264 |
+
"url": url,
|
| 265 |
+
"source": source_name,
|
| 266 |
+
"published_at": published_at,
|
| 267 |
+
"author": author,
|
| 268 |
+
"search_engine": "newsapi"
|
| 269 |
+
}
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
except Exception as e:
|
| 273 |
+
logger.warning(f"Failed to normalize NewsAPI result: {e}")
|
| 274 |
+
return None
|
| 275 |
+
|
| 276 |
+
def _calculate_freshness(self, published_at: str) -> float:
|
| 277 |
+
"""
|
| 278 |
+
Calculate freshness score based on article age.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
published_at: ISO format date string
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
Freshness score (0.0 to 1.0)
|
| 285 |
+
"""
|
| 286 |
+
if not published_at:
|
| 287 |
+
return 0.8 # Unknown age, assume recent
|
| 288 |
+
|
| 289 |
+
try:
|
| 290 |
+
pub_date = datetime.fromisoformat(published_at.replace('Z', '+00:00'))
|
| 291 |
+
age = datetime.utcnow() - pub_date.replace(tzinfo=None)
|
| 292 |
+
age_minutes = age.total_seconds() / 60
|
| 293 |
+
|
| 294 |
+
# NewsAPI results are very fresh
|
| 295 |
+
if age_minutes < 10:
|
| 296 |
+
return 1.0 # < 10 min
|
| 297 |
+
elif age_minutes < 60:
|
| 298 |
+
return 0.98 # < 1 hour
|
| 299 |
+
elif age_minutes < 360:
|
| 300 |
+
return 0.95 # < 6 hours
|
| 301 |
+
elif age_minutes < 1440:
|
| 302 |
+
return 0.9 # < 24 hours
|
| 303 |
+
else:
|
| 304 |
+
return 0.85 # Older but still from live search
|
| 305 |
+
except:
|
| 306 |
+
return 0.8 # Default to recent
|
| 307 |
+
|
| 308 |
+
def _extract_domain(self, url: str) -> str:
|
| 309 |
+
"""
|
| 310 |
+
Extract domain name from URL.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
url: Full URL
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
Domain name (e.g., "bbc.com")
|
| 317 |
+
"""
|
| 318 |
+
try:
|
| 319 |
+
from urllib.parse import urlparse
|
| 320 |
+
parsed = urlparse(url)
|
| 321 |
+
domain = parsed.netloc
|
| 322 |
+
# Remove www. prefix
|
| 323 |
+
if domain.startswith("www."):
|
| 324 |
+
domain = domain[4:]
|
| 325 |
+
return domain
|
| 326 |
+
except:
|
| 327 |
+
return "unknown"
|
| 328 |
+
|
| 329 |
+
def is_available(self) -> bool:
|
| 330 |
+
"""
|
| 331 |
+
Check if NewsAPI is available.
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
True if API key is configured, False otherwise
|
| 335 |
+
"""
|
| 336 |
+
return self.api_key is not None
|
| 337 |
+
|
| 338 |
+
async def close(self):
|
| 339 |
+
"""Close HTTP client"""
|
| 340 |
+
if self.client:
|
| 341 |
+
await self.client.aclose()
|
| 342 |
+
self.client = None
|
| 343 |
+
logger.debug("NewsAPI client closed")
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
# ═══════════════════════════════════════════════════════════════════════════
|
| 347 |
+
# SINGLETON INSTANCE
|
| 348 |
+
# ═══════════════════════════════════════════════════════════════════════════
|
| 349 |
+
|
| 350 |
+
_default_adapter = None
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def get_newsapi_adapter(
|
| 354 |
+
api_key: str,
|
| 355 |
+
timeout: float = 2.0,
|
| 356 |
+
max_results: int = 20
|
| 357 |
+
) -> NewsAPIAdapter:
|
| 358 |
+
"""
|
| 359 |
+
Get or create the default NewsAPI adapter instance.
|
| 360 |
+
|
| 361 |
+
Args:
|
| 362 |
+
api_key: NewsAPI.org API key
|
| 363 |
+
timeout: Search timeout in seconds
|
| 364 |
+
max_results: Maximum results to return
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
NewsAPIAdapter instance
|
| 368 |
+
"""
|
| 369 |
+
global _default_adapter
|
| 370 |
+
if _default_adapter is None:
|
| 371 |
+
_default_adapter = NewsAPIAdapter(
|
| 372 |
+
api_key=api_key,
|
| 373 |
+
timeout=timeout,
|
| 374 |
+
max_results=max_results
|
| 375 |
+
)
|
| 376 |
+
return _default_adapter
|
src/infrastructure/adapters/redis_adapter.py
CHANGED
|
@@ -1,20 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
import logging
|
| 3 |
-
from typing import Optional, Dict, Any
|
| 4 |
-
import redis
|
| 5 |
import hashlib
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
from src.core.ports.cache_port import CachePort
|
| 8 |
from src.core.config import settings
|
| 9 |
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
class RedisAdapter(CachePort):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def __init__(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
try:
|
| 15 |
-
if hasattr(settings,
|
| 16 |
url = settings.REDIS_URL
|
| 17 |
-
# Upstash requires TLS
|
| 18 |
if url.startswith("redis://") and "upstash.io" in url:
|
| 19 |
url = "rediss://" + url[len("redis://"):]
|
| 20 |
self.client = redis.from_url(url, decode_responses=True)
|
|
@@ -24,43 +63,206 @@ class RedisAdapter(CachePort):
|
|
| 24 |
port=settings.REDIS_PORT,
|
| 25 |
db=settings.REDIS_DB,
|
| 26 |
password=settings.REDIS_PASSWORD or None,
|
| 27 |
-
decode_responses=True
|
| 28 |
)
|
| 29 |
self.client = redis.Redis(connection_pool=pool)
|
| 30 |
self.client.ping()
|
| 31 |
-
logger.info("Connected to Redis cache.")
|
| 32 |
except Exception as e:
|
| 33 |
-
logger.warning(f"
|
| 34 |
self.client = None
|
| 35 |
|
| 36 |
-
|
| 37 |
|
| 38 |
def get(self, key: str) -> Optional[Any]:
|
| 39 |
-
if not self.client:
|
|
|
|
| 40 |
try:
|
| 41 |
data = self.client.get(key)
|
| 42 |
return json.loads(data) if data else None
|
| 43 |
except Exception as e:
|
| 44 |
-
logger.
|
| 45 |
return None
|
| 46 |
|
| 47 |
def set(self, key: str, value: Any, expiration: int = 3600) -> bool:
|
| 48 |
-
if not self.client:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
try:
|
| 50 |
-
self.client.
|
| 51 |
return True
|
| 52 |
except Exception as e:
|
| 53 |
-
logger.
|
| 54 |
return False
|
| 55 |
|
| 56 |
def search_similar(self, query_vector: list, threshold: float = 0.95) -> Optional[Dict[str, Any]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
"""
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
as an interim caching mechanism until Redis vector extensions are configured.
|
| 61 |
"""
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Redis Cache Adapter — Smart Layered Caching
|
| 3 |
+
|
| 4 |
+
Cache layers with different TTLs:
|
| 5 |
+
Layer 1 — Intent cache : 1 hour (same query = same intent)
|
| 6 |
+
Layer 2 — Live search cache : 10 min (DuckDuckGo/NewsAPI results)
|
| 7 |
+
Layer 3 — Translation cache : 1 hour (LLM translation is expensive)
|
| 8 |
+
Layer 4 — Full response cache: 5 min (complete RAG answer)
|
| 9 |
+
|
| 10 |
+
Key naming convention:
|
| 11 |
+
intent_v2:{query_hash} → IntentResult dict
|
| 12 |
+
live_search:{query_hash} → list of live results
|
| 13 |
+
translation:{query_hash} → translation + expanded query dict
|
| 14 |
+
rag_response:{query_hash} → full RAG response dict
|
| 15 |
+
|
| 16 |
+
All keys use SHA-256 of the normalized query (lowercase, stripped).
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
import json
|
| 20 |
import logging
|
|
|
|
|
|
|
| 21 |
import hashlib
|
| 22 |
+
import time
|
| 23 |
+
from typing import Optional, Dict, Any, List
|
| 24 |
+
|
| 25 |
+
import redis
|
| 26 |
|
| 27 |
from src.core.ports.cache_port import CachePort
|
| 28 |
from src.core.config import settings
|
| 29 |
|
| 30 |
logger = logging.getLogger(__name__)
|
| 31 |
|
| 32 |
+
# ── TTL constants (seconds) ───────────────────────────────────────────────────
|
| 33 |
+
TTL_INTENT = 3600 # 1 hour — intent rarely changes for same query
|
| 34 |
+
TTL_LIVE_SEARCH = 600 # 10 min — live news stays fresh enough
|
| 35 |
+
TTL_TRANSLATION = 3600 # 1 hour — translations don't change
|
| 36 |
+
TTL_RESPONSE = 300 # 5 min — full RAG response (temporal queries need freshness)
|
| 37 |
+
TTL_RESPONSE_HISTORICAL = 1800 # 30 min — historical answers change less often
|
| 38 |
+
|
| 39 |
+
|
| 40 |
class RedisAdapter(CachePort):
|
| 41 |
+
"""
|
| 42 |
+
Redis cache adapter with smart layered caching.
|
| 43 |
+
|
| 44 |
+
Falls back gracefully when Redis is unavailable — all methods
|
| 45 |
+
return None/False instead of raising exceptions.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
def __init__(self):
|
| 49 |
+
self.client = None
|
| 50 |
+
self._connect()
|
| 51 |
+
|
| 52 |
+
def _connect(self):
|
| 53 |
try:
|
| 54 |
+
if hasattr(settings, "REDIS_URL") and settings.REDIS_URL:
|
| 55 |
url = settings.REDIS_URL
|
| 56 |
+
# Upstash requires TLS
|
| 57 |
if url.startswith("redis://") and "upstash.io" in url:
|
| 58 |
url = "rediss://" + url[len("redis://"):]
|
| 59 |
self.client = redis.from_url(url, decode_responses=True)
|
|
|
|
| 63 |
port=settings.REDIS_PORT,
|
| 64 |
db=settings.REDIS_DB,
|
| 65 |
password=settings.REDIS_PASSWORD or None,
|
| 66 |
+
decode_responses=True,
|
| 67 |
)
|
| 68 |
self.client = redis.Redis(connection_pool=pool)
|
| 69 |
self.client.ping()
|
| 70 |
+
logger.info("✅ Connected to Redis cache.")
|
| 71 |
except Exception as e:
|
| 72 |
+
logger.warning(f"Redis unavailable: {e}. All cache operations will be no-ops.")
|
| 73 |
self.client = None
|
| 74 |
|
| 75 |
+
# ── CachePort interface ───────────────────────────────────────────────────
|
| 76 |
|
| 77 |
def get(self, key: str) -> Optional[Any]:
|
| 78 |
+
if not self.client:
|
| 79 |
+
return None
|
| 80 |
try:
|
| 81 |
data = self.client.get(key)
|
| 82 |
return json.loads(data) if data else None
|
| 83 |
except Exception as e:
|
| 84 |
+
logger.debug(f"Redis get error for key '{key}': {e}")
|
| 85 |
return None
|
| 86 |
|
| 87 |
def set(self, key: str, value: Any, expiration: int = 3600) -> bool:
|
| 88 |
+
if not self.client:
|
| 89 |
+
return False
|
| 90 |
+
try:
|
| 91 |
+
self.client.setex(key, expiration, json.dumps(value, default=str))
|
| 92 |
+
return True
|
| 93 |
+
except Exception as e:
|
| 94 |
+
logger.debug(f"Redis set error for key '{key}': {e}")
|
| 95 |
+
return False
|
| 96 |
+
|
| 97 |
+
def delete(self, key: str) -> bool:
|
| 98 |
+
if not self.client:
|
| 99 |
+
return False
|
| 100 |
try:
|
| 101 |
+
self.client.delete(key)
|
| 102 |
return True
|
| 103 |
except Exception as e:
|
| 104 |
+
logger.debug(f"Redis delete error for key '{key}': {e}")
|
| 105 |
return False
|
| 106 |
|
| 107 |
def search_similar(self, query_vector: list, threshold: float = 0.95) -> Optional[Dict[str, Any]]:
|
| 108 |
+
"""Vector similarity search — not implemented (requires RedisSearch module)."""
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
# ── Key generation ──────────────────────────────────────────────────────���─
|
| 112 |
+
|
| 113 |
+
def generate_exact_hash(self, text: str) -> str:
|
| 114 |
+
"""SHA-256 hash of normalized text for exact-match cache keys."""
|
| 115 |
+
normalized = text.lower().strip()
|
| 116 |
+
return hashlib.sha256(normalized.encode("utf-8")).hexdigest()
|
| 117 |
+
|
| 118 |
+
def _make_key(self, prefix: str, query: str) -> str:
|
| 119 |
+
"""Build a namespaced cache key from query text."""
|
| 120 |
+
return f"{prefix}:{self.generate_exact_hash(query)}"
|
| 121 |
+
|
| 122 |
+
# ── Layer 1: Intent cache ─────────────────────────────────────────────────
|
| 123 |
+
|
| 124 |
+
def get_intent(self, query: str) -> Optional[Dict[str, Any]]:
|
| 125 |
"""
|
| 126 |
+
Retrieve cached intent result for a query.
|
| 127 |
+
Returns dict with keys: intent, confidence, method
|
|
|
|
| 128 |
"""
|
| 129 |
+
key = self._make_key("intent_v2", query)
|
| 130 |
+
result = self.get(key)
|
| 131 |
+
if result:
|
| 132 |
+
logger.debug(f"[Cache] Intent HIT for '{query[:50]}'")
|
| 133 |
+
return result
|
| 134 |
+
|
| 135 |
+
def set_intent(self, query: str, intent_data: Dict[str, Any]) -> bool:
|
| 136 |
+
"""Cache intent result for 1 hour."""
|
| 137 |
+
key = self._make_key("intent_v2", query)
|
| 138 |
+
success = self.set(key, intent_data, expiration=TTL_INTENT)
|
| 139 |
+
if success:
|
| 140 |
+
logger.debug(f"[Cache] Intent SET for '{query[:50]}' (TTL={TTL_INTENT}s)")
|
| 141 |
+
return success
|
| 142 |
+
|
| 143 |
+
# ── Layer 2: Live search cache ────────────────────────────────────────────
|
| 144 |
+
|
| 145 |
+
def get_live_search(self, query: str) -> Optional[List[Dict[str, Any]]]:
|
| 146 |
+
"""
|
| 147 |
+
Retrieve cached live search results (DuckDuckGo + NewsAPI).
|
| 148 |
+
Returns list of result dicts or None if not cached.
|
| 149 |
+
"""
|
| 150 |
+
key = self._make_key("live_search", query)
|
| 151 |
+
result = self.get(key)
|
| 152 |
+
if result:
|
| 153 |
+
age = result.get("_cached_at", 0)
|
| 154 |
+
elapsed = int(time.time()) - age if age else 0
|
| 155 |
+
logger.info(f"[Cache] Live search HIT for '{query[:50]}' (age={elapsed}s)")
|
| 156 |
+
return result.get("results", [])
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
def set_live_search(self, query: str, results: List[Dict[str, Any]]) -> bool:
|
| 160 |
+
"""Cache live search results for 10 minutes."""
|
| 161 |
+
key = self._make_key("live_search", query)
|
| 162 |
+
payload = {
|
| 163 |
+
"results": results,
|
| 164 |
+
"_cached_at": int(time.time()),
|
| 165 |
+
"_query": query[:100],
|
| 166 |
+
}
|
| 167 |
+
success = self.set(key, payload, expiration=TTL_LIVE_SEARCH)
|
| 168 |
+
if success:
|
| 169 |
+
logger.info(f"[Cache] Live search SET for '{query[:50]}' ({len(results)} results, TTL={TTL_LIVE_SEARCH}s)")
|
| 170 |
+
return success
|
| 171 |
+
|
| 172 |
+
# ── Layer 3: Translation cache ────────────────────────────────────────────
|
| 173 |
+
|
| 174 |
+
def get_translation(self, query: str) -> Optional[Dict[str, Any]]:
|
| 175 |
+
"""
|
| 176 |
+
Retrieve cached translation + query expansion result.
|
| 177 |
+
Returns dict with keys: expanded_query, translations, days_back, etc.
|
| 178 |
+
"""
|
| 179 |
+
key = self._make_key("translation", query)
|
| 180 |
+
result = self.get(key)
|
| 181 |
+
if result:
|
| 182 |
+
logger.debug(f"[Cache] Translation HIT for '{query[:50]}'")
|
| 183 |
+
return result
|
| 184 |
+
|
| 185 |
+
def set_translation(self, query: str, translation_data: Dict[str, Any]) -> bool:
|
| 186 |
+
"""Cache translation result for 1 hour."""
|
| 187 |
+
key = self._make_key("translation", query)
|
| 188 |
+
success = self.set(key, translation_data, expiration=TTL_TRANSLATION)
|
| 189 |
+
if success:
|
| 190 |
+
logger.debug(f"[Cache] Translation SET for '{query[:50]}' (TTL={TTL_TRANSLATION}s)")
|
| 191 |
+
return success
|
| 192 |
+
|
| 193 |
+
# ── Layer 4: Full response cache ──────────────────────────────────────────
|
| 194 |
+
|
| 195 |
+
def get_response(self, query: str) -> Optional[Dict[str, Any]]:
|
| 196 |
+
"""
|
| 197 |
+
Retrieve cached full RAG response.
|
| 198 |
+
Returns complete response dict or None if not cached.
|
| 199 |
+
"""
|
| 200 |
+
key = self._make_key("rag_response", query)
|
| 201 |
+
result = self.get(key)
|
| 202 |
+
if result:
|
| 203 |
+
age = result.get("_cached_at", 0)
|
| 204 |
+
elapsed = int(time.time()) - age if age else 0
|
| 205 |
+
logger.info(f"[Cache] Response HIT for '{query[:50]}' (age={elapsed}s)")
|
| 206 |
+
return result
|
| 207 |
+
|
| 208 |
+
def set_response(
|
| 209 |
+
self,
|
| 210 |
+
query: str,
|
| 211 |
+
response: Dict[str, Any],
|
| 212 |
+
intent: str = "NEWS_GENERAL"
|
| 213 |
+
) -> bool:
|
| 214 |
+
"""
|
| 215 |
+
Cache full RAG response.
|
| 216 |
+
TTL depends on intent:
|
| 217 |
+
- NEWS_TEMPORAL → 5 min (fresh news changes fast)
|
| 218 |
+
- NEWS_HISTORICAL → 30 min (historical facts are stable)
|
| 219 |
+
- NEWS_GENERAL → 5 min (default)
|
| 220 |
+
"""
|
| 221 |
+
key = self._make_key("rag_response", query)
|
| 222 |
+
ttl = TTL_RESPONSE_HISTORICAL if intent == "NEWS_HISTORICAL" else TTL_RESPONSE
|
| 223 |
+
payload = {
|
| 224 |
+
**response,
|
| 225 |
+
"_cached_at": int(time.time()),
|
| 226 |
+
"_intent": intent,
|
| 227 |
+
}
|
| 228 |
+
success = self.set(key, payload, expiration=ttl)
|
| 229 |
+
if success:
|
| 230 |
+
logger.info(f"[Cache] Response SET for '{query[:50]}' (intent={intent}, TTL={ttl}s)")
|
| 231 |
+
return success
|
| 232 |
+
|
| 233 |
+
# ── Cache stats ───────────────────────────────────────────────────────────
|
| 234 |
+
|
| 235 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 236 |
+
"""Return cache statistics."""
|
| 237 |
+
if not self.client:
|
| 238 |
+
return {"status": "disconnected"}
|
| 239 |
+
try:
|
| 240 |
+
info = self.client.info("stats")
|
| 241 |
+
keyspace = self.client.info("keyspace")
|
| 242 |
+
return {
|
| 243 |
+
"status": "connected",
|
| 244 |
+
"hits": info.get("keyspace_hits", 0),
|
| 245 |
+
"misses": info.get("keyspace_misses", 0),
|
| 246 |
+
"hit_rate": round(
|
| 247 |
+
info.get("keyspace_hits", 0) /
|
| 248 |
+
max(1, info.get("keyspace_hits", 0) + info.get("keyspace_misses", 0))
|
| 249 |
+
* 100, 1
|
| 250 |
+
),
|
| 251 |
+
"total_keys": sum(
|
| 252 |
+
v.get("keys", 0) for v in keyspace.values()
|
| 253 |
+
if isinstance(v, dict)
|
| 254 |
+
),
|
| 255 |
+
"memory_used": self.client.info("memory").get("used_memory_human", "?"),
|
| 256 |
+
}
|
| 257 |
+
except Exception as e:
|
| 258 |
+
return {"status": "error", "error": str(e)}
|
| 259 |
+
|
| 260 |
+
def is_available(self) -> bool:
|
| 261 |
+
"""Check if Redis is connected."""
|
| 262 |
+
if not self.client:
|
| 263 |
+
return False
|
| 264 |
+
try:
|
| 265 |
+
self.client.ping()
|
| 266 |
+
return True
|
| 267 |
+
except Exception:
|
| 268 |
+
return False
|