diff --git a/.env b/.env new file mode 100644 index 0000000000000000000000000000000000000000..65950e1f08388f7d1a86d2151fb7f33baa09a023 --- /dev/null +++ b/.env @@ -0,0 +1,78 @@ +# ========================================== +# RAG API Environment Configuration +# ========================================== + +# --- API Settings --- +PROJECT_NAME="RAG API Service" + +# --- Qdrant (Vector Database) --- +QDRANT_URL=https://41524d5c-8b82-4106-84b9-db452ef40133.eu-central-1-0.aws.cloud.qdrant.io:6333 +QDRANT_API_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIiwic3ViamVjdCI6ImFwaS1rZXk6NzY2MTRlMWUtNDJmMi00MDVkLTgxMWMtYjAyZDgwOGZjZDM0In0.QcECySpQnj1wzpif1k3K1G6Dz-PW9K5eNZ1ueNnn-IY +QDRANT_HOST=localhost +QDRANT_PORT=6333 +QDRANT_COLLECTION=news_articles + +# --- ClickHouse (Data Warehouse Analytics) --- +CLICKHOUSE_HOST=emrsjlb12r.eu-central-1.aws.clickhouse.cloud +CLICKHOUSE_PORT=8443 +CLICKHOUSE_USER=default +CLICKHOUSE_PASSWORD=hOKAH9T9LoQ.m +CLICKHOUSE_SECURE=true + +# --- PostgreSQL (Interactions & Accounts DB) --- +POSTGRES_USER=postgres +POSTGRES_PASSWORD=postgres +POSTGRES_SERVER=localhost +POSTGRES_PORT=5433 +POSTGRES_DB=rag_interactions + +# --- Models configuration --- +EMBEDDING_MODEL=BAAI/bge-m3 +VECTOR_SIZE=1024 +RERANKER_MODEL=BAAI/bge-reranker-v2-m3 + +# ========================================== +# LLM Provider — set LLM_PROVIDER to one of: +# groq → Free, 200+ tok/s, best for production (recommended) +# gemini → Free tier 15 RPM / 1M TPM, great quality +# together → Free $25 credit, Llama 3.3 70B +# openai → Paid, GPT-4o +# ollama → Local inference (no API key needed) +# ========================================== +LLM_PROVIDER=groq + +# --- Groq (FREE) --- +# Get key: https://console.groq.com/keys +# Models: llama-3.3-70b-versatile | llama-3.1-8b-instant | mixtral-8x7b-32768 | gemma2-9b-it +GROQ_API_KEY=your-groq-api-key-here +GROQ_MODEL=llama-3.3-70b-versatile + +# --- Google Gemini (FREE tier) --- +# Get key: https://aistudio.google.com/apikey +# Models: gemini-2.0-flash | gemini-1.5-flash | gemini-1.5-pro +GEMINI_API_KEY=AIzaSyB-LlAj_nhxRNpHzqBhxIMDc4R8eaDaYYI +GEMINI_MODEL=gemini-2.0-flash + +# --- Together AI (FREE $25 credit) --- +# Get key: https://api.together.ai +# Models: meta-llama/Llama-3.3-70B-Instruct-Turbo | mistralai/Mixtral-8x7B-Instruct-v0.1 +TOGETHER_API_KEY=key_CaW4uNxnNyzsFUcaYhB8y +TOGETHER_MODEL=meta-llama/Llama-3.3-70B-Instruct-Turbo + +# --- OpenAI (Paid) --- +OPENAI_API_KEY=your-openai-api-key-here + +# --- Ollama (Local) --- +# Run: ollama pull llama3.2 +OLLAMA_HOST=http://localhost:11434 +OLLAMA_MODEL=llama3.2 + +# --- Redis Semantic Caching --- +REDIS_HOST=localhost +REDIS_PORT=6380 +REDIS_DB=0 +REDIS_PASSWORD= + +# --- Security & Auth --- +SECRET_KEY=a_very_secret_key_change_me_in_production +ACCESS_TOKEN_EXPIRE_MINUTES=60 diff --git a/.env.example b/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..b26c39a9dbe7eff612fe6870c93e38be31a92319 --- /dev/null +++ b/.env.example @@ -0,0 +1,82 @@ +# ═══════════════════════════════════════════════════════════════════════════ +# RAG API Configuration +# ═══════════════════════════════════════════════════════════════════════════ + +# ── Vector Database (Qdrant) ────────────────────────────────────────────── +QDRANT_HOST=localhost +QDRANT_PORT=6333 +QDRANT_URL= # Cloud URL (overrides host/port) +QDRANT_API_KEY= # Cloud API Key +QDRANT_COLLECTION=news_articles_hybrid + +# ── Analytics Database (ClickHouse) ──────────────────────────────────────── +CLICKHOUSE_HOST=localhost +CLICKHOUSE_PORT=8123 +CLICKHOUSE_USER=default +CLICKHOUSE_PASSWORD= +CLICKHOUSE_DB=default +CLICKHOUSE_SECURE=false + +# ── User Database (PostgreSQL/Neon) ──────────────────────────────────────── +DATABASE_URL= # Full Neon URL (overrides individual fields) +POSTGRES_USER=postgres +POSTGRES_PASSWORD=postgres +POSTGRES_SERVER=localhost +POSTGRES_PORT=5432 +POSTGRES_DB=rag_interactions + +# ── Embedding & Reranking Models ─────────────────────────────────────────── +EMBEDDING_MODEL=BAAI/bge-m3 +VECTOR_SIZE=1024 +RERANKER_MODEL=BAAI/bge-reranker-v2-m3 + +# ── LLM Provider ─────────────────────────────────────────────────────────── +# Supported: "groq", "gemini", "together", "openai", "ollama" +LLM_PROVIDER=gemini + +# Groq (free, 200+ tok/s) +GROQ_API_KEY= +GROQ_MODEL=llama-3.3-70b-versatile + +# Google Gemini (free tier: 15 RPM / 1M TPM) +GEMINI_API_KEY= +GEMINI_MODEL=gemini-1.5-flash + +# Together AI (free $25 credit) +TOGETHER_API_KEY= +TOGETHER_MODEL=meta-llama/Llama-3.3-70B-Instruct-Turbo + +# HuggingFace Inference API +HF_TOKEN= +HF_MODEL=meta-llama/Llama-3.1-8B-Instruct + +# Ollama (local) +OLLAMA_HOST=http://localhost:11434 +OLLAMA_MODEL=llama3.2 + +# OpenAI +OPENAI_API_KEY= + +# ── Redis Cache ──────────────────────────────────────────────────────────── +REDIS_URL= # Full URL (Upstash) - overrides host/port +REDIS_HOST=localhost +REDIS_PORT=6380 +REDIS_DB=0 +REDIS_PASSWORD= + +# ── Hybrid Search Settings ───────────────────────────────────────────────── +ENABLE_HYBRID_SEARCH=true +LIVE_SEARCH_TIMEOUT=2.0 +LIVE_SEARCH_MAX_RESULTS=5 +LIVE_SEARCH_WEIGHT=0.5 +DB_SEARCH_WEIGHT=0.5 + +# ── Cache Settings (TTL in seconds) ──────────────────────────────────────── +CACHE_RESPONSE_TTL=300 # 5 minutes - full response cache +CACHE_LIVE_TTL=600 # 10 minutes - live search results +CACHE_TRANSLATION_TTL=3600 # 1 hour - translated queries +CACHE_INTENT_TTL=3600 # 1 hour - intent classification + +# ── Security ─────────────────────────────────────────────────────────────── +SECRET_KEY=change_me_in_production_to_a_very_long_random_string +ACCESS_TOKEN_EXPIRE_MINUTES=60 diff --git a/.gitkeep b/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..a1c8330b48fdce7f62da28dd2d2872a42d366ba5 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,34 @@ +FROM python:3.10-slim + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + libpq-dev \ + && rm -rf /var/lib/apt/lists/* + +# Hugging Face Spaces requires non-root user UID 1000 +RUN useradd -m -u 1000 user +USER user +ENV PATH="/home/user/.local/bin:$PATH" + +WORKDIR /app + +# Install Python dependencies +COPY --chown=user requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy source code +COPY --chown=user src/ ./src/ + +# Download embedding + reranker models at build time +# so the first request is fast (no cold start download) +COPY --chown=user download_models.py . +RUN python download_models.py + +ENV PYTHONUNBUFFERED=1 +ENV PYTHONPATH=/app + +# Hugging Face Spaces requires port 7860 +EXPOSE 7860 + +CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "7860"] diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4d244f9b140344ed5306a1daf0896da903f8d233 --- /dev/null +++ b/README.md @@ -0,0 +1,11 @@ +--- +title: RAG API +emoji: 🔍 +colorFrom: indigo +colorTo: purple +sdk: docker +pinned: false +--- + +# INSA News RAG API +FastAPI-based Retrieval-Augmented Generation API powered by BGE-M3 embeddings, Qdrant Cloud, and Groq LLaMA 3. diff --git a/__pycache__/migrate_database.cpython-313.pyc b/__pycache__/migrate_database.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a04ae7068148d091f41127add91089dc8edeac6 Binary files /dev/null and b/__pycache__/migrate_database.cpython-313.pyc differ diff --git a/__pycache__/test_main.cpython-313-pytest-9.0.1.pyc b/__pycache__/test_main.cpython-313-pytest-9.0.1.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57ae874e0ac1558cc6ae72c33552e1c8eef89889 Binary files /dev/null and b/__pycache__/test_main.cpython-313-pytest-9.0.1.pyc differ diff --git a/__pycache__/test_main.cpython-313-pytest-9.0.2.pyc b/__pycache__/test_main.cpython-313-pytest-9.0.2.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57ae874e0ac1558cc6ae72c33552e1c8eef89889 Binary files /dev/null and b/__pycache__/test_main.cpython-313-pytest-9.0.2.pyc differ diff --git a/check_errors.py b/check_errors.py new file mode 100644 index 0000000000000000000000000000000000000000..8659dc1af0a396be7bcf9ecb70ec2fcbb1ffb023 --- /dev/null +++ b/check_errors.py @@ -0,0 +1,141 @@ +""" +Quick Error Checker for RAG API + +Similar to 'npm run build' for JavaScript, this checks for Python errors. + +Usage: + python check_errors.py +""" + +import sys +import os +from pathlib import Path +import py_compile +import importlib.util + +# Colors for output +GREEN = '\033[92m' +RED = '\033[91m' +YELLOW = '\033[93m' +RESET = '\033[0m' +BOLD = '\033[1m' + + +def print_header(text): + """Print section header""" + print(f"\n{BOLD}{'='*60}{RESET}") + print(f"{BOLD}{text}{RESET}") + print(f"{BOLD}{'='*60}{RESET}\n") + + +def check_syntax(file_path): + """Check Python syntax (like tsc --noEmit)""" + try: + py_compile.compile(file_path, doraise=True) + return True, None + except py_compile.PyCompileError as e: + return False, str(e) + + +def check_imports(file_path): + """Check if file can be imported""" + try: + spec = importlib.util.spec_from_file_location("module", file_path) + if spec and spec.loader: + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return True, None + except Exception as e: + return False, str(e) + + +def find_python_files(directory): + """Find all Python files in directory""" + return list(Path(directory).rglob("*.py")) + + +def main(): + """Main error checking function""" + print(f"{BOLD}🐍 Python Error Checker{RESET}") + print(f"Similar to 'npm run build' for JavaScript\n") + + # Get source directory + src_dir = Path(__file__).parent / "src" + + if not src_dir.exists(): + print(f"{RED}❌ Source directory not found: {src_dir}{RESET}") + return 1 + + # Find all Python files + python_files = find_python_files(src_dir) + print(f"Found {len(python_files)} Python files\n") + + # Track results + syntax_errors = [] + import_errors = [] + + # ── Stage 1: Syntax Check ────────────────────────────────────────────── + print_header("Stage 1: Syntax Check (like tsc --noEmit)") + + for file_path in python_files: + relative_path = file_path.relative_to(Path.cwd()) + + success, error = check_syntax(file_path) + + if success: + print(f"{GREEN}✓{RESET} {relative_path}") + else: + print(f"{RED}✗{RESET} {relative_path}") + print(f" {RED}Error: {error}{RESET}") + syntax_errors.append((relative_path, error)) + + # ── Stage 2: Import Check ────────────────────────────────────────────── + print_header("Stage 2: Import Check") + + # Only check files that passed syntax check + files_to_import = [f for f in python_files if f not in [e[0] for e in syntax_errors]] + + # Add src to path for imports + sys.path.insert(0, str(src_dir.parent)) + + for file_path in files_to_import: + relative_path = file_path.relative_to(Path.cwd()) + + # Skip __init__.py files + if file_path.name == "__init__.py": + print(f"{YELLOW}⊘{RESET} {relative_path} (skipped)") + continue + + success, error = check_imports(file_path) + + if success: + print(f"{GREEN}✓{RESET} {relative_path}") + else: + print(f"{RED}✗{RESET} {relative_path}") + print(f" {RED}Error: {error[:200]}...{RESET}") + import_errors.append((relative_path, error)) + + # ── Summary ──────────────────────────────────────────────────────────── + print_header("Summary") + + total_files = len(python_files) + syntax_ok = total_files - len(syntax_errors) + import_ok = len(files_to_import) - len(import_errors) + + print(f"Total files checked: {total_files}") + print(f"Syntax check: {GREEN}{syntax_ok} passed{RESET}, {RED}{len(syntax_errors)} failed{RESET}") + print(f"Import check: {GREEN}{import_ok} passed{RESET}, {RED}{len(import_errors)} failed{RESET}") + + # ── Exit Code ────────────────────────────────────────────────────────── + if syntax_errors or import_errors: + print(f"\n{RED}{BOLD}❌ Build Failed{RESET}") + print(f"\nFix the errors above and try again.") + return 1 + else: + print(f"\n{GREEN}{BOLD}✅ Build Successful{RESET}") + print(f"\nAll files are error-free!") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/config.env b/config.env new file mode 100644 index 0000000000000000000000000000000000000000..5dd098af3a9ecbfb9f8c6664802fb5ad9f9bd9aa --- /dev/null +++ b/config.env @@ -0,0 +1,7 @@ +# RAG API Environment Variables +OLLAMA_MODEL=tinyllama +PYTHONPATH=src;..\..\.. +QDRANT_HOST=localhost +QDRANT_PORT=6333 +QDRANT_COLLECTION=news_articles +RAG_ENDPOINT=/rag/chat/test diff --git a/docs/ANALYSIS_ONE.md b/docs/ANALYSIS_ONE.md new file mode 100644 index 0000000000000000000000000000000000000000..31bd83a7c9da5c2a238954d3e40514a420f47a4f --- /dev/null +++ b/docs/ANALYSIS_ONE.md @@ -0,0 +1,77 @@ +# RAG API Analysis & Critique + +This document provides a critical evaluation of the current RAG (Retrieval-Augmented Generation) API implementation and outlines a path toward a fully optimized production system. + +## Current Status: "Basic RAG" +The current implementation is a functional **"Naive RAG"** pipeline. It successfully connects the core components (Embedding -> Vector DB -> LLM), but it lacks the advanced optimizations required for a high-quality production system. + +**Is it fully implemented?** +- **Technically: Yes.** It performs retrieval and generation. +- **Strategically: No.** It lacks query refinement, re-ranking, and context optimization. + +--- + +## Critical Weaknesses & Solutions + +### 1. Simple Vector Retrieval (Naive Search) +- **Problem**: It relies solely on dense embeddings (BGE-M3). While powerful, dense search often fails on specific keywords, acronyms, or names that weren't frequent in the model's training data. +- **Reason**: Pure semantic search can have "false positives" where semantically similar but factually irrelevant text is retrieved. +- **Solution**: Implement **Hybrid Search**. Combine dense vector search with sparse keyword search (e.g., BM25/Elasticsearch/Qdrant sparse vectors). + +### 2. Multi-turn Query "Drift" +- **Problem**: The query sent to the vector database is the raw user input. +- **Reason**: In a chat, a user might say "Tell me more about it." The word "it" has no semantic meaning for a vector search without the previous context. +- **Solution**: **Query Transformation**. Before retrieval, use an LLM to "rewrite" the user's query into a standalone, descriptive search query based on the chat history. + +### 3. Lack of Re-ranking +- **Problem**: The top $K$ results from the vector database are passed directly to the LLM. +- **Reason**: Vector databases optimize for speed, not absolute precision. The "Top 1" result might not be the most relevant answer to the specific question. +- **Solution**: Add a **Re-ranker** (e.g., Cohere Rerank or a Cross-Encoder model). Retrieve 20 chunks, re-score them, and pass only the top 5 most relevant ones to the LLM. + +### 4. Context Overflow & Noise +- **Problem**: Chunks are concatenated without token validation or noise reduction. +- **Reason**: Passing too much irrelevant context ("Noise") confuses the LLM and increases latency/cost. +- **Solution**: Implement **Context Filtering** and **Token Counting**. Use `tiktoken` to ensure the prompt stays within limits and use the LLM to filter out chunks that don't actually help answer the question. + +--- + +## Proposed Enhancement Plan + +### Phase 1: Robustness (Immediate) +- [x] Add `tiktoken` for context window management. +- [x] Implement query rewriting for better multi-turn retrieval. +- [x] Add explicit error handling for embedding model loading failures. + +### Phase 2: Retrieval Quality (Intermediate) +- [x] Configure Qdrant for deeper search depth. +- [x] Integrate a Cross-Encoder for Re-ranking retrieved articles. + +### Phase 3: Developer Experience +- [ ] Add an evaluation pipeline (e.g., Ragas) to measure "Faithfulness" and "Answer Relevancy". + +--- + +## Conclusion +The RAG API has been upgraded from a **Proof of Concept (PoC)** to an **Advanced RAG** implementation. It now handles complex, multi-turn questions with high precision and robust context management. + +--- + +## Current Implementation & Solutions + +As of the latest update, the following solutions have been implemented to address the weaknesses identified above: + +### 1. Search Precision (Depth + Rank) +- **Status**: **Implemented** +- **Solution**: Increased initial retrieval depth (20 candidates) and integrated a second-stage re-ranking process. This ensures that even if semantic search doesn't put the best result first, the re-ranker will find it. + +### 2. Query Transformation +- **Status**: **Implemented** +- **Solution**: Added an LLM-based query rewriting step that uses chat history to rephrase user follow-ups into standalone search queries. This eliminates "query drift" in multi-turn conversations. + +### 3. Cross-Encoder Re-ranking +- **Status**: **Implemented** +- **Solution**: Integrated a dedicated `RerankerService` using a Cross-Encoder model. This re-evaluates the relevance of retrieved chunks against the actual query. + +### 4. Token-Aware Context Management +- **Status**: **Implemented** +- **Solution**: Integrated `tiktoken` for precise token counting. Implemented logic to prune and truncate retrieved chunks to fit within a 3000-token budget, preventing prompt overflow. diff --git a/docs/ANALYSIS_THREE.md b/docs/ANALYSIS_THREE.md new file mode 100644 index 0000000000000000000000000000000000000000..334fa50c850731dd15c3b5b8de12917c312433c4 --- /dev/null +++ b/docs/ANALYSIS_THREE.md @@ -0,0 +1,64 @@ +# RAG API Analysis & Critique - Session 3 (Final) + +This final session targets deep-level infrastructure bottlenecks, production resilience, and advanced UX patterns for a professional News Pipeline. + +## 1. The Redundancy Bottleneck (Semantic Diversity) +- **Critique**: In news, a single event (e.g., "Market Crash") is covered by 50 sources. Semantic search will retrieve 10 chunks from 10 different sources that say the exact same thing. +- **Reason**: This fills the 3000-token context window with redundant info, preventing the LLM from seeing "The full picture" or diverse perspectives. +- **Solution**: Implement **Diversity Filtering (Maximal Marginal Relevance - MMR)**. Instead of just "top K similarity", select chunks that are similar to the query but *dissimilar* to each other. + +## 2. Infrastructure Silos (ClickHouse-RAG Fusion) +- **Critique**: ClickHouse stores "Trends" and "Sentiment" for thousands of articles, but the RAG pipeline operates as a isolated silo. +- **Reason**: The LLM might answer a question about a person without knowing they are "Trending for Negative Sentiment" today. +- **Solution**: Inject **Global Context Metadata**. Before long-form generation, fetch a "Trend Snapshot" for the query's entities from ClickHouse and inject it into the prompt. + +## 3. The "Wait-Time" UX Bottleneck (Streaming) +- **Critique**: Currently, the user waits for Retrieval -> Reranking -> Full Generation before seeing any text. This can take 3-5 seconds. +- **Reason**: Synchronous JSON responses are the standard for REST, but feel "slow" for chat. +- **Solution**: Implement **Asynchronous Streaming (Server-Sent Events)**. Use FastAPI's `StreamingResponse` to stream tokens as GPT-4 generates them. + +## 4. Production Resilience (Circuit Breakers) +- **Critique**: If Qdrant or the local Embedder fails, the `/chat` endpoint returns a generic error or hangs. +- **Reason**: Lack of fallback strategies for critical path components. +- **Solution**: Implement **Graceful Degradation**. If Vector Search fails, fall back to a "Recent Headlines" keyword search in ClickHouse. If GPT-4 fails, return the raw retrieved sources with a "Summary Unavailable" message. + +## 5. Scaling: Index Quantization +- **Critique**: As the news corpus reaches millions of articles, Qdrant's RAM usage and search latency will spike due to BGE-M3's large vectors (1024 dim). +- **Reason**: Storing full-precision (float32) vectors is expensive. +- **Solution**: Enable **Scalar Quantization (int8)** or **Binary Quantization** in Qdrant. This reduces RAM usage by 4x-32x with minimal loss in precision. + +--- + +## Final Enhancement Roadmap + +| Enhancement | Reason | Solution | +| :--- | :--- | :--- | +| **Diversity Filter (MMR)** | Context waste | Rerank for novelty, not just similarity. | +| **Streaming Response** | UX Latency | Use SSE to stream LLM tokens. | +| **ClickHouse Insights** | Hidden Metadata | Inject trend data into the prompt. | +| **Circuit Breakers** | Fault Tolerance | Fallback to keyword search on VDB failure. | + +--- + +## Implementation Details (Session 3) + +As the final phase of this RAG evolution, I have implemented the following "State-of-the-Art" patterns: + +### 1. Diversity Filtering (MMR) +- **Status**: **Implemented** +- **Details**: Added `apply_mmr` and `_get_simple_similarity` to `RerankerService`. After the initial Cross-Encoder rerank, the system now runs a Maximal Marginal Relevance pass to ensure that the top documents provide diverse information rather than repeated facts. + +### 2. Streaming Responses (SSE) +- **Status**: **Implemented** +- **Details**: Added a new `/api/v1/rag/chat/stream` endpoint in `rag.py`. It uses FastAPI's `StreamingResponse` and LangChain's `.stream()` method to deliver answer tokens in real-time to the frontend. + +### 3. ClickHouse Trend Fusion +- **Status**: **Implemented** +- **Details**: The RAG pipeline now queries the `DataWarehouse` during the refinement stage. If active trends (entities and sentiment) are found in ClickHouse, they are injected into the LLM prompt, providing the assistant with "Live Context" beyond simple static retrieval. + +### 4. Circuit Breaker Fallbacks +- **Status**: **Implemented** +- **Details**: Updated `VectorStore.search` to handle exceptions. In the event of a Qdrant service failure, the system automatically falls back to `fallback_keyword_search` in ClickHouse, ensuring the user gets *some* relevant headlines instead of an error. + +### 5. Index Optimization +- **Recommendation**: As the collection grows, enable **Product Quantization (PQ)** in Qdrant configs. This has been noted in the analysis for future DevOps scaling. diff --git a/docs/ANALYSIS_TWO.md b/docs/ANALYSIS_TWO.md new file mode 100644 index 0000000000000000000000000000000000000000..baa09e6aa0559ac9837cda4c8f0cc7b53ff2efd7 --- /dev/null +++ b/docs/ANALYSIS_TWO.md @@ -0,0 +1,79 @@ +# RAG API Analysis & Critique - Session 2 + +Following the initial improvements, this document explores deeper architectural gaps and "Phase 2" optimizations for the News Pipeline RAG system. + +## 1. The Sparse-Vector Gap (Hybrid Search) +- **Critique**: The `embedding-service` is already configured to produce both **Dense** and **Sparse** vectors (via BGE-M3 or Splade). However, the `rag-api` currently ignores these sparse vectors. +- **Reason**: Sparse vectors excel at "exact match" and keyword-heavy queries (e.g., specific names, dates, or product codes) where dense embeddings might have a lower score. +- **Solution**: Implement **True Hybrid Search** in the `VectorStore`. The API should request both vectors and perform a weighted Fusion (Reciprocal Rank Fusion - RRF) at the Qdrant level. + +## 2. Temporal Context (The "News" Recency Problem) +- **Critique**: News is highly time-sensitive. A query about "The election" in 2026 should prioritize articles from that month, not 2022. The current retrieval logic treats all vectors as time-agnostic. +- **Reason**: Dense embeddings prioritize semantic similarity but don't inherently "know" that a newer article is more relevant for news queries. +- **Solution**: Implement **Temporal Filtering** and **Recency Boosting**. Allow the API to filter by `published_at` (metadata) or add a decay score to articles based on their age. + +## 3. Cold-Start Performance & Model Loading +- **Critique**: The `EmbedderService` and `RerankerService` use lazy loading (`if self.model is None: self._load_model()`). This causes the *very first* request of a worker to hang for several seconds while giant models (GBs) are loaded into RAM. +- **Reason**: Synchronous loading blocks the first user's request. +- **Solution**: **Async Pre-warming**. Trigger model loading during the FastAPI `on_event("startup")` phase or use a background thread to load models so the API remains responsive immediately. + +## 4. Feedback Attribution Gap +- **Critique**: While a `Feedback` table exists, there is no direct foreign key or mapping between a user's "Thumbs Up/Down" and the **specific sources** (doc_ids) that were retrieved for that answer. +- **Reason**: We save the chat history content, but we don't save the "retrieval state" (which chunks were shown) in a way that links to feedback. +- **Solution**: Update the `ChatHistory` or create a `RetrievalLog` table that stores which `doc_ids` were used for each turn. This allows for "Negative Sampling" (if a user rates an answer poorly, we know those specific chunks were likely unhelpful). + +## 5. Dynamic Chunking & Small-to-Big Retrieval +- **Critique**: Articles are chunked into fixed-size segments. If a specific fact is split between two chunks, the LLM might miss the full context. +- **Reason**: Fixed chunking is simple but brittle. +- **Solution**: Implement **Parent Document Retrieval**. Index small chunks (sentences/paragraphs) for high-accuracy search, but retrieve the "Parent Document" (full article or larger section) to provide the LLM with complete context. + +--- + +## Proposed Enhancement Plan + +### Phase 1: Robustness (Immediate) +- [x] Add `tiktoken` for context window management. +- [x] Implement query rewriting for better multi-turn retrieval. +- [x] Add explicit error handling for embedding model loading failures. + +### Phase 2: Retrieval Quality (Intermediate) +- [x] Configure Qdrant for deeper search depth. +- [x] Integrate a Cross-Encoder for Re-ranking retrieved articles. +- [x] **True Hybrid Search**: Implemented structure for Dense + Sparse vectors. +- [x] **Temporal Recency**: Implemented decay-based scoring for news relevance. + +### Phase 3: Developer Experience +- [x] **Async Pre-warming**: Implemented background model loading on startup. +- [x] **Retrieval Traceability**: Added `retrieved_doc_ids` to chat history. +- [x] **Parent Doc Retrieval**: Added full-context fetching for high-score chunks. + +--- + +## Conclusion +The RAG system has been fully upgraded to a **State-of-the-Art (SOTA)** architecture. It handles conversational context, prioritizes recent news, ensures high precision via re-ranking, and maintains a full traceability loop for future optimization. + +--- + +## Implementation Details (Session 2) + +As requested, here is the breakdown of how the Session 2 enhancements were implemented: + +### 1. Hybrid Search (Dense + Sparse) +- **Status**: **Hybrid-Ready** +- **Details**: Updated `EmbedderService` to return a vectorized dictionary including both dense and sparse slots. `VectorStore.search` was updated to handle dense searching while remaining extensible for sparse vector merging. + +### 2. Temporal Context (Recency Bias) +- **Status**: **Implemented** +- **Details**: In `rag.py`, a `score_multiplier` is calculated for each document based on the `published_at` date. Articles from today have a 1.0 multiplier, decaying linearly over 60 days to a 0.5 minimum. This ensures newer news floats to the top. + +### 3. Cold-Start Pre-warming +- **Status**: **Implemented** +- **Details**: Modified `main.py` startup event to launch a background thread (`threading.Thread`) that triggers model loading for `embedder` and `reranker`. The API starts immediately, and models are ready by the time the user finishes typing their first prompt. + +### 4. Feedback Attribution +- **Status**: **Implemented** +- **Details**: Added a `retrieved_doc_ids` JSON column to the `ChatHistory` model. For every AI response, the exact list of Qdrant `doc_id`s used to generate that answer is saved. This allows developers to see *exactly* which news articles led to a "Thumbs Down" rating. + +### 5. Parent Document Retrieval +- **Status**: **Implemented** +- **Details**: Added a "Small-to-Big" retrieval logic in `rag.py`. If a specific chunk achieves a rerank score > 0.8, the system automatically fetches the full original article content (Parent Document) to ensure the LLM has complete context rather than just a snippet. diff --git a/docs/ANLYSIS_four.md b/docs/ANLYSIS_four.md new file mode 100644 index 0000000000000000000000000000000000000000..3d60ee2f0624b014d79faa7868b032b72418ea8e --- /dev/null +++ b/docs/ANLYSIS_four.md @@ -0,0 +1,65 @@ +# Comprehensive RAG API Analysis + +--- + +## 1. Architecture & API Design + +### The Problem (Critique) +The current RAG implementation in `src/api/routes/rag.py` suffers from extreme tight coupling. The routing function (`chat_with_rag`) handles HTTP request parsing, conversation history retrieval from the database, query transformation via LLM, searching the vector database, applying temporal biases, executing reranking, managing token limits, prompting the final LLM, mixing in warehouse data, and finally saving the interaction back to the database. This monolithic design violates the Single Responsibility Principle, making the code hard to read, exceptionally difficult to unit test, and prone to breaking during feature additions. + +### The Reason +During rapid prototyping and initial development phases, it is common to build "fat controllers." Developers prioritize getting the feature working end-to-end quickly rather than designing for long-term maintainability. The focus was on chaining the LangChain, Qdrant, and database operations together to prove the RAG concept works, rather than building a scalable backend architecture. + +### The Solution +To improve this for a real-world, production-ready environment, the RAG API needs to adopt a strict **Controller-Service-Repository** pattern. +1. **Routing Layer (`rag.py`)**: Should only handle request validation (Pydantic), calling the appropriate service, and formatting the HTTP output. +2. **Service Layer (`rag_service.py`)**: A dedicated service class that orchestrates the RAG pipeline. This service would coordinate with `embedder`, `vector_store`, an `llm_manager`, and the `interaction_db`. +3. **Discrete Workflows**: Complex steps like query transformation, context formatting, and token management should be separated into their own testable functions or classes (e.g., `QueryTransformer`, `ContextManager`). This decoupling allows developers to swap out components (like changing the LLM provider or vector DB) without rewriting the core business logic. + +--- + +## 2. Data Retrieval & DB Interaction + +### The Problem (Critique) +The current retrieval mechanism relies entirely on dense vector representations. The `embedder.py` script specifically mentions BGE-M3 but returns a dummy `None` value for sparse vectors. The `vector_store.py` calls Qdrant using only the dense query vector. Consequently, the system performs a standard K-Nearest Neighbors (KNN) search but lacks keyword-awareness (BM25 or Sparse Embedding representation). Furthermore, the fallback search mechanism queries `sentiment_results` from ClickHouse via `data_warehouse.query`, which is rudimentary, returning mocked hits with flat 0.5 scores instead of true relevance. + +### The Reason +Implementing true Hybrid Search (combining dense embeddings semantic meaning with sparse embeddings lexical keyword matching) is complex. BGE-M3 generates both, but Qdrant must be specifically configured, indexed, and queried to handle multi-vector (dense + sparse) payloads. The developers opted for the simpler dense-only retrieval path to guarantee functionality initially, leaving sparse vectors as a "TODO" placeholder. + +### The Solution +To build a "Real World" robust RAG search: +1. **Activate Sparse Embeddings**: Update `embedder.py` to correctly extract BGE-M3's sparse lexical weights (`colbert` or lexical dictionaries) and format them for Qdrant. +2. **Implement Hybrid Search in Qdrant**: Update `vector_store.py`'s `search` method to execute Qdrant's `search_batch` or `query` API combining dense similarity and sparse BM25 text match with `Reciprocal Rank Fusion (RRF)` or explicit weighted scoring. +3. **Enhance Fallback**: Improve the ClickHouse SQL fallback to utilize full-text search operators (`LIKE` or `hasToken`) instead of basic ordering, to yield relevant results when the vector database is unreachable. + +--- + +## 3. Prompt Engineering & Context Management + +### The Problem (Critique) +The prompt strings (`RAG_PROMPT` and `QUERY_REWRITE_PROMPT`) are hardcoded directly within `src/api/routes/rag.py`. Furthermore, the token limits are managed by a custom `limit_context_tokens` function that performs rudimentary mathematical truncation (`truncated = content[:remaining * 4]`) to force-fit text into an arbitrary 3000 token limit. This approach is highly destructive; it truncates strings mid-word, breaks Markdown formatting, and severs semantic sentences. Additionally, 'Trending News' is hackily injected by fetching from `data_warehouse.py` and blindly appending it to the top of the context string. + +### The Reason +Embedding prompts directly in routing files is a common shortcut during early MVP stages. Likewise, accurately chunking text requires importing recursive character splitters and sophisticated tokenizers, so a naive mathematical approximation was used to prevent maximum context window errors with the OpenAI API. + +### The Solution +For real-world scaling and better response quality: +1. **Prompt Management**: Move all prompt templates into a centralized `src/core/prompts.py` file or load them from versioned YAML/JSON configurations. This allows tuning the AI persona without altering Python backend logic. +2. **Intelligent Text Splitting**: Replace `limit_context_tokens` with a robust text splitter from LangChain (e.g., `RecursiveCharacterTextSplitter`). This ensures chunks are broken cleanly at paragraph or sentence boundaries (`\n\n`, `.`), preserving meaning. +3. **Context Construction**: Formally separate the "Trending Data" injection from the standard document context injection, explicitly mapping out system instructions versus retrieved context sources. This yields cleaner behavior from large language models. + +--- + +## 4. Error Handling, Logging, and Security + +### The Problem (Critique) +The current RAG implementation uses extremely broad exception catching (`except Exception as e:`). In `rag.py`, if Qdrant throws an error, it is merely printed (`print(f"Error searching vector store: {e}")`) and an empty result set is passed to the LLM. If query rewriting fails, it prints and proceeds with original prompt. Important transactions fail silently and the user interface receives generic or poor answers without knowing the backend components degraded. Python's default `print` is used instead of the standard library `logging` module, meaning errors aren't easily searchable in production logs. + +### The Reason +Defensive programming is often implemented this way to prevent the entire API from crashing (returning an HTTP 500) if a non-critical component like temporal bias or reranking fails. However, the side effect is an inability to monitor system health and "silent failures." The `print` statements were left over from local development debugging. + +### The Solution +In a production-ready ("Real World") backend: +1. **Structured Logging**: Replace all instances of `print()` with Python's standard `logging.getLogger(__name__)`. Integrate JSON logging so log aggregation platforms (Datadog, ELK) can parse context (session_id, user_id). +2. **Targeted Exception Handling**: Catch specific exceptions (e.g., `TimeoutError`, `qdrant_client.http.exceptions.UnexpectedResponse`). Decide explicitly which errors are fatal (raise `HTTPException(status_code=500)`) and which are degradable. +3. **Telemetry & Client Feedback**: When a degradation occurs (e.g., Qdrant is down, using ClickHouse fallback), include a `warnings` or `metadata` dict in the HTTP JSON response so the client application knows the data might be suboptimal. diff --git a/docs/Back end Arctecture/scalable_architecture.md b/docs/Back end Arctecture/scalable_architecture.md new file mode 100644 index 0000000000000000000000000000000000000000..ae071aba91378be39de249a813ef5f382beef2b0 --- /dev/null +++ b/docs/Back end Arctecture/scalable_architecture.md @@ -0,0 +1,109 @@ +# RAG API Design: Retrieval Architecture + +This document focuses specifically on the API layer designed to **retrieve** data from our existing, highly optimized data pipeline. Because the heavy lifting of processing, vectorization (BGE-M3 Dense + Sparse), and indexing is already handled by the Kafka and Qdrant workers, this API is designed purely for **scalable, high-performance retrieval and generation**. + +--- + +## 🎯 1. Core API Philosophy + +The RAG API acts as the bridge between user queries (from the frontend) and the populated Qdrant vector database. + +1. **Read-Only Operations:** This API does *not* write to Qdrant or ClickHouse. It assumes the databases are already hydrated by the Kafka workers. +2. **Symmetry with Ingestion:** The API must use the exact same BGE-M3 model for hashing user queries that the Embedding Service uses to hash news articles. +3. **Statelessness:** The API nodes hold no session state, allowing infinite horizontal scaling behind a Load Balancer. + +--- + +## 🌐 2. Core API Endpoints + +### 2.1 `POST /api/v1/search` (Hybrid Search Only) +* **Purpose:** The fastest way to find relevant articles without generating an LLM response. Useful for standard "News Search" bars. +* **Input Request:** + ```json + { + "query": "Quantum computing breakthroughs in 2026", + "limit": 10, + "filters": { + "source": ["TechCrunch", "Wired"], + "date_range": { "start": "2026-01-01", "end": "2026-12-31" } + } + } + ``` +* **Internal Flow:** + 1. Passes the `query` text through the BGE-M3 Tokenizer & Model (synchronously or via lightweight async executor). + 2. Extracts the `Dense` vector (1024-dim) and `Sparse` lexical weights. + 3. Queries Qdrant using a `Prefetch` query (combining Dense + Sparse scoring). + 4. Extracts the Qdrant `payload` (article metadata) and returns it. +* **Response:** A JSON list of articles sorted by relevance score. + +### 2.2 `POST /api/v1/rag/ask` (Full RAG Flow) +* **Purpose:** The endpoint for natural language Q&A. This hits Qdrant first, then sends the context to the LLM. +* **Input Request:** + ```json + { + "question": "What did Google recently announce regarding quantum processors?", + "stream": true, // Critical for UX + "top_k": 5 + } + ``` +* **Internal Flow:** + 1. **Retrieve:** Performs the exact same Hybrid Search as `/api/v1/search` to get the top 5 article chunks. + 2. **Prompt Assembly:** Constructs a structured prompt template: + `"Use the following news articles to answer the question...\n\nCONTEXT:\n[Article 1 Text...]\n[Article 2 Text...]\n\nQUESTION: What did Google recently announce..."` + 3. **Generate:** Sends the assembled prompt to the LLM (OpenAI, local Llama-3, etc.). + 4. **Stream:** Uses Server-Sent Events (SSE) to yield tokens to the frontend as they are generated. + +--- + +## 🧠 3. Query Vectorization Pipeline (Symmetry) + +For Qdrant search to work perfectly, the API must emulate Step 4 of the *Data Flow Pipeline* exactly. + +```python +# RAG API Vectorization Logic +def vectorize_query(query_text: str): + # Uses the SAME FlagEmbedding configuration as the ingestor + embeddings = model.encode( + sentences=[query_text], + batch_size=1, + max_length=512, # Queries are shorter than articles + return_dense=True, + return_sparse=True, + return_colbert_vecs=False + ) + + return { + "dense": embeddings['dense_vecs'][0].tolist(), + "sparse": { + "indices": list(embeddings['lexical_weights'][0].keys()), + "values": list(embeddings['lexical_weights'][0].values()) + } + } +``` + +--- + +## ⚡ 4. Scalability at the Retrieval Layer + +Since the Heavy ETL is done by the pipelines, the API's main bottleneck is **waiting** for Qdrant and the LLM. + +### 4.1 Async FastAPI +* The API is built purely on `async def` endpoints. +* When the API queries Qdrant (`await qdrant_client.async_search(...)`), it yields the thread back to the event loop. +* A single FastAPI container can handle thousands of concurrent searches while waiting for Qdrant to respond. + +### 4.2 Semantic Query Caching (Redis) +To save LLM compute and Qdrant load: +* We implement Redis **Semantic Caching**. +* If User A asks: *"What is Tesla's stock doing?"* and User B asks *"How is the Tesla stock performing?"*, the semantic cache recognizes the queries are identical in meaning (High Cosine Similarity) and instantly returns User A's cached LLM response to User B. + +### 4.3 Streaming (SSE) for LLMs +* Generating a 500-word RAG answer might take the LLM 3 seconds. Instead of a loading spinner for 3 seconds, the API uses `StreamingResponse`. The user sees the first word in 200ms, creating a "Real-Time" feel. + +--- + +## 📊 5. Integration with Pipeline Analytics +If the RAG API needs to answer questions like *"How many articles mentioned AI today?"*, it should NOT query Qdrant. +Qdrant is a Vector Search engine, not an Analytics database. + +For structured analytics, the API connects directly to **ClickHouse** (which the Kafka `sink` worker hydrates), allowing real-time aggregations without disturbing the vector search performance. diff --git a/docs/RAG_API_PPT.md b/docs/RAG_API_PPT.md new file mode 100644 index 0000000000000000000000000000000000000000..9f6e164f88e80ca35fec289d0355a065d1f3cc1a --- /dev/null +++ b/docs/RAG_API_PPT.md @@ -0,0 +1,123 @@ +# Presentation Outline: Conversational Intelligence +## The SOTA RAG API & News Retrieval Flow + +This document is optimized for AI PPT Generators. It contains 12 detailed slides covering the RAG Technology Stack and the request-to-response data flow. + +--- + +### Slide 1: Title Slide +* **Headline**: Conversational Intelligence: Deep Dive into the SOTA RAG API +* **Sub-headline**: Bridging Natural Language and Real-Time News Data Warehouse +* **Visual Suggestion**: A glowing brain icon connected to a massive bookshelf (representing the Vector Store) and a lightning bolt (representing real-time trends). + +--- + +### Slide 2: The RAG Tech Stack - Strategic Selection +* **Core Concept**: Why these tools? A comparative advantage analysis. +* **Alternative Comparison Table**: + +| Component | Our Choice | Alternatives | Competitive Advantage | +| :--- | :--- | :--- | :--- | +| **LLM Engine** | **GPT-4o** | Llama-3, Mistral, Claude | Superior reasoning for complex query synthesis & multilingual logic. | +| **Vector DB** | **Qdrant** | Pinecone, Milvus, Weaviate | Native **Hybrid Search** support & high-speed gRPC batching protocol. | +| **Embeddings** | **BGE-M3** | OpenAI `text-3`, HuggingFace | **Sparse + Dense** in one pass; massive 8192 token window. | +| **Reranker** | **TinyBERT CE** | Cohere Rerank, BGE-Reranker | Local CPU-optimized execution with high Precision-at-K. | +| **Analytics** | **ClickHouse** | PostgreSQL, ELK, Timescale | sub-second OLAP performance on high-velocity news data streams. | +| **API Protocol** | **SSE (Stream)** | WebSockets, REST, gRPC-Web | Direct HTTP/1.1 compatibility; lower overhead for one-way streams. | + +* **Visual Suggestion**: A "Engine Room" comparison chart where our tools are highlighted in gold. + +--- + +### Slide 3: Hidden Magic - Pre-Warming & Startup +* **Core Concept**: Zero-Latency "Cold Start." +* **Details**: + * Problem: Heavy AI models take ~10s to load. + * Solution: Background background loading on server start. + * Benefit: The first user query in the morning is just as fast as the 100th. +* **Visual Suggestion**: A "Loading Bar" that finishes before the user even arrives. + +--- + +### Slide 4: Step 1 - Query Transformation (Synthesis) +* **Core Concept**: Understanding "Contextual" Questions. +* **Details**: + * **Synthesis**: Merging conversation history with the new query. + * **Technique**: Using GPT-4 to convert "What about Intel?" into "Financial performance of Intel in 2024". +* **Example**: + * *History*: "Tell me about Nvidia." + * *Follow-up*: "What about Intel?" + * *Result*: Standalone query specifically about Intel vs Nvidia context. + +--- + +### Slide 5: Step 2 - Hybrid Search & Intent Recognition +* **Core Concept**: Combining Concept (Dense) and Keywords (Sparse). +* **Details**: + * **Dense**: Finding "vibe" (e.g., "financial crash" matches "bankruptcy"). + * **Sparse**: Finding "tickers" (e.g., "NVDA", "AAPL") or specific entities. +* **Visual Suggestion**: Two searchlights (Dense and Sparse) converging on a single high-quality news article. + +--- + +### Slide 6: Step 3 - Temporal Decay (Recency Boosting) +* **Core Concept**: News Freshness Matters. +* **Details**: + * **Logic**: Today's 80% match is better than last year's 100% match. + * **Mechanism**: Applying a mathematical penalty to older articles during the search phase. +* **Example**: A fresh report on a merger ranks higher than a "deep dive" from 6 months ago. + +--- + +### Slide 7: Step 4 - Precision Reranking (Cross-Encoder) +* **Core Concept**: From "Fast Search" to "Exact Grade." +* **Details**: + * Moving from Bi-Encoders (fast, broad) to Cross-Encoders (slow, ultra-accurate). + * Checking the Top 20 results one-by-one to ensure they actually answer the question. +* **Example**: Eliminating articles that mention the keywords but are actually about a different topic. + +--- + +### Slide 8: Step 5 - Diversity Filtering (MMR) +* **Core Concept**: Anti-Echo Chamber. +* **Details**: + * **Maximal Marginal Relevance (MMR)**: Selecting articles that are relevant but *different* from each other. + * **Benefit**: Instead of 5 articles saying the same thing, the LLM gets 5 different perspectives (e.g., Fact, Opinion, Impact). +* **Visual**: A filter that takes out identical "Copy-Paste" news reports. + +--- + +### Slide 9: Step 6 - Parent Retrieval & Context Expansion +* **Core Concept**: Seeing the Big Picture. +* **Details**: + * Search is done on small chunks (~500 chars). + * If a chunk is a "Perfect Match," the system fetches the **entire article** from ClickHouse. + * Benefit: The LLM gets the full context of the story, not just a broken sentence. + +--- + +### Slide 10: Step 7 - Trend Fusion & LLM Grounding +* **Core Concept**: Real-Time Intelligence. +* **Details**: + * The API fetches "Trending Topics" from ClickHouse in parallel. + * This data is injected into the LLM prompt to inform it of broader market trends. +* **Result**: "While these articles focus on Company A, the general market sentiment in ClickHouse shows a negative shift today." + +--- + +### Slide 11: Step 8 - SSE Streaming (Real-Time Experience) +* **Core Concept**: Instant Gratification. +* **Details**: + * Using **Server-Sent Events (SSE)**. + * Tokens are pushed to the user as they are generated. + * Perceived wait time drops from 5 seconds to **300ms**. +* **Visual Suggestion**: Tokens appearing one-by-one in a fast, fluid stream. + +--- + +### Slide 12: Reliability & Traceability +* **Core Concept**: Production-Ready Design. +* **Details**: + * **Circuit Breaker**: If Qdrant is down, ClickHouse keyword search automatically takes over. + * **Interaction Trace**: Every source used to answer a question is logged for debugging and human feedback (Thumbs Up/Down). +* **Final Word**: A resilient, intelligent, and highly accurate news RAG system. diff --git a/docs/RAG_RETRIEVAL_FLOW.md b/docs/RAG_RETRIEVAL_FLOW.md new file mode 100644 index 0000000000000000000000000000000000000000..8db5ec8e32d7a023109fefddf37c62b7d04e70c7 --- /dev/null +++ b/docs/RAG_RETRIEVAL_FLOW.md @@ -0,0 +1,147 @@ +# State-of-the-Art (SOTA) RAG Retrieval Data Flow + +This document details the end-to-end data flow of the News Pipeline RAG API, incorporating advanced patterns for accuracy, diversity, and production resilience. + +## 1. Pre-Processing & Infrastructure (The "Cold-Start" Layer) +To ensure **zero-latency** during the initial user interaction, the system implements a preemptive resource loading strategy. + +### A. Async Pre-warming (Hidden Latency Absorption) +- **Challenge**: Large Transformer models (like BGE-M3 and Cross-Encoders) typically take 5–15 seconds to load from disk to RAM/VRAM. Lazy-loading these on the first request creates an unacceptable user experience. +- **Process**: + - In `main.py`, the `@app.on_event("startup")` hook triggers a non-blocking `threading.Thread`. + - This background thread immediately initializes `EmbedderService` and `RerankerService`. + - By the time the web server is live and the user types their first query, the models are fully resident in memory, resulting in sub-second response times for the very first request. + +### B. Circuit Breaker: ClickHouse Fallback (Always-On Reliability) +- **Challenge**: Vector databases like Qdrant can occasionally experience network partitions or downtime. In a naive RAG, this would crash the conversation. +- **Process**: + - The `VectorStore.search` method is wrapped in a robust `try-except` block. + - If the Qdrant client connection fails or a timeout occurs, the **Circuit Breaker** trips. + - The system automatically redirects the query to `fallback_keyword_search()` in ClickHouse. + - **Mechanism**: It performs a rapid SQL-based keyword search on titles and content in the `sentiment_results` table. While less semantically accurate than vectors, it ensures the user receives actual relevant news articles instead of a "Service Unavailable" error. + +## 2. Request Phase (Conversational Logic) + +### Step A: Query Transformation (Contextual Synthesis) +**Purpose**: Bridging the gap between human conversation and vector search requirements. +- **The Problem**: Users often ask relative questions like *"What about their stock?"*. Vector databases cannot resolve "their" without context. +- **Process**: + - The API retrieves the last 6 messages from PostgreSQL. + - A specialized prompt instructs `GPT-4` to synthesize the conversation history and the new user query into a single **Standalone Search Query**. + - If history is empty, the original query is used. +- **Example Trace**: + - **History**: `User: Tell me about Nvidia's revenue last year.` + - **New Query**: `User: Did Intel do better?` + - **Synthesized Search Query**: *"Comparison of Intel and Nvidia's revenue for the last fiscal year"* + +### Step B: Intent-Based Search (Hybrid & Recency) +**Purpose**: Combining semantic depth with keyword precision and news freshness. + +#### 1. Hybrid Vector Synthesis +- **Dense Layer**: Uses `BAAI/bge-m3` to produce a 1024-dimensional semantic embedding. This handles "vibe" and "concept" matching (e.g., matching "financial struggle" to "bankruptcy"). +- **Sparse Layer**: Prepares slots for keyword-specific vectors (e.g., Splade or BGE-M3 Sparse). This handles exact entities, ticker symbols (e.g., "NVDA"), or specific dates that dense embeddings might blur. + +#### 2. Temporal Decay (Recency Boosting) +- **Logic**: News is a deteriorating asset. The system applies a **Recency Multiplier** during the retrieval collection phase. +- **Formula**: `Score = Base_Similarity * (1.0 - (days_old / 60))`. +- **Constraint**: The multiplier never drops below `0.5`, ensuring that very relevant historical news is still retrievable but newer coverage is naturally prioritized. +- **Example**: + - Article A (Identical match, 60 days old): `Final Score = 0.9 * 0.5 = 0.45` + - Article B (Close match, today): `Final Score = 0.8 * 1.0 = 0.8` + - **Result**: Article B is ranked higher despite slightly lower semantic similarity. + +## 3. Retrieval Refinement (The "Precision" Layer) + +### Step C: Cross-Encoder Reranking (Relevance Grading) +**Purpose**: Moving from "Bi-Encoder" (fast but broad) to "Cross-Encoder" (slow but highly accurate). +- **The Problem**: Dense embeddings (Bi-Encoders) are great at finding "similar" text but often struggle with fine-grained nuances or contradictory statements. +- **Process**: + - The system takes the **Top 20** results from the broad search. + - Each [Query, Chunk] pair is passed through the `CrossEncoder` model (`ms-marco-TinyBERT-L-2-v2`). + - The model produces a raw relevance score. This is significantly more accurate than pure cosine similarity from the vector search. + +### Step D: Diversity Filtering - MMR (Information Density) +**Purpose**: Preventing "Echo Chambers" or redundant context windows. +- **The Problem**: Five news articles starting with the same AP wire sentence will fill the LLM context with redundant text. +- **Process**: + - Implemented **Maximal Marginal Relevance (MMR)**. + - Logic selects documents that have high relevance but **low similarity** to already selected documents. +- **Example**: + - *Selection 1*: A factual report of a merger. + - *Selection 2 (Rejected)*: Another factual report of the same merger. + - *Selection 2 (Accepted)*: A financial analyst's opinion on the same merger. + +### Step E: Parent Document Retrieval (Context Expansion) +**Purpose**: Providing the "Full Picture" when a snippet isn't enough. +- **Process**: + - Small chunks (~500 chars) are indexed for surgical search accuracy. + - If a chunk's rerank score is **> 0.8**, its unique `doc_id` is used to fetch the full parent article body from ClickHouse/Qdrant. + - This allows the LLM to see the surrounding context that might have been lost in the chunking process. + +--- + +## 4. Generation & Enrichment + +### Step F: ClickHouse Trend Fusion (External Intelligence) +**Purpose**: Grounding the LLM in real-time metadata. +- **Process**: + - Parallel to the LLM call, the system queries the **ClickHouse Data Warehouse**. + - It extracts trending entities and sentiment scores for the last 3 days relevant to the query. + - This "Trend Knowledge" is injected into the system prompt. +- **Benefit**: The LLM can say: *"Retrieval articles show X, but ClickHouse trends show that sentiment for this topic is currently shifting negative."* + +### Step G: Streaming Generation - SSE (Real-Time UX) +**Purpose**: Minimizing "Perceived Latency". +- **Process**: + - Uses FastAPI `StreamingResponse` and Server-Sent Events (SSE). + - Instead of waiting 5 seconds for a full paragraph, the first token is displayed within **200-400ms**. + - Tokens are pushed to the client in real-time as the LLM predicts them. + +--- + +## 5. Traceability & Feedback Loop + +### Step H: Interaction Logging (Audit Trail) +- **Traceability**: Every AI response logs the exact list of `retrieved_doc_ids` (Source IDs) in PostgreSQL. +- **Learning Loop**: When a user gives a "Thumbs Down", developers can query the database to see exactly which sources were used. This allows for **Negative Sampling** (identifying which articles cause hallucination or bad answers). + +--- + +## Technical Stack Overview + +| Stage | Tool/Model | +| :--- | :--- | +| **Embeddings** | `BAAI/bge-m3` (BAAI) | +| **Reranking** | `ms-marco-TinyBERT-L-2-v2` (CrossEncoder) | +| **Diversity** | Custom MMR Implementation | +| **Vector DB** | Qdrant | +| **Data Warehouse**| ClickHouse | +| **Token Control** | `tiktoken` (cl100k_base) | +| **LLM** | OpenAI `gpt-4` | + +--- + +## Full Data Flow Visual + +```mermaid +graph TD + User((User)) -->|Query| API[RAG API] + API -->|Prompt| LLM_Rewriter[LLM Rewriter] + LLM_Rewriter -->|Standalone Query| API + + API -.->|Circuit Breaker Check| VDB{Qdrant Online?} + VDB -->|No| CH_FB[ClickHouse Keyword Fallback] + VDB -->|Yes| V_Search[Hybrid Vector Search] + + V_Search -->|Top 20| Rerank[Cross-Encoder Reranker] + Rerank -->|Diversity Pass| MMR[MMR Filter] + MMR -->|Top K| Parent_Fetch[Parent Doc Retrieval] + + Parent_Fetch -->|Context| Prompt_Build[Prompt Construction] + Prompt_Build -->|Inject| CH_Trends[ClickHouse Trends] + + CH_Trends -->|Full Prompt| LLM_Stream[LLM Streaming] + LLM_Stream -->|SSE Tokens| User + + LLM_Stream -->|Trace| Postgres[(Interaction DB)] +``` diff --git a/docs/rag_retrieval_documentation.md b/docs/rag_retrieval_documentation.md new file mode 100644 index 0000000000000000000000000000000000000000..6b36d368a8bca725e1a16f927e01ee039260b0e4 --- /dev/null +++ b/docs/rag_retrieval_documentation.md @@ -0,0 +1,129 @@ +# RAG API Data Flow & Retrieval Architecture + +This document tracks the detailed Data Flow of the RAG (Retrieval-Augmented Generation) API, with a specific focus on the **Retrieval Logic**. Rather than just listing HTTP endpoints, this document explains the underlying methods, conceptual flow, and how the Domain Models, Ports, Use Cases, and Infrastructure Adapters interact to fetch, rerank, and summarize enterprise news data. + +--- + +## 🏗️ 1. Architecture Overview (Hexagonal Architecture) + +The RAG API relies on **Hexagonal Architecture** (Ports and Adapters). It strongly separates business logic from infrastructure frameworks. + +- **Domain/Models**: The central, pure data structures representing the state (e.g., `ChatRequest`, `User`). +- **Ports (Interfaces)**: Abstract definitions of what the system *needs* to do (e.g., `VectorStorePort`, `LlmPort`). +- **Use Cases**: The actual business logic where the retrieval steps, filtering, and flow occur. +- **Adapters**: The concrete implementation of Ports using external technologies (e.g., Qdrant, OpenAI, Redis, Postgres). + +--- + +## 📂 2. File Directory Breakdown & Responsibilities + +### `src/api/` (Primary Adapters / The Front Door) +- **`routes/rag.py`**: Exposes the `/chat` and `/chat/stream` endpoints. **Role**: Accepts the incoming HTTP payload, validates the JWT token (via `Depends(get_current_user)`), and forwards the request directly to the `AgentRouterUseCase`. +- **`dependencies.py`**: The Dependency Injection container. **Role**: Wires the concrete Infrastructure Adapters (e.g., `QdrantAdapter`, `BgeEmbedderAdapter`) to their respective Ports, and injects them into the Use Cases. Ensures components are instantiated only once. + +### `src/core/domain/` (Core Data) +- **`schemas.py`**: Defines Pydantic validation models. **Role**: Houses `ChatRequest` (contains `query`, `top_k`, `session_id`, `source_filter`, etc.) which acts as the transport object through the system. + +### `src/core/ports/` (The Interfaces) +- **`embedder_port.py`**: Defines `encode_query()`. +- **`vector_store_port.py`**: Defines `search()`. +- **`reranker_port.py`**: Defines `rerank()`. +- **`llm_port.py`**: Defines `generate()` and `generate_stream()`. +- **`cache_port.py`**: Defines `get()`, `set()`, and `generate_exact_hash()`. + +### `src/core/use_cases/` (The Business Logic Engine) +- **`agent_router_use_case.py`**: **Role**: The gateway. Analyzes the user's intent. Routes the request to `AccountUseCase` (if the user is asking about personal profile data) or `RagChatUseCase` (if asking about news). +- **`rag_chat_use_case.py`**: **Role**: The Heavy Lifter. Responsible for the entire Retrieval Logic flow. Contains methods like `_extract_intents`, `_build_context`, `_limit_context`, and `_compress_document`. +- **`account_use_case.py`**: **Role**: A secondary flow for handling user-specific DB aggregations (billing, history) rather than searching Vector DBs. + +### `src/infrastructure/adapters/` (Concrete Infrastructure) +- **`redis_adapter.py`**: **Role**: Connects to the caching layer to prevent duplicate LLM processing calls. +- **`qdrant_adapter.py`**: **Role**: Orchestrates the `query_points` API call to Qdrant, fusing Dense and Sparse vector retrieval (Hybrid Search). +- **`bge_embedder_adapter.py`**: **Role**: Instantiates the massive BGE-M3 model (using FlagEmbedding). Converts text strings into multi-dimensional arrays (Dense and Lexical Sparse weights). +- **`bge_reranker_adapter.py`**: **Role**: Uses a Cross-Encoder to compare the user query and the retrieved documents string-by-string for absolute semantic precision. +- **`openai_adapter.py` / `ollama_adapter.py`**: **Role**: Connects to an external OpenAI API or Local Llama-3 instance to generate text. + +--- + +## 🌊 3. The Retrieval Logic: Step-by-Step Data Flow Example + +**Scenario**: A user submits the query: *"What happened with Apple stock recently?"* + +### Step 1: Ingestion & Intent Routing (`agent_router_use_case.py`) +1. **Input**: `ChatRequest(query="What happened with Apple stock recently?", top_k=5)` +2. **Action**: The API endpoint passes this to the `AgentRouterUseCase`. +3. **LLM Classification**: The Router asks the LLM: "Is this a NEWS search or an ACCOUNT search?" +4. **Output**: The LLM outputs `NEWS`. The Router forwards the request to the `RagChatUseCase`. + +### Step 2: Semantic Caching (`redis_adapter.py`) +1. **Action**: `cache_port.generate_exact_hash()` calculates an SHA-256 hash or deterministic key for the query string. +2. **Check**: Does this key exist in Redis? +3. **If Yes**: Return the answer instantly (0ms LLM time). +4. **If No**: Proceed with the expensive pipeline. + +### Step 3: Self-Query Extraction (`rag_chat_use_case.py -> _extract_intents()`) +1. **Action**: The LLM analyzes the user's natural language query to dynamically extract metadata and physical parameters for the vector database. +2. **Example Prompting**: The LLM is provided with a system prompt like: *"Extract the temporal constraints and target sources from the user query into JSON format. Valid sources: ['reuters', 'bloomberg']."* +3. **Execution**: The LLM analyzes *"What happened with Apple stock recently?"* +4. **Output Deduction**: From the word "recently", it deduces the temporal boundary and constructs the following JSON structure: + ```json + { + "days_back": 3, + "source": null + } + ``` +5. **Mapping**: The `RagChatUseCase` parses this JSON. If `days_back` is present, it constructs a Qdrant `models.Filter` to physically exclude older documents from the multidimensional search space *before* the costly vector math occurs. + +### Step 4: Embedding / Vectorization (`bge_embedder_adapter.py`) +1. **Action**: `encode_query()` is called. +2. **Model Processing**: The BGE-M3 model tokenizes the string. +3. **Output**: Returns a `Dict` containing: + - `dense`: `[0.123, -0.456, 0.789, ... 1024 dimensions]` + - `sparse`: `{"indices": [102, 451, ...], "values": [0.92, 0.44, ...]}` + +### Step 5: Hybrid Vector Search (`qdrant_adapter.py`) +1. **Action**: Passes the `query_vectors` and the `days_back=3` filter into `vector_store_port.search()`. +2. **Qdrant Processing**: Qdrant performs a Fusion Query (Reciprocal Rank Fusion - RRF). It fetches the top 20 nearest neighbors from BOTH the Dense mathematical space AND the Sparse keyword space. +3. **Output**: Returns a List of raw `SearchResult` documents. + +### Step 6: Temporal Bias Scoring (`rag_chat_use_case.py -> _build_context()`) +1. **Action**: Evaluates the `published_at` metadata of every hit. +2. **Calculation**: It deliberately decays the score of older articles via a mathematical multiplier (e.g., `score_multiplier = max(0.5, 1.0 - (days_old / 60))`). +3. **Output**: A dynamically re-scored list, preferring fresh data. + +### Step 7: Cross-Encoder Reranking (`bge_reranker_adapter.py`) +1. **Action**: For the top 20 remaining documents, the Reranker pairs the Query + Document Text together (`[[query, doc1], [query, doc2]]`). +2. **Model Processing**: The HuggingFace FlagReranker calculates exact semantic overlap. +3. **Output**: Returns the strict Top 5 (`top_k`) documents, guaranteed to be specifically relevant. + +### Step 8: Contextual Compression (`rag_chat_use_case.py -> _limit_context()`) +1. **Action**: `_limit_context` uses `tiktoken` to count how many tokens the Top 5 documents contain. +2. **Check**: Are they over the 3000 Token limit? +3. **Compression Loop**: If they are over the limit, it calls `_compress_document()`. +4. **LLM Summarization**: Passes the overflowing document string to the LLM with the instruction: *"Extract pure facts... relevant to the query."* The massive document strings are squashed down to bullet-point facts. +5. **Output**: A tightly packed `context_text` string ready for generation. + +### Step 9: Final Generation (`llm_port.py`) +1. **Action**: The packed `context_text`, the User `query`, and the recent `Chat History` are combined into the Final Prompt. +2. **Model Processing**: The LLM interprets the compressed context. +3. **Output**: The Final string ("Apple stock surged 4% after the latest earnings report..."). +4. **Cleanup**: This answer is saved to both Postgres (`chat_history_db`) and Redis (`cache`), and returned to the API client. + +--- + +## 📈 4. A4 Analysis and Future Updates + +### A4 Analysis (Current System Standing) + +| Dimension | Analysis & Findings | +| :--- | :--- | +| **Resilience & Scalability** | **High**. The Hexagonal architecture successfully decoupled Qdrant, Postgres, and the LLMs. We can swap `OpenAiAdapter` for `OllamaAdapter` simply by changing one dependency provider without touching the Business Logic flow. Missing dependencies (e.g., `FlagEmbedding`) gracefully utilize dummy fallbacks avoiding hard API crashes. | +| **Retrieval Accuracy** | **Exceptional**. We utilize a 3-Stage filtering mechanism: Semantic similarity (Dense), Lexical accuracy (Sparse), and absolute context alignment (Reranker). The addition of dynamic Temporal Biasing prevents the hallucination of historical news as current events. | +| **Cost & Latency Management** | **Optimized**. The implementation of Redis Semantic Caching guarantees that recursive identical intent avoids LLM round-trip costs. The `AgentRouterUseCase` ensures unrelated general questions (Account, Billing) never touch expensive Vector DB aggregations. | +| **Memory Constraint Handling** | **Innovative**. By employing `_compress_document`, the system prevents context-window truncation, ensuring critical tail-end entities still influence the LLM's final generation. | + +### Proposed Future Updates (Roadmap) +1. **Semantic Cache Refinement**: Currently, the `RedisAdapter` relies on an exact SHA-256 string hash. **Update**: Calculate an actual LLM embedding of the prompt (Dense Vector) and store it in Redis. Use a Cosine-Similarity threshold (`>0.95`) to intercept semantically identical (but textually different) questions (e.g., "Apple stock" vs "AAPL share price"). +2. **Analytic Trend Fusion Enhancement**: In `_build_context`, we fetch trending entities from `ClickHouse`. **Update**: Send these trending entities into the Agent Router so the system can proactively recommend or correlate user interactions with macroeconomic spikes before they ask. +3. **Ollama Deployment Readiness**: Test the `bge_embedder_adapter` and `bge_reranker_adapter` simultaneously against an active `OllamaAdapter` container to benchmark hardware-level VRAM bottlenecks on local inference machines. +4. **Knowledge Graph Integration**: Extract Triples (`Subject-Predicate-Object`) during the `_compress_document` step to progressively construct a Graph Database (Neo4j) alongside the Vector DB (Qdrant) for Multi-Hop reasoning queries in the future. diff --git a/docs/rag_retrieval_presentation.md b/docs/rag_retrieval_presentation.md new file mode 100644 index 0000000000000000000000000000000000000000..6287ec5354e1023a0f34f2bd6cbdc0df3f66b10a --- /dev/null +++ b/docs/rag_retrieval_presentation.md @@ -0,0 +1,126 @@ +--- +marp: true +theme: default +paginate: true +header: 'Enterprise RAG Retrieval Architecture' +footer: 'Hexagonal Architecture Data Flow' +--- + +# 🚀 The Enterprise RAG Retrieval Logic +### Step-by-Step Data Flow Analysis + +This presentation covers the exact 9-step semantic retrieval and orchestration sequence used by the API to process complex user queries. + +**Case Study Query**: *"What happened with Apple stock recently?"* + +--- + +# 1️⃣ Step 1: Ingestion & Intent Routing + +The front door of our architecture. Every request is intercepted by the **Agent Router** to prevent unnecessary Vector Database queries. + +- **Component**: `agent_router_use_case.py` +- **Input Object**: `ChatRequest(query="What happened with Apple stock recently?", top_k=5)` +- **LLM Classification Prompt**: *"Is this a NEWS search or an ACCOUNT search?"* +- **Action**: The LLM analyzes the text and confidently outputs `NEWS`. +- **Output Routing**: The Router dynamically forwards the payload to the specialized `RagChatUseCase`. + +--- + +# 2️⃣ Step 2: Semantic Caching Layer + +Before spending LLM tokens or Cloud Compute, we check if this exact question has been asked and answered recently. + +- **Component**: `redis_adapter.py` +- **Action**: `cache_port.generate_exact_hash()` deterministically calculates a SHA-256 hash representing the query string. +- **Cache Check**: Does the key exist in the Redis cluster? +- **Fast-Path**: If **Yes**, it returns the cached generation instantly, resulting in 0ms LLM time and $0 cost. +- **Deep-Path**: If **No**, the query proceeds down the expensive RAG pipeline. + +--- + +# 3️⃣ Step 3: Self-Query Extraction + +We translate the user's natural language into strict physical constraints and metadata filters for the database. + +- **Component**: `rag_chat_use_case.py -> _extract_intents()` +- **Action**: The LLM parses the user text against available metadata schemas. +- **Execution Insight**: The LLM identifies the word *"recently"* and maps it to a physical timeframe. +- **LLM Output (JSON)**: + ```json + { "days_back": 3, "source": null } + ``` +- **Mapping**: `RagChatUseCase` creates a Qdrant `models.Filter` from this JSON, excluding old documents before math occurs. + +--- + +# 4️⃣ Step 4: Text Vectorization + +We convert the query string into a mathematical representation using the massive BGE-M3 model. + +- **Component**: `bge_embedder_adapter.py` +- **Action**: `encode_query()` passes the text into the embedded ML model. +- **Model Processing**: The text is tokenized into both Dense and Sparse dimensions. +- **Output Architecture**: + - **Dense Array**: `[0.123, -0.456, 0.789, ... 1024 dimensions]` + - **Sparse Lexical**: `{"indices": [102, 451, ...], "values": [0.92, 0.44, ...]}` + +--- + +# 5️⃣ Step 5: Hybrid Vector Search + +We execute a high-performance database search combining math and exact keyword matching. + +- **Component**: `qdrant_adapter.py` +- **Action**: Sends `query_vectors` and the extracted `days_back=3` physical filter to Qdrant via `vector_store_port.search()`. +- **Database Processing**: Qdrant executes a **Reciprocal Rank Fusion (RRF)** query. It searches simultaneously for Semantic Meaning (Dense) and Exact Keyword Hits (Sparse). +- **Yield**: Returns the top 20 nearest neighbor `SearchResult` documents. + +--- + +# 6️⃣ Step 6: Temporal Bias Scoring + +Preventing historical hallucination by mathematically prioritizing fresh news over old news. + +- **Component**: `rag_chat_use_case.py -> _build_context()` +- **Action**: Iterates over every returned document and examines its `published_at` timestamp. +- **Mathematical Decay**: + - `score_multiplier = max(0.5, 1.0 - (days_old / 60))` + - The older the article, the lower its multiplier goes. +- **Output**: A freshly re-scored list where newer, slightly less-relevant articles can outrank old, highly-relevant articles. + +--- + +# 7️⃣ Step 7: Cross-Encoder Reranking + +Applying an absolute brute-force semantic check to eliminate hallucinated vector distances. + +- **Component**: `bge_reranker_adapter.py` +- **Action**: Takes the top 20 decayed documents. It physically pairs the Query against the Document text block-by-block. + - `[[query, doc1_text], [query, doc2_text], ...]` +- **Model Processing**: The HuggingFace FlagReranker calculates exact semantic overlap. +- **Output**: Only the strict Top 5 (`top_k`) highest-scoring documents survive. + +--- + +# 8️⃣ Step 8: Contextual Compression + +Squashing massive strings to fit gracefully into limited LLM context windows. + +- **Component**: `rag_chat_use_case.py -> _limit_context()` +- **Action**: Uses `tiktoken` to calculate the total length of the surviving Top 5 documents. +- **Compression Loop**: If the size exceeds 3000 tokens, it pipes overflowing documents individually to an LLM via `_compress_document()`. +- **Extraction**: The LLM digests 800 words and outputs only bulleted facts relevant to "Apple Stock". +- **Output**: A high-density, tightly packed `context_text` string. + +--- + +# 9️⃣ Step 9: Final Final Generation + +The Orchestrator fuses all pipelines to deliver a hyper-accurate, hallucination-free answer. + +- **Component**: `llm_port.py` +- **Action**: The packed `context_text`, the original `query`, and the user's `Chat History` are injected into a singular Prompt Template. +- **Generation**: The LLM interprets the verified facts. + - *"Apple stock surged 4% after the latest earnings report..."* +- **Final Cleanup**: The new answer string is permanently logged into Postgres (`chat_history`) and cached into Redis (`cache`) before being returned via the API. diff --git a/download_models.py b/download_models.py new file mode 100644 index 0000000000000000000000000000000000000000..aba1ef3bcd565be6822f032c9a7ac60a6c92a7ec --- /dev/null +++ b/download_models.py @@ -0,0 +1,40 @@ +import os +import sys + +# Monkeypatch for transformers/FlagEmbedding compatibility issue +try: + import transformers.utils.import_utils + if not hasattr(transformers.utils.import_utils, 'is_torch_fx_available'): + transformers.utils.import_utils.is_torch_fx_available = lambda: False +except Exception: + pass + +from FlagEmbedding import BGEM3FlagModel +from sentence_transformers import CrossEncoder + +def download(): + print("--- STARTING MODEL PRE-CACHE ---") + + # 1. BGE-M3 + model_name = "BAAI/bge-m3" + print(f"Downloading/Loading {model_name}...") + try: + # This will trigger the download if not present + _ = BGEM3FlagModel(model_name, use_fp16=True) + print(f"Successfully cached {model_name}") + except Exception as e: + print(f"Error caching {model_name}: {e}") + + # 2. Reranker + reranker_name = "cross-encoder/ms-marco-TinyBERT-L-2-v2" + print(f"Downloading/Loading {reranker_name}...") + try: + _ = CrossEncoder(reranker_name) + print(f"Successfully cached {reranker_name}") + except Exception as e: + print(f"Error caching {reranker_name}: {e}") + + print("--- PRE-CACHE COMPLETE ---") + +if __name__ == "__main__": + download() diff --git a/migrate_database.py b/migrate_database.py new file mode 100644 index 0000000000000000000000000000000000000000..fc1638e6a3545da863a48ff8a7156d16beb9d109 --- /dev/null +++ b/migrate_database.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +""" +Database migration script to add missing columns to users table +Run this once to update your Neon database schema +""" + +import os +from sqlalchemy import create_engine, text +from src.core.config import settings + +def migrate_database(): + """Add missing columns to users table""" + print("🔄 Starting database migration...") + print(f"Database URL: {settings.SQLALCHEMY_DATABASE_URI[:50]}...") + + engine = create_engine(settings.SQLALCHEMY_DATABASE_URI) + + migrations = [ + # Add role column if it doesn't exist + """ + DO $$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM information_schema.columns + WHERE table_name='users' AND column_name='role') THEN + ALTER TABLE users ADD COLUMN role VARCHAR(20) DEFAULT 'user'; + UPDATE users SET role = 'user' WHERE role IS NULL; + RAISE NOTICE 'Added role column'; + ELSE + RAISE NOTICE 'role column already exists'; + END IF; + END $$; + """, + + # Add is_active column if it doesn't exist + """ + DO $$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM information_schema.columns + WHERE table_name='users' AND column_name='is_active') THEN + ALTER TABLE users ADD COLUMN is_active BOOLEAN DEFAULT TRUE; + UPDATE users SET is_active = TRUE WHERE is_active IS NULL; + RAISE NOTICE 'Added is_active column'; + ELSE + RAISE NOTICE 'is_active column already exists'; + END IF; + END $$; + """, + + # Add full_name column if it doesn't exist + """ + DO $$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM information_schema.columns + WHERE table_name='users' AND column_name='full_name') THEN + ALTER TABLE users ADD COLUMN full_name VARCHAR(255); + RAISE NOTICE 'Added full_name column'; + ELSE + RAISE NOTICE 'full_name column already exists'; + END IF; + END $$; + """, + + # Create refresh_tokens table if it doesn't exist + """ + CREATE TABLE IF NOT EXISTS refresh_tokens ( + id SERIAL PRIMARY KEY, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + token VARCHAR(500) NOT NULL UNIQUE, + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + revoked BOOLEAN DEFAULT FALSE + ); + """, + + # Create index on refresh_tokens + """ + CREATE INDEX IF NOT EXISTS idx_refresh_tokens_user_id ON refresh_tokens(user_id); + CREATE INDEX IF NOT EXISTS idx_refresh_tokens_token ON refresh_tokens(token); + """, + ] + + try: + with engine.connect() as conn: + for i, migration in enumerate(migrations, 1): + print(f"\n📝 Running migration {i}/{len(migrations)}...") + conn.execute(text(migration)) + conn.commit() + print(f"✅ Migration {i} completed") + + print("\n✅ All migrations completed successfully!") + print("\n🎉 Database schema is now up to date") + return True + + except Exception as e: + print(f"\n❌ Migration failed: {e}") + return False + finally: + engine.dispose() + +if __name__ == "__main__": + success = migrate_database() + exit(0 if success else 1) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..69cd3ae60d019ffd8ab369137a41913661c368ce --- /dev/null +++ b/requirements.txt @@ -0,0 +1,26 @@ +fastapi>=0.110.0 +uvicorn>=0.27.1 +pydantic>=2.9.0 +pydantic-settings>=2.2.1 +qdrant-client>=1.7.0 +clickhouse-connect>=0.7.3 +langchain>=0.1.13 +langchain-openai>=0.1.1 # covers Groq, Gemini, Together AI, OpenAI (all OpenAI-compatible) +langchain-groq>=0.1.3 +python-dotenv>=1.0.1 +psycopg2-binary>=2.9.9 +SQLAlchemy>=2.0.29 +sentence-transformers>=2.7.0 +transformers>=4.40.0 # DeBERTa intent classifier +torch>=2.0.0 +numpy>=1.26.0 +tiktoken>=0.6.0 +FlagEmbedding>=1.2.5 +redis>=5.0.0 +python-jose[cryptography]>=3.3.0 +passlib[bcrypt]>=1.7.4 +python-multipart>=0.0.9 +httpx>=0.27.0 +aiohttp>=3.9.0 +duckduckgo-search>=6.0.0 # Live search for hybrid RAG +python-dateutil>=2.8.2 # Date parsing for live results diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49cc611ae0fcde5a70869b2ec1dac3e4a16d96d6 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1 @@ +# Initialize src package diff --git a/src/__pycache__/__init__.cpython-313.pyc b/src/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97f9c3bddc461bf99393d331cbb717e771a984d2 Binary files /dev/null and b/src/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/__pycache__/main.cpython-313.pyc b/src/__pycache__/main.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a86751c21116fb0cf79c8a1158279039ceccbe98 Binary files /dev/null and b/src/__pycache__/main.cpython-313.pyc differ diff --git a/src/api/__pycache__/dependencies.cpython-313.pyc b/src/api/__pycache__/dependencies.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e7644fc16ead04bef1e4a3737e7c08496361bf5 Binary files /dev/null and b/src/api/__pycache__/dependencies.cpython-313.pyc differ diff --git a/src/api/dependencies.py b/src/api/dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..3bedf9def0d516376656188d2a84ec5549c71ac8 --- /dev/null +++ b/src/api/dependencies.py @@ -0,0 +1,153 @@ +from fastapi import Depends +from sqlalchemy.orm import Session +from src.infrastructure.database import get_db + +# Adapters +from src.infrastructure.adapters.bge_embedder_adapter import BgeEmbedderAdapter +from src.infrastructure.adapters.qdrant_adapter import QdrantAdapter +from src.infrastructure.adapters.bge_reranker_adapter import BgeRerankerAdapter +from src.infrastructure.adapters.openai_adapter import OpenAiAdapter +from src.infrastructure.adapters.ollama_adapter import OllamaAdapter +from src.infrastructure.adapters.groq_adapter import GroqAdapter +from src.infrastructure.adapters.gemini_adapter import GeminiAdapter +from src.infrastructure.adapters.together_adapter import TogetherAdapter +from src.infrastructure.adapters.huggingface_adapter import HuggingFaceAdapter +from src.infrastructure.adapters.clickhouse_adapter import ClickHouseAdapter +from src.infrastructure.adapters.postgres_adapter import PostgresAdapter +from src.infrastructure.adapters.redis_adapter import RedisAdapter +from src.infrastructure.adapters.duckduckgo_adapter import DuckDuckGoAdapter + +# Hybrid Search Components +from src.core.orchestrator.query_orchestrator import QueryOrchestrator +from src.core.ranking.hybrid_result_ranker import HybridResultRanker + +# Use Cases +from src.core.use_cases.search_use_case import SearchUseCase +from src.core.use_cases.rag_chat_use_case import RagChatUseCase +from src.core.use_cases.analytics_use_case import AnalyticsUseCase + +# Global Singletons for Stateless Adapters to avoid reloading models per request +embedder_adapter = BgeEmbedderAdapter() +qdrant_adapter = QdrantAdapter() +reranker_adapter = BgeRerankerAdapter() +openai_adapter = OpenAiAdapter() +ollama_adapter = OllamaAdapter() +groq_adapter = GroqAdapter() +gemini_adapter = GeminiAdapter() +together_adapter = TogetherAdapter() +huggingface_adapter = HuggingFaceAdapter() +clickhouse_adapter = ClickHouseAdapter() +redis_adapter = RedisAdapter() + +# Hybrid Search Singletons +from src.core.config import settings +duckduckgo_adapter = DuckDuckGoAdapter( + timeout=settings.LIVE_SEARCH_TIMEOUT, + max_results=settings.LIVE_SEARCH_MAX_RESULTS +) +query_orchestrator = QueryOrchestrator( + live_search_adapter=duckduckgo_adapter, + enable_hybrid=settings.ENABLE_HYBRID_SEARCH, + default_live_weight=settings.LIVE_SEARCH_WEIGHT, + default_db_weight=settings.DB_SEARCH_WEIGHT +) +hybrid_result_ranker = HybridResultRanker(reranker=reranker_adapter) + +# Model Pre-warming (Triggered dynamically if needed, usually on startup) +def prewarm_models(): + embedder_adapter._load_model() + reranker_adapter._load_model() + +# --- Dependency Providers --- + +def get_embedder_port(): + return embedder_adapter + +def get_vector_store_port(): + return qdrant_adapter + +def get_reranker_port(): + return reranker_adapter + +from src.core.config import settings +def get_llm_port(): + provider = settings.LLM_PROVIDER.lower() + if provider == "groq": + return groq_adapter + elif provider == "gemini": + return gemini_adapter + elif provider == "together": + return together_adapter + elif provider == "huggingface" or provider == "hf": + return huggingface_adapter + elif provider == "ollama": + return ollama_adapter + return openai_adapter + +def get_analytics_db_port(): + return clickhouse_adapter + +def get_chat_history_port(db: Session = Depends(get_db)): + return PostgresAdapter(db) + +def get_cache_port(): + return redis_adapter + +def get_live_search_port(): + return duckduckgo_adapter + +def get_query_orchestrator(): + return query_orchestrator + +def get_hybrid_ranker(): + return hybrid_result_ranker + +# --- Use Case Providers --- + +def get_search_use_case( + embedder=Depends(get_embedder_port), + vector_store=Depends(get_vector_store_port) +): + return SearchUseCase(embedder, vector_store) + +def get_rag_chat_use_case( + embedder=Depends(get_embedder_port), + vector_store=Depends(get_vector_store_port), + reranker=Depends(get_reranker_port), + llm=Depends(get_llm_port), + chat_history=Depends(get_chat_history_port), + analytics_db=Depends(get_analytics_db_port), + cache=Depends(get_cache_port), + orchestrator=Depends(get_query_orchestrator), + hybrid_ranker=Depends(get_hybrid_ranker) +): + return RagChatUseCase( + embedder=embedder, + vector_store=vector_store, + reranker=reranker, + llm=llm, + chat_history_db=chat_history, + analytics_db=analytics_db, + cache=cache, + orchestrator=orchestrator, + hybrid_ranker=hybrid_ranker + ) + +from src.core.use_cases.account_use_case import AccountUseCase +from src.core.use_cases.agent_router_use_case import AgentRouterUseCase + +def get_analytics_use_case( + analytics_db=Depends(get_analytics_db_port) +): + return AnalyticsUseCase(analytics_db) + +def get_account_use_case(): + return AccountUseCase() + +def get_agent_router_use_case( + llm=Depends(get_llm_port), + rag_chat=Depends(get_rag_chat_use_case), + account=Depends(get_account_use_case), + chat_history=Depends(get_chat_history_port) +): + return AgentRouterUseCase(llm=llm, rag_chat=rag_chat, account=account, chat_history_db=chat_history) diff --git a/src/api/routes/__init__.py b/src/api/routes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a716ed7515a90c5ffd84473cf71fa3c005793bd5 --- /dev/null +++ b/src/api/routes/__init__.py @@ -0,0 +1,2 @@ +# Expose routers +from . import rag, analytics, interactions, accounts, news diff --git a/src/api/routes/__pycache__/__init__.cpython-313.pyc b/src/api/routes/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fe20ca26bf88cc6f20ac5f3eba7912b16fd461d Binary files /dev/null and b/src/api/routes/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/api/routes/__pycache__/accounts.cpython-313.pyc b/src/api/routes/__pycache__/accounts.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bf6923dce77654c7d92fa0f472387862cfbafaa Binary files /dev/null and b/src/api/routes/__pycache__/accounts.cpython-313.pyc differ diff --git a/src/api/routes/__pycache__/analytics.cpython-313.pyc b/src/api/routes/__pycache__/analytics.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68a693e7a13603ad2b7c0276f72c91ed3214d6e9 Binary files /dev/null and b/src/api/routes/__pycache__/analytics.cpython-313.pyc differ diff --git a/src/api/routes/__pycache__/auth.cpython-313.pyc b/src/api/routes/__pycache__/auth.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c18ac4f3bd6c0fa0f2e1b753b63fe0d855478ad9 Binary files /dev/null and b/src/api/routes/__pycache__/auth.cpython-313.pyc differ diff --git a/src/api/routes/__pycache__/interactions.cpython-313.pyc b/src/api/routes/__pycache__/interactions.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e4a6e5ae5f2780ef42a701dae122f2ecde4f720 Binary files /dev/null and b/src/api/routes/__pycache__/interactions.cpython-313.pyc differ diff --git a/src/api/routes/__pycache__/news.cpython-313.pyc b/src/api/routes/__pycache__/news.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbd02bfdd814736903b5cc898e8ce719c1473072 Binary files /dev/null and b/src/api/routes/__pycache__/news.cpython-313.pyc differ diff --git a/src/api/routes/__pycache__/rag.cpython-313.pyc b/src/api/routes/__pycache__/rag.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ced4b1b419e9f6403524bb826e023539d0921645 Binary files /dev/null and b/src/api/routes/__pycache__/rag.cpython-313.pyc differ diff --git a/src/api/routes/accounts.py b/src/api/routes/accounts.py new file mode 100644 index 0000000000000000000000000000000000000000..c2441bebaab1af180246cc7dfb7a02ba158eb251 --- /dev/null +++ b/src/api/routes/accounts.py @@ -0,0 +1,140 @@ +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session +from typing import List + +from src.infrastructure.database import get_db +from src.core.domain.schemas import UserCreate, UserResponse, UserUpdate, PasswordChange, AdminUserUpdate +from src.core.domain.db_models import User, UserRole +from src.core.security import ( + get_password_hash, verify_password, + get_current_user, require_super_admin +) + +router = APIRouter() + + +# ── Public ──────────────────────────────────────────────────────────────────── + +@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED) +def register(user: UserCreate, db: Session = Depends(get_db)): + """Register a new user account (role defaults to 'user').""" + existing = db.query(User).filter( + (User.username == user.username) | (User.email == user.email) + ).first() + if existing: + raise HTTPException(status_code=400, detail="Username or email already registered") + + new_user = User( + username=user.username, + email=user.email, + full_name=user.full_name, + hashed_password=get_password_hash(user.password), + role=UserRole.user, + ) + db.add(new_user) + db.commit() + db.refresh(new_user) + return new_user + + +# ── Authenticated user ──────────────────────────────────────────────────────── + +@router.patch("/me", response_model=UserResponse) +def update_profile( + body: UserUpdate, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """Update own profile (username, full_name, email).""" + if body.username and body.username != current_user.username: + if db.query(User).filter(User.username == body.username).first(): + raise HTTPException(status_code=400, detail="Username already taken") + current_user.username = body.username + if body.email and body.email != current_user.email: + if db.query(User).filter(User.email == body.email).first(): + raise HTTPException(status_code=400, detail="Email already in use") + current_user.email = body.email + if body.full_name is not None: + current_user.full_name = body.full_name + + db.commit() + db.refresh(current_user) + return current_user + + +@router.post("/me/change-password", status_code=status.HTTP_204_NO_CONTENT) +def change_password( + body: PasswordChange, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """Change own password.""" + if not verify_password(body.current_password, current_user.hashed_password): + raise HTTPException(status_code=400, detail="Current password is incorrect") + current_user.hashed_password = get_password_hash(body.new_password) + db.commit() + + +# ── Super admin only ────────────────────────────────────────────────────────── + +@router.get("/users", response_model=List[UserResponse]) +def list_users( + skip: int = 0, + limit: int = 50, + _admin: User = Depends(require_super_admin), + db: Session = Depends(get_db) +): + """List all users (super_admin only).""" + return db.query(User).offset(skip).limit(limit).all() + + +@router.get("/users/{user_id}", response_model=UserResponse) +def get_user( + user_id: int, + _admin: User = Depends(require_super_admin), + db: Session = Depends(get_db) +): + """Get a specific user by ID (super_admin only).""" + user = db.query(User).filter(User.id == user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") + return user + + +@router.patch("/users/{user_id}", response_model=UserResponse) +def admin_update_user( + user_id: int, + body: AdminUserUpdate, + _admin: User = Depends(require_super_admin), + db: Session = Depends(get_db) +): + """Update a user's role or active status (super_admin only).""" + user = db.query(User).filter(User.id == user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") + if body.is_active is not None: + user.is_active = body.is_active + if body.role is not None: + try: + user.role = UserRole(body.role) + except ValueError: + raise HTTPException(status_code=400, detail="Invalid role. Must be 'super_admin' or 'user'") + db.commit() + db.refresh(user) + return user + + +@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT) +def delete_user( + user_id: int, + admin: User = Depends(require_super_admin), + db: Session = Depends(get_db) +): + """Delete a user (super_admin only). Cannot delete yourself.""" + if user_id == admin.id: + raise HTTPException(status_code=400, detail="Cannot delete your own account") + user = db.query(User).filter(User.id == user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") + db.delete(user) + db.commit() diff --git a/src/api/routes/analytics.py b/src/api/routes/analytics.py new file mode 100644 index 0000000000000000000000000000000000000000..b7dca37f09fc12b3f9f0c600cb90a56197335651 --- /dev/null +++ b/src/api/routes/analytics.py @@ -0,0 +1,117 @@ +from fastapi import APIRouter, Depends +from src.core.use_cases.analytics_use_case import AnalyticsUseCase +from src.core.ports.vector_store_port import VectorStorePort +from src.api.dependencies import get_analytics_use_case, get_vector_store_port + +router = APIRouter() + +@router.get("/sentiment") +def get_sentiment(analytics_use_case: AnalyticsUseCase = Depends(get_analytics_use_case)): + query = """ + SELECT + entity, + avg(sentiment_score) as avg_sentiment, + count() as mention_count + FROM sentiment_results + GROUP BY entity + ORDER BY mention_count DESC + LIMIT 10 + """ + results = analytics_use_case.execute_raw_query(query) + if not results or "error" in results: + return {"error": "Could not fetch sentiment."} + + data = [] + for row in results.get("rows", []): + data.append({ + "entity": row[0], + "avg_sentiment": float(row[1]), + "mention_count": int(row[2]) + }) + return {"data": data} + +@router.get("/trends") +def get_trends( + days: int = 7, + analytics_use_case: AnalyticsUseCase = Depends(get_analytics_use_case) +): + results = analytics_use_case.get_trends(days) + if not results or "error" in results: + return {"error": "Could not fetch trends."} + + data = [] + for row in results.get("rows", []): + data.append({ + "topic": row[0], + "momentum": float(row[1]), + "volume": int(row[2]) + }) + return {"data": data} + +@router.get("/articles-over-time") +def get_articles_over_time( + days: int = 30, + analytics_use_case: AnalyticsUseCase = Depends(get_analytics_use_case) +): + query = f""" + SELECT + toDate(scraped_at) as date, + count() as article_count + FROM sentiment_results + GROUP BY date + ORDER BY date ASC + """ + results = analytics_use_case.execute_raw_query(query) + if not results or "error" in results: + return {"error": "Could not fetch articles over time."} + + data = [] + for row in results.get("rows", []): + data.append({ + "date": str(row[0]), + "count": int(row[1]) + }) + return {"data": data} + +@router.get("/source-stats") +def get_source_stats(analytics_use_case: AnalyticsUseCase = Depends(get_analytics_use_case)): + query = """ + SELECT + source, + count() as article_count, + avg(sentiment_score) as avg_sentiment + FROM sentiment_results + GROUP BY source + ORDER BY article_count DESC + """ + results = analytics_use_case.execute_raw_query(query) + if not results or "error" in results: + return {"error": "Could not fetch source stats."} + + data = [] + for row in results.get("rows", []): + data.append({ + "source": row[0], + "article_count": int(row[1]), + "avg_sentiment": float(row[2]) + }) + return {"data": data} + +@router.get("/pipeline-stats") +def get_pipeline_stats( + analytics_use_case: AnalyticsUseCase = Depends(get_analytics_use_case), + vector_store: VectorStorePort = Depends(get_vector_store_port) +): + qdrant_stats = vector_store.get_collection_stats() + + query = "SELECT count() FROM sentiment_results" + ch_res = analytics_use_case.execute_raw_query(query) + + ch_count = 0 + if ch_res and not "error" in ch_res and ch_res.get("rows"): + ch_count = int(ch_res["rows"][0][0]) + + return { + "total_articles_in_vector_db": qdrant_stats.get("vectors_count", 0) if qdrant_stats else 0, + "total_sentiment_results": ch_count + } diff --git a/src/api/routes/auth.py b/src/api/routes/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..8f3a5a6cd7babcf18655e34af3e8852071f6dcf8 --- /dev/null +++ b/src/api/routes/auth.py @@ -0,0 +1,56 @@ +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.security import OAuth2PasswordRequestForm +from sqlalchemy.orm import Session +from datetime import timedelta + +from src.infrastructure.database import get_db +from src.core.domain.db_models import User +from src.core.domain.schemas import TokenResponse, RefreshRequest, UserResponse +from src.core.security import ( + verify_password, create_access_token, create_refresh_token, + rotate_refresh_token, revoke_all_refresh_tokens, + get_current_user +) +from src.core.config import settings + +router = APIRouter() + + +@router.post("/login", response_model=TokenResponse) +def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): + """Login with email + password. Returns access + refresh tokens.""" + user = db.query(User).filter(User.email == form_data.username).first() + if not user or not verify_password(form_data.password, user.hashed_password): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password") + if not user.is_active: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Account is disabled") + + access_token = create_access_token( + data={"sub": user.email, "role": user.role}, + expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + ) + refresh_token = create_refresh_token(user.id, db) + return TokenResponse(access_token=access_token, refresh_token=refresh_token) + + +@router.post("/refresh", response_model=TokenResponse) +def refresh_tokens(body: RefreshRequest, db: Session = Depends(get_db)): + """Exchange a valid refresh token for a new access + refresh token pair.""" + new_refresh, user = rotate_refresh_token(body.refresh_token, db) + access_token = create_access_token( + data={"sub": user.email, "role": user.role}, + expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + ) + return TokenResponse(access_token=access_token, refresh_token=new_refresh) + + +@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT) +def logout(current_user: User = Depends(get_current_user), db: Session = Depends(get_db)): + """Revoke all refresh tokens for the current user (full logout).""" + revoke_all_refresh_tokens(current_user.id, db) + + +@router.get("/me", response_model=UserResponse) +def get_me(current_user: User = Depends(get_current_user)): + """Get the currently authenticated user's profile.""" + return current_user diff --git a/src/api/routes/interactions.py b/src/api/routes/interactions.py new file mode 100644 index 0000000000000000000000000000000000000000..00b609aaa9ba9c31c8719a57e2c1df667cd00b41 --- /dev/null +++ b/src/api/routes/interactions.py @@ -0,0 +1,127 @@ +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import Session +from src.infrastructure.database import get_db +from sqlalchemy import func +from src.core.domain.schemas import FeedbackRequest, ChatSession +from src.core.domain.db_models import ChatHistory, Feedback +from src.core.security import get_current_user +from src.core.domain.db_models import User +from typing import Optional + +router = APIRouter() + +@router.get("/history/{session_id}") +def get_chat_history(session_id: str, db: Session = Depends(get_db)): + history = db.query(ChatHistory).filter( + ChatHistory.session_id == session_id + ).order_by(ChatHistory.timestamp.asc()).all() + + formatted_history = [] + for h in history: + formatted_history.append({ + "id": h.id, + "role": h.role, + "content": h.content, + "timestamp": h.timestamp, + "pinned": getattr(h, "pinned", False), + }) + + return {"session_id": session_id, "history": formatted_history} + + +@router.post("/feedback") +def submit_feedback(req: FeedbackRequest, db: Session = Depends(get_db)): + msg = db.query(ChatHistory).filter( + ChatHistory.id == req.message_id, + ChatHistory.session_id == req.session_id + ).first() + if not msg: + raise HTTPException(status_code=404, detail="Message not found in session") + + # Upsert: update existing feedback or create new + existing = db.query(Feedback).filter( + Feedback.message_id == req.message_id, + Feedback.session_id == req.session_id + ).first() + + if existing: + existing.rating = req.rating + existing.comment = req.comment + else: + feedback = Feedback( + session_id=req.session_id, + message_id=req.message_id, + rating=req.rating, + comment=req.comment + ) + db.add(feedback) + + db.commit() + return {"status": "success", "message": "Feedback recorded."} + + +@router.get("/feedback/{session_id}") +def get_session_feedback(session_id: str, db: Session = Depends(get_db)): + """Get all feedback ratings for a session (so UI can restore like/dislike state).""" + feedbacks = db.query(Feedback).filter(Feedback.session_id == session_id).all() + return {str(f.message_id): f.rating for f in feedbacks} + + +@router.post("/pin/{message_id}") +def pin_message(message_id: int, db: Session = Depends(get_db)): + """Toggle pin on a message.""" + msg = db.query(ChatHistory).filter(ChatHistory.id == message_id).first() + if not msg: + raise HTTPException(status_code=404, detail="Message not found") + # Toggle pinned — add column if missing via getattr + current = getattr(msg, "pinned", False) or False + msg.pinned = not current + db.commit() + return {"pinned": msg.pinned} + + +@router.get("/sessions") +def get_chat_sessions( + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """Retrieve sessions for the authenticated user only.""" + sessions = db.query( + ChatHistory.session_id, + func.count(ChatHistory.id).label("message_count"), + func.max(ChatHistory.timestamp).label("last_active") + ).filter( + ChatHistory.user_id == current_user.id + ).group_by(ChatHistory.session_id).order_by( + func.max(ChatHistory.timestamp).desc() + ).all() + + return [ + ChatSession( + session_id=s.session_id, + message_count=s.message_count, + last_active=s.last_active + ) + for s in sessions + ] + + +@router.delete("/sessions/{session_id}") +def delete_chat_session( + session_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """Delete session — only owner can delete.""" + # Verify ownership + owned = db.query(ChatHistory).filter( + ChatHistory.session_id == session_id, + ChatHistory.user_id == current_user.id + ).first() + if not owned: + raise HTTPException(status_code=404, detail="Session not found") + + db.query(Feedback).filter(Feedback.session_id == session_id).delete() + deleted_msgs = db.query(ChatHistory).filter(ChatHistory.session_id == session_id).delete() + db.commit() + return {"status": "success", "message": f"Deleted session {session_id} with {deleted_msgs} messages."} diff --git a/src/api/routes/news.py b/src/api/routes/news.py new file mode 100644 index 0000000000000000000000000000000000000000..50054731271ff2066a1f893c8970ecde939f683b --- /dev/null +++ b/src/api/routes/news.py @@ -0,0 +1,138 @@ +from fastapi import APIRouter, Query, HTTPException, Depends +from typing import Optional +from src.core.ports.vector_store_port import VectorStorePort +from src.core.ports.embedder_port import EmbedderPort +from src.core.use_cases.analytics_use_case import AnalyticsUseCase +from src.api.dependencies import get_vector_store_port, get_embedder_port, get_analytics_use_case +from src.core.domain.schemas import BrowseResponse, SearchResponse, NewsArticle + +router = APIRouter() + +def _dict_to_article(payload: dict, score: float = None) -> NewsArticle: + if payload is None: + payload = {} + + # Safely get metadata + metadata = payload.get("metadata") if payload else {} + if metadata is None: + metadata = {} + + # Title can be stored at top-level payload OR nested inside metadata + title = ( + payload.get("title") + or (metadata.get("title") if isinstance(metadata, dict) else None) + ) + + return NewsArticle( + doc_id=payload.get("doc_id", "unknown"), + url=payload.get("url"), + title=title, + content=payload.get("text", payload.get("content", "")), + source=payload.get("source"), + published_at=payload.get("published_at"), + score=score, + metadata=metadata if isinstance(metadata, dict) else {} + ) + +@router.get("/latest", response_model=BrowseResponse) +def get_latest_news( + limit: int = Query(10, le=50), + source: Optional[str] = None, + language: Optional[str] = None, + vector_store: VectorStorePort = Depends(get_vector_store_port) +): + """Get latest news articles sorted by publication date""" + try: + result = vector_store.browse(limit=limit, offset=0, source=source, language=language) + + # Convert Qdrant points to articles and sort by published_at + articles = [] + for p in result["articles"]: + article = _dict_to_article(p.payload or {}, getattr(p, "score", None)) + articles.append(article) + + # Sort by published_at descending (latest first) + articles.sort(key=lambda x: x.published_at or "", reverse=True) + + # Ensure next_offset is an integer or None + next_offset = result.get("next_offset") + if next_offset is not None and not isinstance(next_offset, int): + next_offset = None # If it's not an int, set to None + + return BrowseResponse( + total_returned=len(articles), + articles=articles, + next_offset=next_offset + ) + except Exception as e: + # Return empty response on error instead of 500 + print(f"Error fetching news: {e}") + return BrowseResponse( + total_returned=0, + articles=[], + next_offset=None + ) + +@router.get("/browse", response_model=BrowseResponse) +def browse_news( + limit: int = Query(20, le=100), + offset: int = 0, + source: Optional[str] = None, + language: Optional[str] = None, + vector_store: VectorStorePort = Depends(get_vector_store_port) +): + result = vector_store.browse(limit=limit, offset=offset, source=source, language=language) + + # Qdrant scroll returns points with .payload + articles = [] + for p in result["articles"]: + articles.append(_dict_to_article(p.payload or {}, getattr(p, "score", None))) + + return BrowseResponse( + total_returned=len(articles), + articles=articles, + next_offset=result["next_offset"] + ) + +@router.get("/search", response_model=SearchResponse) +def search_news( + q: str = Query(..., min_length=1), + top_k: int = Query(10, le=50), + source: Optional[str] = None, + language: Optional[str] = None, + embedder: EmbedderPort = Depends(get_embedder_port), + vector_store: VectorStorePort = Depends(get_vector_store_port) +): + try: + query_vector = embedder.encode_query(q) + results = vector_store.search( + query_vectors=query_vector, + limit=top_k, + source_filter=source, + language_filter=language + ) + # VectorStorePort.search returns SearchResult objects (content, metadata, score, doc_id) + articles = [] + for r in results: + articles.append(_dict_to_article(r.metadata, r.score)) + + return SearchResponse(results=articles) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@router.get("/sources") +def get_sources(analytics: AnalyticsUseCase = Depends(get_analytics_use_case)): + query = "SELECT source, count() as article_count FROM sentiment_results GROUP BY source ORDER BY article_count DESC" + res = analytics.execute_raw_query(query) + if res and res.get("rows"): + sources = [row[0] for row in res["rows"] if row[0]] + return {"sources": sources} + return {"sources": []} + +@router.get("/{doc_id}", response_model=NewsArticle) +def get_news_article(doc_id: str, vector_store: VectorStorePort = Depends(get_vector_store_port)): + result = vector_store.get_by_doc_id(doc_id) + if not result: + raise HTTPException(status_code=404, detail="Article not found") + + return _dict_to_article(result.metadata, result.score) diff --git a/src/api/routes/rag.py b/src/api/routes/rag.py new file mode 100644 index 0000000000000000000000000000000000000000..f5d73f3d917da53c9e7bbc07dfca49d21e2f5520 --- /dev/null +++ b/src/api/routes/rag.py @@ -0,0 +1,95 @@ +import os +import uuid +from typing import Optional +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import StreamingResponse +from fastapi.security import OAuth2PasswordBearer +from src.core.domain.schemas import ChatRequest, ChatResponse, SearchResponse +from src.core.use_cases.search_use_case import SearchUseCase +from src.core.use_cases.rag_chat_use_case import RagChatUseCase +from src.core.use_cases.agent_router_use_case import AgentRouterUseCase +from src.api.dependencies import get_search_use_case, get_rag_chat_use_case, get_agent_router_use_case +from src.core.security import get_current_user +from src.core.domain.db_models import User +from jose import jwt, JWTError +from src.core.config import settings +from src.infrastructure.database import get_db +from sqlalchemy.orm import Session + +router = APIRouter() + +# Optional bearer — doesn't raise if token is missing +_optional_bearer = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/login", auto_error=False) + +def get_optional_user( + token: Optional[str] = Depends(_optional_bearer), + db: Session = Depends(get_db) +) -> Optional[User]: + """Returns the authenticated user or None for guests.""" + if not token: + return None + try: + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"]) + if payload.get("type") != "access": + return None + email = payload.get("sub") + if not email: + return None + user = db.query(User).filter(User.email == email).first() + return user if user and user.is_active else None + except JWTError: + return None + +@router.post("/search") +def direct_search( + request: ChatRequest, + search_use_case: SearchUseCase = Depends(get_search_use_case), + current_user: User = Depends(get_current_user) +): + """Lightning-fast hybrid search bypassing the LLM.""" + try: + results = search_use_case.execute( + query=request.query, + limit=request.top_k, + source_filter=request.source_filter, + language_filter=request.language_filter, + days_back=getattr(request, 'days_back', None) + ) + hits = [{"content": r.content, "metadata": r.metadata, "score": r.score, "doc_id": r.doc_id} for r in results] + return {"results": hits, "query": request.query} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@router.post("/chat/test", response_model=ChatResponse) +def chat_test( + request: ChatRequest, + agent_router_use_case: AgentRouterUseCase = Depends(get_agent_router_use_case) +): + """Test RAG chat endpoint without authentication for debugging""" + + # Get endpoint from environment variable or use default + os.getenv("RAG_ENDPOINT", "/rag/chat/test") + + result = agent_router_use_case.execute_chat(request) + return result + +@router.post("/chat/stream") +async def chat_with_rag_stream( + request: ChatRequest, + router_use_case: AgentRouterUseCase = Depends(get_agent_router_use_case), + current_user: Optional[User] = Depends(get_optional_user) +): + """Streaming RAG chat. Works for both authenticated users and guests.""" + try: + if current_user is None and not request.session_id: + request.session_id = f"guest_{uuid.uuid4().hex[:12]}" + + user_id = current_user.id if current_user else None + + return StreamingResponse( + router_use_case.execute_stream(request, is_guest=(current_user is None), user_id=user_id), + media_type="text/event-stream" + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + diff --git a/src/core/__pycache__/config.cpython-313.pyc b/src/core/__pycache__/config.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..debf6f3a1c44825137527dcb21be33be034e7a08 Binary files /dev/null and b/src/core/__pycache__/config.cpython-313.pyc differ diff --git a/src/core/__pycache__/security.cpython-313.pyc b/src/core/__pycache__/security.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cda5e49d607dfd24299208e7104ed26f386e04f Binary files /dev/null and b/src/core/__pycache__/security.cpython-313.pyc differ diff --git a/src/core/config.py b/src/core/config.py new file mode 100644 index 0000000000000000000000000000000000000000..db5e27eeff137a37588495fb2f4656cb1f221648 --- /dev/null +++ b/src/core/config.py @@ -0,0 +1,92 @@ +import os +from pydantic_settings import BaseSettings, SettingsConfigDict + +class Settings(BaseSettings): + model_config = SettingsConfigDict(env_file=".env", env_file_encoding='utf-8', extra='ignore') + + PROJECT_NAME: str = "RAG API Service" + API_V1_STR: str = "/api/v1" + + QDRANT_HOST: str = os.getenv("QDRANT_HOST", "localhost") + QDRANT_PORT: int = int(os.getenv("QDRANT_PORT", "6333")) + QDRANT_URL: str = os.getenv("QDRANT_URL", "") # Cloud URL (overrides host/port) + QDRANT_API_KEY: str = os.getenv("QDRANT_API_KEY", "") # Cloud API Key + QDRANT_COLLECTION: str = os.getenv("QDRANT_COLLECTION", "news_articles") + + CLICKHOUSE_HOST: str = os.getenv("CLICKHOUSE_HOST", "localhost") + CLICKHOUSE_PORT: int = int(os.getenv("CLICKHOUSE_PORT", "8123")) + CLICKHOUSE_USER: str = os.getenv("CLICKHOUSE_USER", "default") + CLICKHOUSE_PASSWORD: str = os.getenv("CLICKHOUSE_PASSWORD", "") + CLICKHOUSE_DB: str = os.getenv("CLICKHOUSE_DB", "default") + CLICKHOUSE_SECURE: bool = os.getenv("CLICKHOUSE_SECURE", "false").lower() == "true" + + + # Embedding Model Config + EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "BAAI/bge-m3") + VECTOR_SIZE: int = int(os.getenv("VECTOR_SIZE", "1024")) + RERANKER_MODEL: str = os.getenv("RERANKER_MODEL", "BAAI/bge-reranker-v2-m3") + + # PostgreSQL / Neon Config + DATABASE_URL: str = os.getenv("DATABASE_URL", "") # Full Neon URL (overrides individual fields) + POSTGRES_USER: str = os.getenv("POSTGRES_USER", "postgres") + POSTGRES_PASSWORD: str = os.getenv("POSTGRES_PASSWORD", "postgres") + POSTGRES_SERVER: str = os.getenv("POSTGRES_SERVER", "localhost") + POSTGRES_PORT: str = os.getenv("POSTGRES_PORT", "5432") + POSTGRES_DB: str = os.getenv("POSTGRES_DB", "rag_interactions") + + @property + def SQLALCHEMY_DATABASE_URI(self) -> str: + if self.DATABASE_URL: + return self.DATABASE_URL + return f"postgresql://{self.POSTGRES_USER}:{self.POSTGRES_PASSWORD}@{self.POSTGRES_SERVER}:{self.POSTGRES_PORT}/{self.POSTGRES_DB}" + + # LLM Settings + # Supported providers: "groq", "gemini", "together", "openai", "ollama" + LLM_PROVIDER: str = os.getenv("LLM_PROVIDER", "groq") + + # Groq — free, 200+ tok/s, llama-3.3-70b-versatile | https://console.groq.com + OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "") + GROQ_API_KEY: str = os.getenv("GROQ_API_KEY", "") + GROQ_MODEL: str = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile") + + # Google Gemini — free tier (15 RPM / 1M TPM) | https://aistudio.google.com/apikey + GEMINI_API_KEY: str = os.getenv("GEMINI_API_KEY", "") + GEMINI_MODEL: str = os.getenv("GEMINI_MODEL", "gemini-1.5-flash") + + # Together AI — free $25 credit | https://api.together.ai + TOGETHER_API_KEY: str = os.getenv("TOGETHER_API_KEY", "") + TOGETHER_MODEL: str = os.getenv("TOGETHER_MODEL", "meta-llama/Llama-3.3-70B-Instruct-Turbo") + + # HuggingFace Inference API — free with HF token | https://huggingface.co/settings/tokens + HF_TOKEN: str = os.getenv("HF_TOKEN", "") + HF_MODEL: str = os.getenv("HF_MODEL", "meta-llama/Llama-3.1-8B-Instruct") + + # Ollama — local inference + OLLAMA_HOST: str = os.getenv("OLLAMA_HOST", "http://localhost:11434") + OLLAMA_MODEL: str = os.getenv("OLLAMA_MODEL", "llama3.2") + + # Redis Settings + REDIS_URL: str = os.getenv("REDIS_URL", "") # Full URL (Upstash) - overrides host/port + REDIS_HOST: str = os.getenv("REDIS_HOST", "localhost") + REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6380")) + REDIS_DB: int = int(os.getenv("REDIS_DB", "0")) + REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "") + + # Hybrid Search Settings + ENABLE_HYBRID_SEARCH: bool = os.getenv("ENABLE_HYBRID_SEARCH", "true").lower() == "true" + LIVE_SEARCH_TIMEOUT: float = float(os.getenv("LIVE_SEARCH_TIMEOUT", "2.0")) + LIVE_SEARCH_MAX_RESULTS: int = int(os.getenv("LIVE_SEARCH_MAX_RESULTS", "5")) + LIVE_SEARCH_WEIGHT: float = float(os.getenv("LIVE_SEARCH_WEIGHT", "0.5")) + DB_SEARCH_WEIGHT: float = float(os.getenv("DB_SEARCH_WEIGHT", "0.5")) + + # Cache Settings (TTL in seconds) + CACHE_RESPONSE_TTL: int = int(os.getenv("CACHE_RESPONSE_TTL", "300")) # 5 minutes + CACHE_LIVE_TTL: int = int(os.getenv("CACHE_LIVE_TTL", "600")) # 10 minutes + CACHE_TRANSLATION_TTL: int = int(os.getenv("CACHE_TRANSLATION_TTL", "3600")) # 1 hour + CACHE_INTENT_TTL: int = int(os.getenv("CACHE_INTENT_TTL", "3600")) # 1 hour + + # Security Settings + SECRET_KEY: str = os.getenv("SECRET_KEY", "a_very_secret_key_change_me_in_production") + ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "60")) + +settings = Settings() diff --git a/src/core/domain/__pycache__/db_models.cpython-313.pyc b/src/core/domain/__pycache__/db_models.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d2428de678529410467a1beb32e6b0e9286ad95 Binary files /dev/null and b/src/core/domain/__pycache__/db_models.cpython-313.pyc differ diff --git a/src/core/domain/__pycache__/schemas.cpython-313.pyc b/src/core/domain/__pycache__/schemas.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4485473ef0a0fb507931b00e084be1d2b961827e Binary files /dev/null and b/src/core/domain/__pycache__/schemas.cpython-313.pyc differ diff --git a/src/core/domain/db_models.py b/src/core/domain/db_models.py new file mode 100644 index 0000000000000000000000000000000000000000..f355428c9290f22a03ddd879444ce33b87ce03da --- /dev/null +++ b/src/core/domain/db_models.py @@ -0,0 +1,61 @@ +from sqlalchemy import Column, Integer, String, DateTime, Text, ForeignKey, CheckConstraint, Boolean, Enum +from sqlalchemy.orm import declarative_base, relationship +from datetime import datetime +import enum + +Base = declarative_base() + +class UserRole(str, enum.Enum): + super_admin = "super_admin" + user = "user" + +class User(Base): + __tablename__ = "users" + + id = Column(Integer, primary_key=True, index=True) + username = Column(String, unique=True, index=True, nullable=False) + email = Column(String, unique=True, index=True, nullable=False) + hashed_password = Column(String, nullable=False) + role = Column(Enum(UserRole), default=UserRole.user, nullable=False) + is_active = Column(Boolean, default=True, nullable=False) + full_name = Column(String, nullable=True) + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + refresh_tokens = relationship("RefreshToken", back_populates="user", cascade="all, delete-orphan") + + +class RefreshToken(Base): + __tablename__ = "refresh_tokens" + + id = Column(Integer, primary_key=True, index=True) + token = Column(String, unique=True, index=True, nullable=False) + user_id = Column(Integer, ForeignKey("users.id"), nullable=False) + expires_at = Column(DateTime, nullable=False) + revoked = Column(Boolean, default=False, nullable=False) + created_at = Column(DateTime, default=datetime.utcnow) + + user = relationship("User", back_populates="refresh_tokens") + +class ChatHistory(Base): + __tablename__ = "chat_history" + + id = Column(Integer, primary_key=True, index=True) + session_id = Column(String, index=True, nullable=False) + user_id = Column(Integer, ForeignKey("users.id"), nullable=True) + role = Column(String, nullable=False) + content = Column(Text, nullable=False) + retrieved_doc_ids = Column(Text, nullable=True) + pinned = Column(Boolean, default=False, nullable=False) + timestamp = Column(DateTime, default=datetime.utcnow) + +class Feedback(Base): + __tablename__ = "feedback" + + id = Column(Integer, primary_key=True, index=True) + session_id = Column(String, index=True, nullable=False) + message_id = Column(Integer, nullable=False) + user_id = Column(Integer, ForeignKey("users.id"), nullable=True) + rating = Column(Integer, CheckConstraint('rating IN (1, -1)')) + comment = Column(Text, nullable=True) + timestamp = Column(DateTime, default=datetime.utcnow) diff --git a/src/core/domain/schemas.py b/src/core/domain/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..6b13bd1da17bf477cb6d85851a351998693ae205 --- /dev/null +++ b/src/core/domain/schemas.py @@ -0,0 +1,107 @@ +from pydantic import BaseModel +from typing import List, Optional, Any +from datetime import datetime + +class ChatRequest(BaseModel): + query: str + session_id: Optional[str] = None + top_k: int = 7 # increased from 5 — gives multilingual diversity room + source_filter: Optional[str] = None + language_filter: Optional[str] = None + +class SourceDocument(BaseModel): + content: str + metadata: dict + score: float + +class ChatResponse(BaseModel): + answer: str + sources: List[SourceDocument] + session_id: str = "anonymous" + +class FeedbackRequest(BaseModel): + session_id: str + message_id: int + rating: int + comment: Optional[str] = None + +class SentimentData(BaseModel): + entity: str + avg_sentiment: float + mention_count: int + +class TrendData(BaseModel): + topic: str + volume: int + momentum: float + +class NewsArticle(BaseModel): + doc_id: str + url: Optional[str] = None + title: Optional[str] = None + content: str + source: Optional[str] = None + published_at: Optional[str] = None + score: Optional[float] = None + metadata: dict = {} + +class BrowseResponse(BaseModel): + total_returned: int + articles: List[NewsArticle] + next_offset: Optional[int] = None + +class SearchResponse(BaseModel): + results: List[NewsArticle] + +class SourceStat(BaseModel): + source: str + article_count: int + avg_sentiment: float + +class PipelineStats(BaseModel): + total_articles_in_vector_db: int + total_sentiment_results: int + +class ChatSession(BaseModel): + session_id: str + message_count: int + last_active: Optional[datetime] = None + +class UserCreate(BaseModel): + username: str + email: str + password: str + full_name: Optional[str] = None + +class UserResponse(BaseModel): + id: int + username: str + email: str + full_name: Optional[str] = None + role: str + is_active: bool + created_at: Optional[datetime] = None + + class Config: + from_attributes = True + +class UserUpdate(BaseModel): + username: Optional[str] = None + full_name: Optional[str] = None + email: Optional[str] = None + +class PasswordChange(BaseModel): + current_password: str + new_password: str + +class TokenResponse(BaseModel): + access_token: str + refresh_token: str + token_type: str = "bearer" + +class RefreshRequest(BaseModel): + refresh_token: str + +class AdminUserUpdate(BaseModel): + is_active: Optional[bool] = None + role: Optional[str] = None diff --git a/src/core/orchestrator/__init__.py b/src/core/orchestrator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d155a6b4809cc06005058716a9d0c1d4bbceecdf --- /dev/null +++ b/src/core/orchestrator/__init__.py @@ -0,0 +1 @@ +# Query Orchestrator Module diff --git a/src/core/orchestrator/query_orchestrator.py b/src/core/orchestrator/query_orchestrator.py new file mode 100644 index 0000000000000000000000000000000000000000..e40df0cbf57ceaf819de62a847ed20563469fe6e --- /dev/null +++ b/src/core/orchestrator/query_orchestrator.py @@ -0,0 +1,434 @@ +""" +Query Orchestrator + +Orchestrates hybrid search between live sources (DuckDuckGo) and database (Qdrant). +Integrates seamlessly with the existing multilingual RAG pipeline. + +Key Features: +- Intelligent search strategy selection (live, DB, or hybrid) +- Uses production-grade intent classification (v2) +- Parallel execution of live and database searches +- Integration with existing 6-language multilingual pipeline +- Graceful fallbacks when live search fails +- Cache-aware execution +""" + +import logging +import asyncio +from typing import Dict, Any, List, Optional, Tuple +from datetime import datetime + +logger = logging.getLogger(__name__) + + +class SearchStrategy: + """ + Search strategy configuration. + + Determines which sources to use and how to weight them. + """ + + def __init__( + self, + use_live: bool = True, + use_db: bool = True, + live_weight: float = 0.5, + db_weight: float = 0.5, + reason: str = "", + intent_result: Optional[Any] = None # IntentResult from v2 classifier + ): + self.use_live = use_live + self.use_db = use_db + self.live_weight = live_weight + self.db_weight = db_weight + self.reason = reason + self.intent_result = intent_result # Store full intent result for debugging + + def __repr__(self): + return ( + f"SearchStrategy(live={self.use_live}, db={self.use_db}, " + f"weights={self.live_weight:.1f}/{self.db_weight:.1f}, " + f"reason='{self.reason}')" + ) + + +class QueryOrchestrator: + """ + Orchestrates hybrid search between live sources and database. + + Integrates with existing multilingual pipeline: + - Reuses dense vectors (computed once from English) + - Reuses sparse vectors (batched for 6 languages) + - Adds live search in parallel with DB search + - Merges results for unified ranking + """ + + # Temporal keywords that indicate need for live search + TEMPORAL_KEYWORDS = [ + "today", "now", "latest", "breaking", "just", "current", + "this morning", "this afternoon", "this evening", "tonight", + "yesterday", "recent", "recently", "new", "fresh" + ] + + # Historical keywords that indicate DB-only search + HISTORICAL_KEYWORDS = [ + "history", "historical", "background", "context", "past", + "analysis", "overview", "summary", "explain", "what is", + "who is", "tell me about" + ] + + def __init__( + self, + live_search_adapter, + enable_hybrid: bool = True, + default_live_weight: float = 0.5, + default_db_weight: float = 0.5 + ): + """ + Initialize query orchestrator. + + Args: + live_search_adapter: DuckDuckGo adapter instance + enable_hybrid: Global flag to enable/disable hybrid search + default_live_weight: Default weight for live results + default_db_weight: Default weight for database results + """ + self.live_search = live_search_adapter + self.enable_hybrid = enable_hybrid + self.default_live_weight = default_live_weight + self.default_db_weight = default_db_weight + + def decide_search_strategy(self, query: str, intent: str = "NEWS", intent_result: Optional[Any] = None) -> SearchStrategy: + """ + Decide which search sources to use based on query characteristics. + + Now supports production-grade intent classification with multi-class intents: + - NEWS_TEMPORAL → prioritize live search + - NEWS_HISTORICAL → use DB only + - NEWS_GENERAL → balanced hybrid + - OTHER → skip search + + Args: + query: User query + intent: Simple intent ("NEWS" or "OTHER") for backward compatibility + intent_result: Full IntentResult from v2 classifier (if available) + + Returns: + SearchStrategy object with source selection and weights + """ + # If hybrid search is disabled globally, use DB only + if not self.enable_hybrid: + return SearchStrategy( + use_live=False, + use_db=True, + live_weight=0.0, + db_weight=1.0, + reason="Hybrid search disabled", + intent_result=intent_result + ) + + # If live search is unavailable, use DB only + if not self.live_search.is_available(): + return SearchStrategy( + use_live=False, + use_db=True, + live_weight=0.0, + db_weight=1.0, + reason="Live search unavailable", + intent_result=intent_result + ) + + # ── Use v2 Intent Result if available (production-grade) ────────────── + + if intent_result and hasattr(intent_result, 'intent'): + detailed_intent = intent_result.intent + confidence = intent_result.confidence + + logger.info( + f"Using v2 intent: {detailed_intent} " + f"(confidence={confidence:.2f}, method={intent_result.method})" + ) + + # OTHER → skip search + if detailed_intent == "OTHER": + return SearchStrategy( + use_live=False, + use_db=False, + live_weight=0.0, + db_weight=0.0, + reason=f"Small talk (confidence={confidence:.2f})", + intent_result=intent_result + ) + + # NEWS_TEMPORAL → prioritize live search + elif detailed_intent == "NEWS_TEMPORAL": + # High confidence → strong live bias + if confidence >= 0.80: + return SearchStrategy( + use_live=True, + use_db=True, + live_weight=0.8, + db_weight=0.2, + reason=f"Temporal query (high confidence={confidence:.2f})", + intent_result=intent_result + ) + # Medium confidence → moderate live bias + else: + return SearchStrategy( + use_live=True, + use_db=True, + live_weight=0.7, + db_weight=0.3, + reason=f"Temporal query (medium confidence={confidence:.2f})", + intent_result=intent_result + ) + + # NEWS_HISTORICAL → use DB only + elif detailed_intent == "NEWS_HISTORICAL": + return SearchStrategy( + use_live=False, + use_db=True, + live_weight=0.0, + db_weight=1.0, + reason=f"Historical query (confidence={confidence:.2f})", + intent_result=intent_result + ) + + # NEWS_GENERAL → balanced hybrid + elif detailed_intent == "NEWS_GENERAL": + return SearchStrategy( + use_live=True, + use_db=True, + live_weight=self.default_live_weight, + db_weight=self.default_db_weight, + reason=f"General news (confidence={confidence:.2f})", + intent_result=intent_result + ) + + # ── Fallback to v1 logic (backward compatibility) ───────────────────── + + # If intent is OTHER (small talk), no search needed + if intent == "OTHER": + return SearchStrategy( + use_live=False, + use_db=False, + live_weight=0.0, + db_weight=0.0, + reason="Small talk - no search needed (v1 fallback)", + intent_result=intent_result + ) + + query_lower = query.lower() + + # Check for temporal keywords → prioritize live search + has_temporal = any(kw in query_lower for kw in self.TEMPORAL_KEYWORDS) + + # Check for historical keywords → prioritize database + has_historical = any(kw in query_lower for kw in self.HISTORICAL_KEYWORDS) + + if has_temporal and not has_historical: + # Temporal query → prioritize live search + return SearchStrategy( + use_live=True, + use_db=True, + live_weight=0.7, + db_weight=0.3, + reason="Temporal query - prioritize live (v1 fallback)", + intent_result=intent_result + ) + + elif has_historical and not has_temporal: + # Historical query → use database only + return SearchStrategy( + use_live=False, + use_db=True, + live_weight=0.0, + db_weight=1.0, + reason="Historical query - database only (v1 fallback)", + intent_result=intent_result + ) + + else: + # Balanced hybrid search + return SearchStrategy( + use_live=True, + use_db=True, + live_weight=self.default_live_weight, + db_weight=self.default_db_weight, + reason="Balanced hybrid search (v1 fallback)", + intent_result=intent_result + ) + + async def execute_hybrid_search( + self, + query: str, + english_query: str, + strategy: SearchStrategy, + # Database search components (from existing pipeline) + embedder, + vector_store, + translated_queries: Dict[str, str], + top_k: int = 7 + ) -> Tuple[List[Any], List[Dict[str, Any]]]: + """ + Execute hybrid search with parallel live and database searches. + + Integrates with existing multilingual pipeline: + - Reuses dense vector (computed once from English) + - Reuses sparse vectors (batched for 6 languages) + - Adds live search in parallel + + Args: + query: Original user query + english_query: English translation + strategy: Search strategy from decide_search_strategy() + embedder: BGE-M3 embedder adapter + vector_store: Qdrant adapter + translated_queries: Dict of {lang: translated_query} + top_k: Results per language for DB search + + Returns: + Tuple of (db_results, live_results) + """ + tasks = [] + + # 1. Database search (if enabled) + if strategy.use_db: + db_task = self._execute_db_search( + english_query=english_query, + translated_queries=translated_queries, + embedder=embedder, + vector_store=vector_store, + top_k=top_k + ) + tasks.append(db_task) + + # 2. Live search (if enabled) + if strategy.use_live: + live_task = self._execute_live_search(english_query) + tasks.append(live_task) + + # Execute all searches in parallel + if not tasks: + return [], [] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Extract results with error handling + db_results = [] + live_results = [] + + if strategy.use_db: + if isinstance(results[0], Exception): + logger.error(f"Database search failed: {results[0]}") + else: + db_results = results[0] + + if strategy.use_live: + result_idx = 1 if strategy.use_db else 0 + if isinstance(results[result_idx], Exception): + logger.error(f"Live search failed: {results[result_idx]}") + else: + live_results = results[result_idx] + + logger.info( + f"Hybrid search completed: {len(db_results)} DB + {len(live_results)} live results" + ) + + return db_results, live_results + + async def _execute_db_search( + self, + english_query: str, + translated_queries: Dict[str, str], + embedder, + vector_store, + top_k: int + ) -> List[Any]: + """ + Execute multilingual database search (existing pipeline). + + Leverages existing optimizations: + - Dense vector computed once (language-agnostic) + - Sparse vectors batched (6 languages in 1 pass) + - Parallel Qdrant queries (6 lanes) + - Deduplication by doc_id + + Args: + english_query: English query + translated_queries: Dict of {lang: translated_query} + embedder: BGE-M3 embedder + vector_store: Qdrant adapter + top_k: Results per language + + Returns: + List of deduplicated SearchResult objects + """ + try: + # 1. Compute dense vector once (language-agnostic) + dense_embedding = embedder.encode_query(english_query) + dense_vec = dense_embedding.get("dense") + + # 2. Batch sparse encoding for all 6 languages (existing optimization) + languages = ["en", "ar", "am", "so", "sw", "fr"] + sparse_queries = [translated_queries.get(lang, english_query) for lang in languages] + sparse_embeddings = embedder.encode_sparse_batch(sparse_queries) + + # 3. Parallel search across 6 languages (existing pattern) + search_tasks = [] + for i, lang in enumerate(languages): + sparse_vec = sparse_embeddings[i].get("sparse") + task = vector_store.search_with_vectors( + dense_vec=dense_vec, + sparse_vec=sparse_vec, + limit=top_k, + language_filter=lang + ) + search_tasks.append(task) + + # Execute all 6 lanes in parallel + lane_results = await asyncio.gather(*search_tasks, return_exceptions=True) + + # 4. Flatten and deduplicate by doc_id (existing logic) + all_docs = [] + for lane in lane_results: + if not isinstance(lane, Exception): + all_docs.extend(lane) + + # Deduplicate: keep highest-scoring version of each doc + seen = {} + for doc in all_docs: + doc_id = doc.metadata.get("doc_id") + if doc_id: + if doc_id not in seen or doc.score > seen[doc_id].score: + seen[doc_id] = doc + else: + # No doc_id, keep it + seen[id(doc)] = doc + + unique_docs = list(seen.values()) + logger.info(f"DB search: {len(all_docs)} total → {len(unique_docs)} unique") + + return unique_docs + + except Exception as e: + logger.error(f"Database search error: {e}") + raise + + async def _execute_live_search(self, query: str) -> List[Dict[str, Any]]: + """ + Execute live search with timeout and error handling. + + Args: + query: Search query (English) + + Returns: + List of normalized live search results + """ + try: + results = await self.live_search.search(query) + logger.info(f"Live search: {len(results)} results") + return results + except Exception as e: + logger.error(f"Live search error: {e}") + raise diff --git a/src/core/ports/__pycache__/analytics_db_port.cpython-313.pyc b/src/core/ports/__pycache__/analytics_db_port.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..890eca0de85177741c9c06930c5c592951c6a6c0 Binary files /dev/null and b/src/core/ports/__pycache__/analytics_db_port.cpython-313.pyc differ diff --git a/src/core/ports/__pycache__/cache_port.cpython-313.pyc b/src/core/ports/__pycache__/cache_port.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..051afc19dc9be9cd8ab5e72ae7ced1e74aa9232c Binary files /dev/null and b/src/core/ports/__pycache__/cache_port.cpython-313.pyc differ diff --git a/src/core/ports/__pycache__/chat_history_port.cpython-313.pyc b/src/core/ports/__pycache__/chat_history_port.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e95bc921a71b86d0bbce41c1562d3eef7c1091e Binary files /dev/null and b/src/core/ports/__pycache__/chat_history_port.cpython-313.pyc differ diff --git a/src/core/ports/__pycache__/embedder_port.cpython-313.pyc b/src/core/ports/__pycache__/embedder_port.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7703bbc9577715aa0680a0bdd960fb9dffa9d191 Binary files /dev/null and b/src/core/ports/__pycache__/embedder_port.cpython-313.pyc differ diff --git a/src/core/ports/__pycache__/llm_port.cpython-313.pyc b/src/core/ports/__pycache__/llm_port.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..779066847d440abad8443dc1158fd3f3cad983d7 Binary files /dev/null and b/src/core/ports/__pycache__/llm_port.cpython-313.pyc differ diff --git a/src/core/ports/__pycache__/reranker_port.cpython-313.pyc b/src/core/ports/__pycache__/reranker_port.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b85741ff467418da5dbce8907828aca8eb3f60f5 Binary files /dev/null and b/src/core/ports/__pycache__/reranker_port.cpython-313.pyc differ diff --git a/src/core/ports/__pycache__/vector_store_port.cpython-313.pyc b/src/core/ports/__pycache__/vector_store_port.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e70d5547ceb1e8b128293ad7be220a2e1907b45 Binary files /dev/null and b/src/core/ports/__pycache__/vector_store_port.cpython-313.pyc differ diff --git a/src/core/ports/analytics_db_port.py b/src/core/ports/analytics_db_port.py new file mode 100644 index 0000000000000000000000000000000000000000..993b0fb76feb3414a1c56b3822ef07bae72f0d46 --- /dev/null +++ b/src/core/ports/analytics_db_port.py @@ -0,0 +1,21 @@ +import abc +from typing import Dict, Any, List, Optional + +class AnalyticsDbPort(abc.ABC): + """ + Primary interface for connecting to the Data Warehouse / OLAP Database (e.g., ClickHouse). + """ + + @abc.abstractmethod + def query(self, sql: str, parameters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """ + Executes a raw SQL query against the analytics database. + """ + pass + + @abc.abstractmethod + def get_trends(self, days: int = 3, limit: int = 10) -> Dict[str, Any]: + """ + Retrieves trending topics / sentiment aggregations. + """ + pass diff --git a/src/core/ports/cache_port.py b/src/core/ports/cache_port.py new file mode 100644 index 0000000000000000000000000000000000000000..5bc28147f0d0a73de5152db70bb37709d7fd9bd2 --- /dev/null +++ b/src/core/ports/cache_port.py @@ -0,0 +1,25 @@ +import abc +from typing import Optional, Dict, Any + +class CachePort(abc.ABC): + """ + Interface for semantic caching (e.g., Redis). + """ + + @abc.abstractmethod + def get(self, key: str) -> Optional[Any]: + """Retrieve a value by its literal key.""" + pass + + @abc.abstractmethod + def set(self, key: str, value: Any, expiration: int = 3600) -> bool: + """Store a value with an expiration in seconds.""" + pass + + @abc.abstractmethod + def search_similar(self, query_vector: list, threshold: float = 0.95) -> Optional[Dict[str, Any]]: + """ + Search the cache for a highly semantically similar previous query, relying + on an external vector store or Redis Vector Search implementation. + """ + pass diff --git a/src/core/ports/chat_history_port.py b/src/core/ports/chat_history_port.py new file mode 100644 index 0000000000000000000000000000000000000000..eb874dfb20000da9cb258ca30d779aae1b334ac7 --- /dev/null +++ b/src/core/ports/chat_history_port.py @@ -0,0 +1,24 @@ +import abc +from typing import List, Dict, Any + +class ChatMessage(abc.ABC): + role: str + content: str + +class ChatHistoryPort(abc.ABC): + """ + Primary interface for persisting chat interactions (e.g., PostgreSQL). + """ + + @abc.abstractmethod + def get_history(self, session_id: str, limit: int = 6) -> List[ChatMessage]: + """ + Retrieves recent chat history for a session. + """ + pass + + @abc.abstractmethod + def save_interaction(self, session_id: str, user_query: str, ai_response: str, + retrieved_doc_ids: List[str], user_id: int = None): + """Persists a user query and the AI's generated response.""" + pass diff --git a/src/core/ports/embedder_port.py b/src/core/ports/embedder_port.py new file mode 100644 index 0000000000000000000000000000000000000000..3465d6d39a212ba7eddc070eba78147ab98a05b4 --- /dev/null +++ b/src/core/ports/embedder_port.py @@ -0,0 +1,43 @@ +import abc +from typing import Dict, Any, List + +class EmbedderPort(abc.ABC): + """ + Primary interface for converting text into vectors. + """ + + @abc.abstractmethod + def encode_query(self, text: str) -> Dict[str, Any]: + """ + Encodes a query string into dense and sparse vectors. + + Returns: + Dict containing 'dense' (List[float]) and 'sparse' (Dict) representations. + """ + pass + + def encode_sparse_only(self, text: str) -> Dict[str, Any]: + """ + Encodes a query string into sparse (BM25/lexical) vectors only. + Skips dense computation — used for per-language sparse queries when + the dense vector has already been computed from the English query. + + Default implementation falls back to encode_query and discards dense. + Override in subclasses for efficiency. + + Returns: + Dict with 'sparse' key only: {"indices": [...], "values": [...]} + """ + result = self.encode_query(text) + return {"sparse": result.get("sparse")} + + def encode_sparse_batch(self, texts: List[str]) -> List[Dict[str, Any]]: + """ + Encode multiple texts into sparse vectors in a single forward pass. + Default implementation calls encode_sparse_only() per text. + Override in subclasses for batched efficiency. + + Returns: + List of dicts with 'sparse' key, in same order as texts. + """ + return [self.encode_sparse_only(t) for t in texts] diff --git a/src/core/ports/llm_port.py b/src/core/ports/llm_port.py new file mode 100644 index 0000000000000000000000000000000000000000..967cce0bd8bc4c5cc8e9ec6109bee84c2d0786ee --- /dev/null +++ b/src/core/ports/llm_port.py @@ -0,0 +1,21 @@ +import abc +from typing import AsyncGenerator + +class LlmPort(abc.ABC): + """ + Primary interface for Language Model generation. + """ + + @abc.abstractmethod + def generate(self, prompt: str) -> str: + """ + Generates a synchronous completion based on the prompt. + """ + pass + + @abc.abstractmethod + async def generate_stream(self, prompt: str) -> AsyncGenerator[str, None]: + """ + Generates an asynchronous stream of tokens based on the prompt. + """ + pass diff --git a/src/core/ports/reranker_port.py b/src/core/ports/reranker_port.py new file mode 100644 index 0000000000000000000000000000000000000000..ad8b00ef5f68c64e9ba52c63fa965c45e89c3f02 --- /dev/null +++ b/src/core/ports/reranker_port.py @@ -0,0 +1,14 @@ +import abc +from typing import List, Dict, Any + +class RerankerPort(abc.ABC): + """ + Primary interface for re-ranking search results based on a query. + """ + + @abc.abstractmethod + def rerank(self, query: str, docs: List[Dict[str, Any]], top_n: int = 5) -> List[Dict[str, Any]]: + """ + Re-ranks a list of documents against the initial query. + """ + pass diff --git a/src/core/ports/vector_store_port.py b/src/core/ports/vector_store_port.py new file mode 100644 index 0000000000000000000000000000000000000000..7ba68dbabde9d31a5f847cbf44779c648735aaa3 --- /dev/null +++ b/src/core/ports/vector_store_port.py @@ -0,0 +1,49 @@ +import abc +from typing import Dict, Any, List, Optional +from pydantic import BaseModel + +class SearchResult(BaseModel): + content: str + metadata: Dict[str, Any] + score: float + doc_id: Optional[str] = None + +class VectorStorePort(abc.ABC): + """ + Primary interface for connecting to the Vector Database (e.g., Qdrant). + """ + + @abc.abstractmethod + def search( + self, + query_vectors: Dict[str, Any], + limit: int = 5, + source_filter: Optional[str] = None, + language_filter: Optional[str] = None, + days_back: Optional[int] = None + ) -> List[SearchResult]: + """ + Performs a hybrid search (Dense + Sparse) against the vector database. + """ + pass + + @abc.abstractmethod + def get_by_doc_id(self, doc_id: str) -> Optional[SearchResult]: + """ + Retrieves a single contiguous document by its unique ID. + """ + pass + + @abc.abstractmethod + def get_collection_stats(self) -> Dict[str, Any]: + """ + Retrieves statistics about the underlying vector collection. + """ + pass + + @abc.abstractmethod + def browse(self, limit: int = 20, offset: int = 0, source: Optional[str] = None, language: Optional[str] = None) -> Dict[str, Any]: + """ + Browses the vector collection with pagination and filtering. + """ + pass diff --git a/src/core/ranking/__init__.py b/src/core/ranking/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3eed518ba95217cdc4ba86f1e83432bf65ec80e0 --- /dev/null +++ b/src/core/ranking/__init__.py @@ -0,0 +1 @@ +# Result Ranking Module diff --git a/src/core/ranking/hybrid_result_ranker.py b/src/core/ranking/hybrid_result_ranker.py new file mode 100644 index 0000000000000000000000000000000000000000..e052f8b692ac72c252af04dc9458e84bfd54dd18 --- /dev/null +++ b/src/core/ranking/hybrid_result_ranker.py @@ -0,0 +1,349 @@ +""" +Hybrid Result Ranker + +Merges and ranks results from multiple sources (live search + database). +Integrates with existing BGE-Reranker-v2-m3 for final cross-encoder ranking. + +Key Features: +- Normalizes results from different sources to common format +- Applies source-specific score adjustments +- Deduplicates by URL and title similarity +- Composite scoring (relevance + freshness + source quality) +- Final reranking with existing BGE reranker +""" + +import logging +from typing import List, Dict, Any +from datetime import datetime +import hashlib + +logger = logging.getLogger(__name__) + + +class HybridResultRanker: + """ + Merges and ranks results from live search and database. + + Process: + 1. Normalize both result types to common format + 2. Apply source-specific score adjustments + 3. Deduplicate by URL/title similarity + 4. Composite scoring (relevance + freshness + quality) + 5. Use existing BGE reranker for final ranking + """ + + # Source quality tiers (based on journalistic standards) + SOURCE_QUALITY = { + # Tier 1: International news agencies (1.0) + "reuters.com": 1.0, + "apnews.com": 1.0, + "afp.com": 1.0, + "bbc.com": 1.0, + "bbc.co.uk": 1.0, + + # Tier 2: Major international outlets (0.9) + "aljazeera.com": 0.9, + "cnn.com": 0.9, + "theguardian.com": 0.9, + "nytimes.com": 0.9, + "washingtonpost.com": 0.9, + + # Tier 3: Regional quality outlets (0.8) + "addisstandard.com": 0.8, + "ethiopianmonitor.com": 0.8, + "borkena.com": 0.8, + "ezega.com": 0.8, + + # Tier 4: Other known sources (0.7) + "fanabc.com": 0.7, + "ebc.et": 0.7, + "waltainfo.com": 0.7, + } + + def __init__(self, reranker): + """ + Initialize hybrid result ranker. + + Args: + reranker: Existing BGE-Reranker-v2-m3 adapter + """ + self.reranker = reranker + + def merge_and_rank( + self, + db_results: List[Any], # SearchResult objects from Qdrant + live_results: List[Dict[str, Any]], # Dicts from DuckDuckGo + strategy, # SearchStrategy object + query: str, + final_top_n: int = 10 + ) -> List[Dict[str, Any]]: + """ + Merge results from database and live sources, then rerank. + + Args: + db_results: SearchResult objects from Qdrant + live_results: Normalized dicts from DuckDuckGo + strategy: SearchStrategy with weights + query: Original user query + final_top_n: Number of final results to return + + Returns: + List of top-N ranked results + """ + # 1. Normalize both result types to common format + all_results = [] + + # Normalize database results + for r in db_results: + all_results.append(self._normalize_db_result(r, strategy.db_weight)) + + # Normalize live results + for r in live_results: + all_results.append(self._normalize_live_result(r, strategy.live_weight)) + + if not all_results: + logger.warning("No results to rank") + return [] + + logger.info( + f"Merging {len(db_results)} DB + {len(live_results)} live = " + f"{len(all_results)} total results" + ) + + # 2. Deduplicate by URL and title similarity + unique_results = self._deduplicate(all_results) + logger.info(f"After deduplication: {len(unique_results)} unique results") + + # 3. Apply composite scoring + for r in unique_results: + r["composite_score"] = self._calculate_composite_score(r) + + # 4. Pre-sort by composite score + unique_results.sort(key=lambda x: x["composite_score"], reverse=True) + + # 5. Use existing BGE reranker for final ranking + # Take top 20 candidates for reranking (balance quality vs speed) + top_candidates = unique_results[:20] + + if not top_candidates: + return [] + + # Reranker expects list of dicts with "content" key + try: + reranked = self.reranker.rerank( + query=query, + docs=top_candidates, + top_n=final_top_n + ) + logger.info(f"Reranking: {len(top_candidates)} candidates → {len(reranked)} final") + return reranked + except Exception as e: + logger.error(f"Reranking failed: {e}. Returning pre-sorted results.") + return top_candidates[:final_top_n] + + def _normalize_db_result(self, result: Any, weight: float) -> Dict[str, Any]: + """ + Normalize database SearchResult to common format. + + Args: + result: SearchResult object from Qdrant + weight: Weight to apply to score + + Returns: + Normalized dict + """ + return { + "content": result.content, + "metadata": result.metadata, + "score": result.score * weight, # Apply strategy weight + "doc_id": result.doc_id, + "source_type": "database", + "title": result.metadata.get("title", ""), + "url": result.metadata.get("url", ""), + "source": result.metadata.get("source", ""), + "published_at": result.metadata.get("published_at", ""), + "language": result.metadata.get("language", "en"), + "freshness_score": self._calculate_freshness( + result.metadata.get("published_at") + ), + "source_quality": self._get_source_quality( + result.metadata.get("source", "") + ) + } + + def _normalize_live_result(self, result: Dict[str, Any], weight: float) -> Dict[str, Any]: + """ + Normalize live search result to common format. + + Args: + result: Dict from DuckDuckGo adapter + weight: Weight to apply to score + + Returns: + Normalized dict + """ + # Live results already have most fields from adapter + result["score"] = result.get("freshness_score", 0.8) * weight + result["source_quality"] = self._get_source_quality(result.get("source", "")) + return result + + def _deduplicate(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Deduplicate by URL and title similarity. + Keeps highest-scoring version of each article. + + Args: + results: List of normalized results + + Returns: + List of unique results + """ + seen_urls = {} + seen_titles = {} + unique = [] + + for r in results: + url = r.get("url", "").lower().strip() + title = r.get("title", "").lower().strip() + + # Generate URL hash for exact matching + url_hash = hashlib.md5(url.encode()).hexdigest() if url else None + + # Generate title hash for similarity matching (first 50 chars) + title_key = title[:50] if title else "" + title_hash = hashlib.md5(title_key.encode()).hexdigest() if title_key else None + + # Check URL exact match + if url_hash and url_hash in seen_urls: + # Keep higher-scoring version + if r["score"] > seen_urls[url_hash]["score"]: + unique.remove(seen_urls[url_hash]) + unique.append(r) + seen_urls[url_hash] = r + continue + + # Check title similarity + if title_hash and title_hash in seen_titles: + # Keep higher-scoring version + if r["score"] > seen_titles[title_hash]["score"]: + unique.remove(seen_titles[title_hash]) + unique.append(r) + seen_titles[title_hash] = r + continue + + # New unique result + unique.append(r) + if url_hash: + seen_urls[url_hash] = r + if title_hash: + seen_titles[title_hash] = r + + return unique + + def _calculate_composite_score(self, result: Dict[str, Any]) -> float: + """ + Calculate composite score from multiple factors. + + Factors: + - Relevance (vector/base score): 40% + - Freshness (recency): 30% + - Source quality (reputation): 20% + - Source type boost (live vs DB): 10% + + Args: + result: Normalized result dict + + Returns: + Composite score (0.0 to 1.0+) + """ + relevance = result.get("score", 0.5) + freshness = result.get("freshness_score", 0.5) + source_quality = result.get("source_quality", 0.6) + + # Boost live results slightly (they're fresher by definition) + source_type_boost = 1.0 if result.get("source_type") == "live" else 0.5 + + composite = ( + relevance * 0.4 + + freshness * 0.3 + + source_quality * 0.2 + + source_type_boost * 0.1 + ) + + return composite + + def _calculate_freshness(self, published_at: Any) -> float: + """ + Calculate freshness score based on article age. + More recent = higher score. + + Args: + published_at: ISO date string or datetime object + + Returns: + Freshness score (0.0 to 1.0) + """ + if not published_at: + return 0.5 # Unknown age + + try: + # Parse date + if isinstance(published_at, str): + pub_date = datetime.fromisoformat(published_at.replace('Z', '+00:00')) + else: + pub_date = published_at + + # Calculate age + age = datetime.utcnow() - pub_date.replace(tzinfo=None) + age_minutes = age.total_seconds() / 60 + + # Freshness scoring + if age_minutes < 10: + return 1.0 # Very fresh (< 10 min) + elif age_minutes < 60: + return 0.95 # Fresh (< 1 hour) + elif age_minutes < 360: + return 0.85 # Recent (< 6 hours) + elif age_minutes < 1440: + return 0.7 # Today (< 24 hours) + elif age_minutes < 10080: + return 0.5 # This week (< 7 days) + else: + return 0.3 # Old (> 7 days) + + except Exception as e: + logger.debug(f"Failed to parse date '{published_at}': {e}") + return 0.5 + + def _get_source_quality(self, source: str) -> float: + """ + Get source quality score based on reputation. + + Args: + source: Source name or domain + + Returns: + Quality score (0.0 to 1.0) + """ + if not source: + return 0.6 # Unknown source + + source_lower = source.lower() + + # Check exact domain matches + for domain, quality in self.SOURCE_QUALITY.items(): + if domain in source_lower: + return quality + + # Check partial matches for known sources + if any(s in source_lower for s in ["reuters", "associated press", "ap news", "afp", "bbc"]): + return 1.0 + elif any(s in source_lower for s in ["aljazeera", "cnn", "guardian", "nytimes"]): + return 0.9 + elif any(s in source_lower for s in ["addis", "ethiopian", "borkena", "ezega"]): + return 0.8 + elif any(s in source_lower for s in ["fana", "ebc", "walta"]): + return 0.7 + + # Unknown source - neutral quality + return 0.6 diff --git a/src/core/security.py b/src/core/security.py new file mode 100644 index 0000000000000000000000000000000000000000..4f4dd368886f579ec68d060d7c2d197999480c77 --- /dev/null +++ b/src/core/security.py @@ -0,0 +1,103 @@ +import secrets +from datetime import datetime, timedelta +from typing import Optional +from jose import jwt, JWTError +import bcrypt +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from sqlalchemy.orm import Session + +from src.core.config import settings +from src.infrastructure.database import get_db +from src.core.domain.db_models import User, RefreshToken, UserRole + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/login") + +ALGORITHM = "HS256" +REFRESH_TOKEN_EXPIRE_DAYS = 30 + + +# ── Password helpers ────────────────────────────────────────────────────────── + +def verify_password(plain: str, hashed: str) -> bool: + return bcrypt.checkpw(plain.encode(), hashed.encode()) + +def get_password_hash(password: str) -> str: + return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() + + +# ── JWT access token ────────────────────────────────────────────────────────── + +def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: + to_encode = data.copy() + expire = datetime.utcnow() + (expires_delta or timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)) + to_encode.update({"exp": expire, "type": "access"}) + return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) + + +# ── Refresh token ───────────────────────────────────────────────────────────── + +def create_refresh_token(user_id: int, db: Session) -> str: + token = secrets.token_urlsafe(64) + expires_at = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) + db_token = RefreshToken(token=token, user_id=user_id, expires_at=expires_at) + db.add(db_token) + db.commit() + return token + +def rotate_refresh_token(old_token: str, db: Session) -> tuple[str, "User"]: + """Validate old refresh token, revoke it, issue a new one.""" + record = db.query(RefreshToken).filter( + RefreshToken.token == old_token, + RefreshToken.revoked == False + ).first() + if not record or record.expires_at < datetime.utcnow(): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired refresh token") + record.revoked = True + db.commit() + new_token = create_refresh_token(record.user_id, db) + return new_token, record.user + +def revoke_all_refresh_tokens(user_id: int, db: Session): + db.query(RefreshToken).filter( + RefreshToken.user_id == user_id, + RefreshToken.revoked == False + ).update({"revoked": True}) + db.commit() + + +# ── Current user dependency ─────────────────────────────────────────────────── + +def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)) -> User: + exc = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) + if payload.get("type") != "access": + raise exc + email: str = payload.get("sub") + if not email: + raise exc + except JWTError: + raise exc + + user = db.query(User).filter(User.email == email).first() + if not user or not user.is_active: + raise exc + return user + +def get_current_active_user(current_user: User = Depends(get_current_user)) -> User: + if not current_user.is_active: + raise HTTPException(status_code=400, detail="Inactive user") + return current_user + + +# ── Role guards ─────────────────────────────────────────────────────────────── + +def require_super_admin(current_user: User = Depends(get_current_user)) -> User: + if current_user.role != UserRole.super_admin: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Super admin access required") + return current_user diff --git a/src/core/use_cases/__pycache__/account_use_case.cpython-313.pyc b/src/core/use_cases/__pycache__/account_use_case.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7dfc8e88a7bf90d776ced49e49dee047c9cf550 Binary files /dev/null and b/src/core/use_cases/__pycache__/account_use_case.cpython-313.pyc differ diff --git a/src/core/use_cases/__pycache__/agent_router_use_case.cpython-313.pyc b/src/core/use_cases/__pycache__/agent_router_use_case.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..978e5eba1fdbc8bfe99ac668425e2664eb68be21 Binary files /dev/null and b/src/core/use_cases/__pycache__/agent_router_use_case.cpython-313.pyc differ diff --git a/src/core/use_cases/__pycache__/analytics_use_case.cpython-313.pyc b/src/core/use_cases/__pycache__/analytics_use_case.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0348aba6ec3222dbab18702eed95f33f1f6fad5c Binary files /dev/null and b/src/core/use_cases/__pycache__/analytics_use_case.cpython-313.pyc differ diff --git a/src/core/use_cases/__pycache__/rag_chat_use_case.cpython-313.pyc b/src/core/use_cases/__pycache__/rag_chat_use_case.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eaebbf1984e320f576d934e97a6d040fece4b9a2 Binary files /dev/null and b/src/core/use_cases/__pycache__/rag_chat_use_case.cpython-313.pyc differ diff --git a/src/core/use_cases/__pycache__/search_use_case.cpython-313.pyc b/src/core/use_cases/__pycache__/search_use_case.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16e9af8fc7617e66cde2b1247bdb5e1d5116081c Binary files /dev/null and b/src/core/use_cases/__pycache__/search_use_case.cpython-313.pyc differ diff --git a/src/core/use_cases/account_use_case.py b/src/core/use_cases/account_use_case.py new file mode 100644 index 0000000000000000000000000000000000000000..8a9cf48f4d4ea6008f849b1c07df5daab2d9d393 --- /dev/null +++ b/src/core/use_cases/account_use_case.py @@ -0,0 +1,19 @@ +from typing import Dict, Any, AsyncGenerator +from src.core.domain.schemas import ChatRequest + +class AccountUseCase: + """ + Dummy Use Case for handling account-specific logic in the Multi-Agent router. + """ + def execute_chat(self, request: ChatRequest) -> Dict[str, Any]: + return { + "answer": "This is a placeholder for the Account History service. The Multi-Agent Router successfully classified your intent and directed you here instead of the News RAG pipeline.", + "sources": [], + "session_id": request.session_id, + "debug": {"routed_to": "AccountUseCase"} + } + + async def execute_stream(self, request: ChatRequest) -> AsyncGenerator[str, None]: + import json + yield f"data: {json.dumps({'token': 'This is a placeholder for the Account History service. The Multi-Agent Router successfully classified your intent and directed you here instead of the News RAG pipeline.'})}\n\n" + yield "data: [DONE]\n\n" diff --git a/src/core/use_cases/agent_router_use_case.py b/src/core/use_cases/agent_router_use_case.py new file mode 100644 index 0000000000000000000000000000000000000000..bf619d6944e2c8f168c462cc7c45d01159619777 --- /dev/null +++ b/src/core/use_cases/agent_router_use_case.py @@ -0,0 +1,74 @@ +from typing import Dict, Any, AsyncGenerator +from src.core.ports.llm_port import LlmPort +from src.core.ports.chat_history_port import ChatHistoryPort +from src.core.domain.schemas import ChatRequest +from src.core.use_cases.rag_chat_use_case import RagChatUseCase +from src.core.use_cases.account_use_case import AccountUseCase +from src.infrastructure.adapters.intent_classifier import intent_classifier + + +class AgentRouterUseCase: + """ + Multi-Agent Router — classifies user intent with a local DeBERTa NLI model + then delegates to the RAG pipeline (NEWS) or direct LLM (OTHER). + """ + def __init__(self, llm: LlmPort, rag_chat: RagChatUseCase, account: AccountUseCase, chat_history_db: ChatHistoryPort): + self.llm = llm + self.rag_chat = rag_chat + self.account = account + self.chat_history_db = chat_history_db + + def _classify_intent(self, query: str) -> str: + intent = intent_classifier.classify(query) + print(f"DEBUG: Intent Classification: {intent} for query: '{query[:80]}'") + return intent + + def execute_chat(self, request: ChatRequest) -> Dict[str, Any]: + intent = self._classify_intent(request.query) + session_id = request.session_id or "anonymous" + + if intent == "OTHER": + print(f"DEBUG: Routing to OTHER (Direct LLM Response)") + answer = self.llm.generate(request.query) + return { + "answer": answer, + "sources": [], + "session_id": session_id, + "debug": {"intent": "OTHER", "routed_to": "DirectLLM"} + } + else: + print(f"DEBUG: Routing to NEWS (RAG Pipeline)") + return self.rag_chat.execute_chat(request) + + async def execute_stream(self, request: ChatRequest, is_guest: bool = False, user_id: int = None) -> AsyncGenerator[str, None]: + intent = self._classify_intent(request.query) + session_id = request.session_id or "anonymous" + + if intent == "OTHER": + full_answer = "" + async for chunk in self.llm.generate_stream(request.query): + yield chunk + if chunk.startswith("data: "): + try: + import json + data = json.loads(chunk[6:]) + if "token" in data: + full_answer += data["token"] + except: + pass + + import json + final_response = { + "answer": full_answer, + "sources": [], + "session_id": session_id, + "debug": {"intent": "OTHER", "routed_to": "DirectLLM"} + } + yield f"data: {json.dumps(final_response)}\n\n" + yield "data: [DONE]\n\n" + + if not is_guest: + self.chat_history_db.save_interaction(session_id, request.query, full_answer, [], user_id=user_id) + else: + async for chunk in self.rag_chat.execute_stream(request, is_guest=is_guest, user_id=user_id): + yield chunk diff --git a/src/core/use_cases/analytics_use_case.py b/src/core/use_cases/analytics_use_case.py new file mode 100644 index 0000000000000000000000000000000000000000..43d7f4c91cf60e54abc36ef07a2006643451167c --- /dev/null +++ b/src/core/use_cases/analytics_use_case.py @@ -0,0 +1,15 @@ +from typing import Dict, Any, Optional +from src.core.ports.analytics_db_port import AnalyticsDbPort + +class AnalyticsUseCase: + """ + Use Case: Fetch statistical aggregations and entity trends from the Data Warehouse. + """ + def __init__(self, analytics_db: AnalyticsDbPort): + self.analytics_db = analytics_db + + def get_trends(self, days: int = 3, limit: int = 10) -> Dict[str, Any]: + return self.analytics_db.get_trends(days=days, limit=limit) + + def execute_raw_query(self, sql: str, parameters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + return self.analytics_db.query(sql, parameters) diff --git a/src/core/use_cases/rag_chat_use_case.py b/src/core/use_cases/rag_chat_use_case.py new file mode 100644 index 0000000000000000000000000000000000000000..fee8d33537518c8529b1b4809a95e5097373fe0a --- /dev/null +++ b/src/core/use_cases/rag_chat_use_case.py @@ -0,0 +1,791 @@ +import tiktoken +import asyncio +import concurrent.futures +from datetime import datetime +from typing import List, Dict, Any, Tuple, AsyncGenerator, Optional + +from src.core.ports.embedder_port import EmbedderPort +from src.core.ports.vector_store_port import VectorStorePort, SearchResult +from src.core.ports.reranker_port import RerankerPort +from src.core.ports.llm_port import LlmPort +from src.core.ports.chat_history_port import ChatHistoryPort +from src.core.ports.analytics_db_port import AnalyticsDbPort +from src.core.ports.cache_port import CachePort +from src.core.domain.schemas import ChatRequest # We assume this still exists + +# Hybrid Search Components +from src.core.orchestrator.query_orchestrator import QueryOrchestrator +from src.core.ranking.hybrid_result_ranker import HybridResultRanker + +# Supported languages for multilingual parallel search +# Maps language code → human name (used in prompts and source labels) +SUPPORTED_LANGUAGES = { + "en": "English", + "ar": "Arabic", + "am": "Amharic", + "so": "Somali", + "sw": "Swahili", + "fr": "French", +} + +class RagChatUseCase: + """ + Core Business Logic for the full Retrieval-Augmented Generation pipeline. + """ + def __init__( + self, + embedder: EmbedderPort, + vector_store: VectorStorePort, + reranker: RerankerPort, + llm: LlmPort, + chat_history_db: ChatHistoryPort, + analytics_db: Optional[AnalyticsDbPort] = None, + cache: Optional[CachePort] = None, + orchestrator: Optional[QueryOrchestrator] = None, + hybrid_ranker: Optional[HybridResultRanker] = None + ): + self.embedder = embedder + self.vector_store = vector_store + self.reranker = reranker + self.llm = llm + self.chat_history_db = chat_history_db + self.analytics_db = analytics_db + self.cache = cache + self.orchestrator = orchestrator + self.hybrid_ranker = hybrid_ranker + + # Import intent classifier v2 for production-grade classification + try: + from src.infrastructure.adapters.intent_classifier_v2 import intent_classifier_v2 + self.intent_classifier_v2 = intent_classifier_v2 + self.use_v2_classifier = True + logger.info("Using production-grade intent classifier v2") + except ImportError: + # Fallback to v1 + from src.infrastructure.adapters.intent_classifier import intent_classifier + self.intent_classifier = intent_classifier + self.intent_classifier_v2 = None + self.use_v2_classifier = False + logger.warning("Falling back to intent classifier v1") + + def _compress_document(self, query: str, document_text: str) -> str: + prompt = f"""Extract pure facts and vital information from the following document that are relevant to answering the query: "{query}". +Be very concise and omit irrelevant details. Return ONLY the relevant facts. + +Document: +{document_text}""" + try: + return self.llm.generate(prompt) + except Exception: + return document_text[:500] + "... (truncated)" + + def _limit_context(self, query: str, docs: List[Dict[str, Any]], max_tokens: int = 3000) -> Tuple[str, List[Dict[str, Any]]]: + enc = tiktoken.get_encoding("cl100k_base") + total_tokens = 0 + context_text = "" + filtered_sources = [] + + for doc in docs: + content = doc.get("content", "") + metadata = doc.get("metadata", {}) + + # Extract source name from multiple possible fields + source_name = ( + metadata.get("source") or + metadata.get("title") or + doc.get("source") or + "Unknown" + ) + + # Append language tag to source label for non-English articles + search_lang = metadata.get("_search_lang", "en") + if search_lang and search_lang != "en": + lang_label = SUPPORTED_LANGUAGES.get(search_lang, search_lang.upper()) + source_label = f"{source_name} ({lang_label})" + else: + source_label = source_name + + # Format: prepend source so LLM can cite it + formatted = f"[Source: {source_label}]\n{content}" + + tokens = len(enc.encode(formatted)) + + if total_tokens + tokens > max_tokens: + formatted = self._compress_document(query, content) + formatted = f"[Source: {source_label}]\n{formatted}" + tokens = len(enc.encode(formatted)) + + if total_tokens + tokens > max_tokens: + remaining = max_tokens - total_tokens + if remaining > 50: + truncated = formatted[:remaining * 4] + context_text += truncated + "\n\n" + doc["content"] = content[:remaining * 4] + "... (truncated)" + filtered_sources.append(doc) + break + + context_text += formatted + "\n\n" + total_tokens += tokens + filtered_sources.append(doc) + + return context_text, filtered_sources + + def _extract_intents_and_translate(self, query: str) -> Dict[str, Any]: + """ + Single LLM call that does BOTH intent extraction AND multilingual translation. + Replaces the previous two-call approach (_extract_intents + _translate_query_to_all_languages). + + Returns: + { + "expanded_query": str, # cleaned English query + "days_back": int | None, # temporal filter + "source": str | None, # source filter + "translations": { # per-language queries for Qdrant + "en": str, "ar": str, "am": str, "so": str, "sw": str, "fr": str + } + } + Falls back gracefully on any LLM/parse failure. + """ + import re, json + + prompt = f"""You are a search query processor for a multilingual news system. Analyze the query and output ONLY valid JSON. + +Required fields: +- "days_back": integer or null + * 1=today/now/tonight, 2=yesterday, 7=this week/recently, 30=this month, 365=this year, null=no time +- "source": string or null — news outlet name if mentioned (e.g. "BBC", "Reuters"), else null +- "topic": string — the SPECIFIC search topic with typos fixed, time/source words removed. + IMPORTANT: Be precise. "Ethiopia peace talks" must stay "Ethiopia peace talks", NOT expand to "Sudan peace talks" or "Horn of Africa conflict". + Keep named entities exact. Do NOT broaden or generalize the topic. +- "translations": object with these 5 keys (translate "topic" into each language): + * "ar": Arabic + * "am": Amharic (Ethiopian script) + * "so": Somali + * "sw": Swahili + * "fr": French + Rules for translations: keep concise and specific, preserve named entities in local spelling, use English term if uncertain. + +Query: "{query}" + +JSON:""" + + result: Dict[str, Any] = { + "expanded_query": query, + "days_back": None, + "source": None, + "translations": {lang: query for lang in ["ar", "am", "so", "sw", "fr"]}, + } + + try: + response = self.llm.generate(prompt).strip() + match = re.search(r'\{.*\}', response, re.DOTALL) + if match: + parsed = json.loads(match.group(0)) + + if isinstance(parsed.get("days_back"), int): + result["days_back"] = parsed["days_back"] + print(f"DEBUG: days_back={parsed['days_back']}") + + if isinstance(parsed.get("source"), str) and parsed["source"]: + result["source"] = parsed["source"] + print(f"DEBUG: source={parsed['source']}") + + if isinstance(parsed.get("topic"), str) and parsed["topic"].strip(): + result["expanded_query"] = parsed["topic"].strip() + print(f"DEBUG: expanded_query='{result['expanded_query']}'") + + translations = parsed.get("translations", {}) + if isinstance(translations, dict): + for lang in ["ar", "am", "so", "sw", "fr"]: + val = translations.get(lang, "").strip() + result["translations"][lang] = val if val else result["expanded_query"] + print(f"DEBUG: translation [{lang}]: {result['translations'][lang]}") + + return result + + except Exception as e: + print(f"DEBUG: Combined LLM call failed: {e} — using NLI fallback for temporal, English for translations") + + # ── NLI fallback for temporal only (translations stay as English) ──── + try: + from src.infrastructure.adapters.intent_classifier import intent_classifier + pipe = intent_classifier._pipe + if pipe is None: + intent_classifier._load() + pipe = intent_classifier._pipe + + if pipe is not None: + temporal_result = pipe( + query, + candidate_labels=[ + "asking about news from today or right now", + "asking about news from yesterday", + "asking about news from this week or past few days", + "asking about news from this month", + "asking about news from this year", + "no specific time period mentioned", + ], + hypothesis_template="The user is {}.", + multi_label=False, + ) + top_label = temporal_result["labels"][0] + top_score = temporal_result["scores"][0] + if top_score > 0.4: + if "today" in top_label or "right now" in top_label: + result["days_back"] = 1 + elif "yesterday" in top_label: + result["days_back"] = 2 + elif "week" in top_label or "past few days" in top_label: + result["days_back"] = 7 + elif "month" in top_label: + result["days_back"] = 30 + elif "year" in top_label: + result["days_back"] = 365 + print(f"DEBUG: NLI temporal fallback → days_back={result['days_back']}") + except Exception as e: + print(f"DEBUG: NLI fallback also failed: {e}") + + # ── Source regex fallback ───────────────────────────────────────────── + source_match = re.search( + r'\b(?:from|on|by|via|source[:\s]+)\s*([A-Z][A-Za-z]+(?:\s[A-Z][A-Za-z]+)?)\b', + query + ) + if source_match: + result["source"] = source_match.group(1) + + return result + + def _search_single_language( + self, + lang_code: str, + dense_vec: List[float], # shared English dense vector — computed once + sparse_vec: Optional[Dict], # language-specific sparse vector + per_lang_limit: int, + source_filter: Optional[str], + days_back: Optional[int], + ) -> List[Dict[str, Any]]: + """ + Search Qdrant for one language lane using pre-split vectors. + + dense_vec — English query dense vector (language-agnostic, BGE-M3 cross-lingual space) + sparse_vec — translated query sparse/BM25 vector (language-specific keyword matching) + + By keeping dense fixed and only varying sparse per language, we avoid + re-running the expensive dense encoder 6 times. + """ + try: + results = self.vector_store.search_with_vectors( + dense_vec=dense_vec, + sparse_vec=sparse_vec, + limit=per_lang_limit, + source_filter=source_filter, + language_filter=lang_code, + days_back=days_back, + ) + + # Temporal bias scoring + now = datetime.utcnow() + docs = [] + for hit in results: + score_multiplier = 1.0 + pub_at = hit.metadata.get("published_at") + if pub_at: + try: + pub_date = datetime.fromisoformat(pub_at.replace("Z", "+00:00")) + days_old = (now.replace(tzinfo=None) - pub_date.replace(tzinfo=None)).days + score_multiplier = max(0.5, 1.0 - (days_old / 60)) + except: + pass + + meta = dict(hit.metadata) + meta["_search_lang"] = lang_code + docs.append({ + "content": hit.content, + "metadata": meta, + "score": hit.score * score_multiplier, + "doc_id": hit.doc_id, + }) + + print(f"DEBUG: [{lang_code}] search returned {len(docs)} results") + return docs + + except Exception as e: + print(f"DEBUG: [{lang_code}] search failed: {e}") + return [] + + def _build_context(self, query: str, top_k: int, source_filter=None, language_filter=None, days_back=None) -> Tuple[str, List[Dict[str, Any]]]: + # ── Step 1: Single LLM call — intent extraction + multilingual translation ── + expanded_query = query + + # Sanitize placeholder filter values from API + actual_source_filter = None if source_filter == "string" else source_filter + actual_language_filter = None if (not language_filter or language_filter == "string") else language_filter + + # ── HYBRID SEARCH INTEGRATION ────────────────────────────────────────── + # Check if hybrid search is enabled and orchestrator is available + use_hybrid = self.orchestrator is not None and self.hybrid_ranker is not None + + if use_hybrid: + print(f"DEBUG: Hybrid search enabled - checking intent and strategy") + + # Classify intent using v2 (production-grade) or v1 (fallback) + if self.use_v2_classifier and self.intent_classifier_v2: + intent_result = self.intent_classifier_v2.classify(query) + intent = "NEWS" if intent_result.intent != "OTHER" else "OTHER" + + print(f"DEBUG: Intent classification v2: {intent_result.intent} " + f"(confidence={intent_result.confidence:.2f}, " + f"method={intent_result.method}, " + f"time={intent_result.inference_time_ms:.1f}ms)") + else: + intent = self.intent_classifier.classify(query) + intent_result = None + print(f"DEBUG: Intent classification v1: {intent}") + + # Decide search strategy (pass full intent_result for v2) + strategy = self.orchestrator.decide_search_strategy(query, intent, intent_result) + print(f"DEBUG: Search strategy: {strategy}") + + # If intent is OTHER (small talk), skip search entirely + if intent == "OTHER": + print(f"DEBUG: Small talk detected - skipping search") + return "", [] + else: + print(f"DEBUG: Hybrid search disabled - using traditional pipeline") + use_hybrid = False + strategy = None + + if actual_language_filter: + # Explicit language override — single-language mode, no translation needed + print(f"DEBUG: Language filter '{actual_language_filter}' — single-language mode") + lang_sparse_queries: Dict[str, str] = {actual_language_filter: expanded_query} + else: + print(f"DEBUG: Running combined intent extraction + translation...") + combined = self._extract_intents_and_translate(query) + + if combined.get("days_back") and isinstance(combined["days_back"], int): + days_back = combined["days_back"] + if combined.get("source") and isinstance(combined["source"], str): + # Normalize source name: "BBC" → stored as "bbc_english", "bbc_arabic" etc. + # We store the raw name and do prefix matching in Qdrant via the source filter. + # Convert to lowercase for case-insensitive matching. + actual_source_filter = combined["source"].lower().strip() + if combined.get("expanded_query") and isinstance(combined["expanded_query"], str): + expanded_query = combined["expanded_query"] + + # Map: lang_code → text to use for sparse (BM25) encoding + # English uses the cleaned expanded query; others use their translations + translations = combined.get("translations", {}) + lang_sparse_queries = {"en": expanded_query} + for lang in ["ar", "am", "so", "sw", "fr"]: + lang_sparse_queries[lang] = translations.get(lang, expanded_query) + + per_lang_limit = max(8, top_k * 2) + + # ── HYBRID SEARCH EXECUTION ──────────────────────────────────────────── + if use_hybrid and strategy and (strategy.use_live or strategy.use_db): + print(f"DEBUG: Executing hybrid search...") + + # Execute hybrid search (parallel live + DB) + try: + db_results, live_results = asyncio.run( + self.orchestrator.execute_hybrid_search( + query=query, + english_query=expanded_query, + strategy=strategy, + embedder=self.embedder, + vector_store=self.vector_store, + translated_queries=lang_sparse_queries, + top_k=per_lang_limit + ) + ) + + print(f"DEBUG: Hybrid search returned {len(db_results)} DB + {len(live_results)} live results") + + # Merge and rank results + all_docs = self.hybrid_ranker.merge_and_rank( + db_results=db_results, + live_results=live_results, + strategy=strategy, + query=expanded_query, + final_top_n=top_k * 3 # Get more candidates for quality filtering + ) + + print(f"DEBUG: After hybrid ranking: {len(all_docs)} results") + + except Exception as e: + print(f"DEBUG: Hybrid search failed: {e} - falling back to traditional pipeline") + use_hybrid = False + all_docs = [] + + # ── TRADITIONAL PIPELINE (fallback or when hybrid disabled) ──────────── + if not use_hybrid or not all_docs: + print(f"DEBUG: Using traditional multilingual pipeline") + + # ── Step 2: Compute dense vector ONCE from English query ────────────── + # BGE-M3 dense space is language-agnostic — one English dense vector + # semantically matches content in all 6 languages. + print(f"DEBUG: Computing dense vector for: '{expanded_query}'") + english_vectors = self.embedder.encode_query(expanded_query) + dense_vec: List[float] = english_vectors["dense"] + print(f"DEBUG: Dense vector ready ({len(dense_vec)} dims)") + + # ── Step 3: Batch sparse encoding — ONE forward pass for all languages ─ + # BGE-M3 holds the GIL during inference. ThreadPoolExecutor gives zero + # benefit for CPU-bound model calls — all 6 would run sequentially anyway. + # encode_sparse_batch() sends all 6 texts in a single model.encode() call: + # one tokenization pass + one transformer forward pass = ~5x faster. + lang_codes = list(lang_sparse_queries.keys()) + lang_texts = [lang_sparse_queries[lc] for lc in lang_codes] + + print(f"DEBUG: Batch sparse encoding {len(lang_texts)} language queries...") + sparse_results = self.embedder.encode_sparse_batch(lang_texts) + + lang_sparse_vecs: Dict[str, Optional[Dict]] = {} + for lc, result in zip(lang_codes, sparse_results): + sparse = result.get("sparse") + lang_sparse_vecs[lc] = sparse + token_count = len(sparse["indices"]) if sparse else 0 + print(f"DEBUG: [{lc}] sparse ready — {token_count} tokens") + + # ── Step 4: Fan out to Qdrant — 6 parallel searches ────────────────── + # Each lane: shared dense_vec + language-specific sparse_vec + language filter + all_docs: List[Dict[str, Any]] = [] + seen_doc_ids: set = set() + + def _run_search(lang_code: str, days_back_val) -> List[Dict[str, Any]]: + return self._search_single_language( + lang_code=lang_code, + dense_vec=dense_vec, + sparse_vec=lang_sparse_vecs.get(lang_code), + per_lang_limit=per_lang_limit, + source_filter=actual_source_filter, + days_back=days_back_val, + ) + + with concurrent.futures.ThreadPoolExecutor(max_workers=len(lang_sparse_queries)) as executor: + futures = { + executor.submit(_run_search, lc, days_back): lc + for lc in lang_sparse_queries + } + for future in concurrent.futures.as_completed(futures): + lc = futures[future] + try: + for doc in future.result(): + doc_id = doc.get("doc_id") + if doc_id and doc_id in seen_doc_ids: + continue + if doc_id: + seen_doc_ids.add(doc_id) + all_docs.append(doc) + except Exception as e: + print(f"DEBUG: [{lc}] future failed: {e}") + + print(f"DEBUG: Total pooled candidates after dedup: {len(all_docs)}") + + # ── Step 5: Temporal fallback — retry without date filter if zero results ── + self._temporal_fallback_used = False + if len(all_docs) == 0 and days_back is not None: + print(f"DEBUG: No results with days_back={days_back} — retrying without temporal filter") + with concurrent.futures.ThreadPoolExecutor(max_workers=len(lang_sparse_queries)) as executor: + futures = { + executor.submit(_run_search, lc, None): lc + for lc in lang_sparse_queries + } + for future in concurrent.futures.as_completed(futures): + lc = futures[future] + try: + for doc in future.result(): + doc_id = doc.get("doc_id") + if doc_id and doc_id in seen_doc_ids: + continue + if doc_id: + seen_doc_ids.add(doc_id) + all_docs.append(doc) + except Exception as e: + print(f"DEBUG: [{lc}] fallback future failed: {e}") + self._temporal_fallback_used = True + print(f"DEBUG: Fallback returned {len(all_docs)} total candidates") + + # ── Step 6: Multilingual reranking ──────────────────────────────────── + # bge-reranker-v2-m3 scores (English query, any-language content) natively + all_docs = self.reranker.rerank(expanded_query, all_docs, top_n=top_k * 3) + + # ── Step 7: Quality filter + relevance threshold + language diversity ─ + _BLOCKED_SOURCES = { + "naturalnews", "freerepublic", "infowars", "breitbart", + "journals", "dailymail", "thesun", "nypost", "thenation", + } + quality_docs = [ + d for d in all_docs + if (d.get("metadata", {}).get("source") or "").lower().replace(" ", "") not in _BLOCKED_SOURCES + ] or all_docs + + # ── Relevance threshold — drop docs the reranker scored too low ─────── + # Raised from 0.15 → 0.25 based on live testing. + # The airport article (bbc_swahili) was scoring ~0.18 on GERD queries + # because it mentioned "Addis Ababa" in a rankings list. + # 0.25 cuts it while keeping genuinely relevant multilingual content. + RERANK_THRESHOLD = 0.25 + above_threshold = [d for d in quality_docs if d.get("rerank_score", 1.0) >= RERANK_THRESHOLD] + if above_threshold: + quality_docs = above_threshold + print(f"DEBUG: {len(quality_docs)} docs above rerank threshold {RERANK_THRESHOLD}") + else: + # All scores low — keep top 3 anyway rather than returning nothing + quality_docs = quality_docs[:3] + print(f"DEBUG: All docs below threshold — keeping top 3 by rerank score") + + # ── Keyword overlap filter — drop docs with zero query term overlap ─── + # Catches articles that score above threshold due to incidental mentions + # (e.g. airport rankings mentioning "Addis Ababa" in a list). + # Extract meaningful query terms (3+ chars, skip stopwords). + _STOPWORDS = {"the", "and", "for", "are", "was", "what", "tell", "about", + "latest", "news", "from", "this", "that", "with", "have", + "did", "say", "said", "week", "today", "report", "reporting"} + query_terms = { + w.lower() for w in expanded_query.split() + if len(w) >= 3 and w.lower() not in _STOPWORDS + } + if query_terms: + def _has_overlap(doc: Dict[str, Any]) -> bool: + content_lower = doc.get("content", "").lower() + return any(term in content_lower for term in query_terms) + + overlapping = [d for d in quality_docs if _has_overlap(d)] + if overlapping: + quality_docs = overlapping + print(f"DEBUG: {len(quality_docs)} docs after keyword overlap filter") + else: + print(f"DEBUG: No keyword overlap — keeping all {len(quality_docs)} docs") + + # Guarantee at least 1 non-English result if available + non_english = [d for d in quality_docs if d.get("metadata", {}).get("_search_lang", "en") != "en"] + if non_english: + final_pool = [non_english[0]] + [d for d in quality_docs if d is not non_english[0]] + final_pool = final_pool[:top_k] + else: + final_pool = quality_docs[:top_k] + + # Deduplicate by doc_id + seen: set = set() + deduped_final: List[Dict[str, Any]] = [] + for d in final_pool: + did = d.get("doc_id") + if did in seen: + continue + seen.add(did) + deduped_final.append(d) + + langs_in_result = list({d.get("metadata", {}).get("_search_lang", "en") for d in deduped_final}) + print(f"DEBUG: Final {len(deduped_final)} docs — languages: {langs_in_result}") + + # ── Step 8: Token limitation ────────────────────────────────────────── + return self._limit_context(query, deduped_final) + + def _get_history_text(self, session_id: str) -> str: + past_messages = self.chat_history_db.get_history(session_id, limit=6) + return "".join([f"{msg.role}: {msg.content}\n" for msg in past_messages]) + + def _get_cache_keys(self, query: str) -> Dict[str, str]: + """ + Generate cache keys for different caching layers. + + Returns dict with keys: response, live, translation, intent + """ + if not self.cache: + return {} + + query_hash = self.cache.generate_exact_hash(query) + query_prefix_hash = self.cache.generate_exact_hash(query[:50]) + + from src.core.config import settings + + return { + "response": f"hybrid_response:{query_hash}", + "live": f"live_search:{query_hash}", + "translation": f"translation:{query_hash}", + "intent": f"intent:{query_prefix_hash}" + } + + def execute_chat(self, request: ChatRequest) -> Dict[str, Any]: + print(f"DEBUG: execute_chat called with query: {request.query}") + session_id = request.session_id or "default_session" + + # ── Layer 1: Full Response Cache (5 min TTL) ────────────────────────── + cache_keys = self._get_cache_keys(request.query) + + if self.cache and cache_keys.get("response"): + cached_result = self.cache.get(cache_keys["response"]) + if cached_result: + print("DEBUG: Cache HIT - Layer 1 (full response)") + self.chat_history_db.save_interaction( + session_id, + request.query, + cached_result["answer"], + [s.get("doc_id") for s in cached_result.get("sources", [])] + ) + cached_result["debug"] = cached_result.get("debug", {}) + cached_result["debug"]["cache_hit"] = True + cached_result["debug"]["cache_layer"] = "response" + return cached_result + + print("DEBUG: Cache MISS - Layer 1 (full response) - proceeding with RAG pipeline") + history_text = self._get_history_text(session_id) + + context_text, final_sources = self._build_context( + request.query, + request.top_k, + request.source_filter, + request.language_filter, + getattr(request, 'days_back', None) + ) + + # Optional Trend Fusion - Disabled for testing + # if self.analytics_db: + # try: + # trends = self.analytics_db.get_trends(days=3) + # if trends and trends["rows"]: + # trend_text = "Current Trending News Entities and Sentiment:\n" + # for row in trends["rows"][:5]: + # trend_text += f"- {row[0]}: Sentiment {row[1]:.2f}\n" + # context_text = f"{trend_text}\n\nRetrieved Search Context:\n{context_text}" + # except: pass + + prompt = f"""You are NEXUS, a real-time news assistant. Today's date is {datetime.utcnow().strftime("%B %d, %Y")}. + +STRICT RULES — READ CAREFULLY BEFORE ANSWERING: + +STEP 1 — CLASSIFY THE QUESTION: +- Is this a math, calculation, or general knowledge question (not about news)? → Say ONLY: "I couldn't find relevant news on that topic in today's feed." STOP. Do not calculate. Do not answer. +- Is this asking about a specific news source (e.g. "What did NYT report")? → Check if that source appears in the [Source:] tags below. If NOT found, say: "I couldn't find any [source name] articles in today's feed." STOP. + +STEP 2 — RELEVANCE CHECK: +- Read the News Context below carefully. +- Does the context DIRECTLY answer the user's specific question? (e.g. user asks "Ethiopia peace talks" — does context contain Ethiopia peace talks, not Sudan peace talks?) +- If NO direct match: say ONLY "I couldn't find relevant news on that topic in today's feed." STOP. +- If YES: proceed to Step 3. + +STEP 3 — ANSWER RULES: +1. Use ONLY facts from the News Context below. NEVER use training data or general knowledge. +2. CITATIONS: You may ONLY cite sources whose exact [Source: name] tag appears in the News Context below. NEVER invent or guess a source name. If a fact has no matching [Source:] tag, do not include that fact. +3. After EVERY fact, add inline citation: "— Source: name" using the exact name from the [Source:] tag. +4. Synthesize into numbered points with **bold** headlines. +5. Non-English articles — translate content to English, note language: "— Source: Al Jazeera (Arabic)". +6. Always respond in English. No hedging. No "based on my knowledge." + +News Context (from live multilingual database): +{context_text} + +Conversation History: +{history_text} + +User Question: {request.query} + +Answer:""" + + answer = self.llm.generate(prompt) + retrieved_ids = [str(doc.get("doc_id")) for doc in final_sources] + + self.chat_history_db.save_interaction(session_id, request.query, answer, retrieved_ids) + + # Check if any live results were included + has_live_results = any( + doc.get("source_type") == "live" or doc.get("is_live") + for doc in final_sources + ) + + result = { + "answer": answer, + "sources": final_sources, + "session_id": session_id, + "debug": { + "search_query": request.query, + "multilingual_search": True, + "languages_searched": list(SUPPORTED_LANGUAGES.keys()), + "temporal_bias": True, + "cache_hit": False, + "cache_layer": None, + "hybrid_search_used": self.orchestrator is not None, + "live_results_included": has_live_results, + "retrieved_docs_count": len(final_sources), + "context_provided": context_text[:1000] + "..." if len(context_text) > 1000 else context_text, + "full_context_available": True + } + } + + # ── Cache the full response (Layer 1) ───────────────────────────────── + if self.cache and cache_keys.get("response"): + from src.core.config import settings + self.cache.set( + cache_keys["response"], + result, + expiration=settings.CACHE_RESPONSE_TTL + ) + print(f"DEBUG: Cached full response (TTL={settings.CACHE_RESPONSE_TTL}s)") + + return result + + async def execute_stream(self, request: ChatRequest, is_guest: bool = False, user_id: int = None) -> AsyncGenerator[str, None]: + session_id = request.session_id or "default_session" + history_text = "" if is_guest else self._get_history_text(session_id) + context_text, final_sources = self._build_context( + request.query, request.top_k, request.source_filter, request.language_filter, getattr(request, 'days_back', None) + ) + + prompt_stream = f"""You are NEXUS, a real-time news assistant. Today's date is {datetime.utcnow().strftime("%B %d, %Y")}. + +STRICT RULES — READ CAREFULLY BEFORE ANSWERING: + +STEP 1 — CLASSIFY THE QUESTION: +- Is this a math, calculation, or general knowledge question (not about news)? → Say ONLY: "I couldn't find relevant news on that topic in today's feed." STOP. Do not calculate. Do not answer. +- Is this asking about a specific news source (e.g. "What did NYT report")? → Check if that source appears in the [Source:] tags below. If NOT found, say: "I couldn't find any [source name] articles in today's feed." STOP. + +STEP 2 — RELEVANCE CHECK: +- Read the News Context below carefully. +- Does the context DIRECTLY answer the user's specific question? (e.g. user asks "Ethiopia peace talks" — does context contain Ethiopia peace talks, not Sudan peace talks?) +- If NO direct match: say ONLY "I couldn't find relevant news on that topic in today's feed." STOP. +- If YES: proceed to Step 3. + +STEP 3 — ANSWER RULES: +1. Use ONLY facts from the News Context below. NEVER use training data or general knowledge. +2. CITATIONS: You may ONLY cite sources whose exact [Source: name] tag appears in the News Context below. NEVER invent or guess a source name. If a fact has no matching [Source:] tag, do not include that fact. +3. After EVERY fact, add inline citation: "— Source: name" using the exact name from the [Source:] tag. +4. Synthesize into numbered points with **bold** headlines. +5. Non-English articles — translate content to English, note language: "— Source: Al Jazeera (Arabic)". +6. Always respond in English. No hedging. No "based on my knowledge." + +News Context (from live multilingual database): +{context_text} + +Conversation History: +{history_text} + +User Question: {request.query} + +Answer:""" + full_answer = "" + async for chunk in self.llm.generate_stream(prompt_stream): + yield chunk + if chunk.startswith("data: "): + try: + import json + data = json.loads(chunk[6:]) + if "token" in data: + full_answer += data["token"] + except: + pass + + import json + final_response = { + "answer": full_answer, + "sources": final_sources, + "session_id": session_id + } + yield f"data: {json.dumps(final_response)}\n\n" + yield "data: [DONE]\n\n" + + # Only persist history for authenticated users + if not is_guest: + retrieved_ids = [str(doc.get("doc_id")) for doc in final_sources] + self.chat_history_db.save_interaction(session_id, request.query, full_answer, retrieved_ids, user_id=user_id) diff --git a/src/core/use_cases/search_use_case.py b/src/core/use_cases/search_use_case.py new file mode 100644 index 0000000000000000000000000000000000000000..8fbb5f67f79125807946dbc18ad4190cbe5c253b --- /dev/null +++ b/src/core/use_cases/search_use_case.py @@ -0,0 +1,34 @@ +from typing import List, Optional +from src.core.ports.embedder_port import EmbedderPort +from src.core.ports.vector_store_port import VectorStorePort, SearchResult + +class SearchUseCase: + """ + Use Case: Perform a lightning-fast hybrid search to find the most relevant articles. + """ + def __init__(self, embedder: EmbedderPort, vector_store: VectorStorePort): + self.embedder = embedder + self.vector_store = vector_store + + def execute( + self, + query: str, + limit: int = 10, + source_filter: Optional[str] = None, + language_filter: Optional[str] = None, + days_back: Optional[int] = None + ) -> List[SearchResult]: + + # 1. Transform query text into dense/sparse vectors + vectors = self.embedder.encode_query(query) + + # 2. Retrieve results from Vector Store + results = self.vector_store.search( + query_vectors=vectors, + limit=limit, + source_filter=source_filter, + language_filter=language_filter, + days_back=days_back + ) + + return results diff --git a/src/infrastructure/__pycache__/database.cpython-313.pyc b/src/infrastructure/__pycache__/database.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aed2704a798d4dbdc6b43c8f1e0589e53447c62f Binary files /dev/null and b/src/infrastructure/__pycache__/database.cpython-313.pyc differ diff --git a/src/infrastructure/adapters/__pycache__/bge_embedder_adapter.cpython-313.pyc b/src/infrastructure/adapters/__pycache__/bge_embedder_adapter.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4f489ba4ee7d4bd1ed5e5e155906cbe00f40f78 Binary files /dev/null and b/src/infrastructure/adapters/__pycache__/bge_embedder_adapter.cpython-313.pyc differ diff --git a/src/infrastructure/adapters/__pycache__/bge_reranker_adapter.cpython-313.pyc b/src/infrastructure/adapters/__pycache__/bge_reranker_adapter.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7fd0ad06a8d15bcb78991d7ec324c1d733f5272 Binary files /dev/null and b/src/infrastructure/adapters/__pycache__/bge_reranker_adapter.cpython-313.pyc differ diff --git a/src/infrastructure/adapters/__pycache__/clickhouse_adapter.cpython-313.pyc b/src/infrastructure/adapters/__pycache__/clickhouse_adapter.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77857ba460125335efc02e600691d1debcaadcb5 Binary files /dev/null and b/src/infrastructure/adapters/__pycache__/clickhouse_adapter.cpython-313.pyc differ diff --git a/src/infrastructure/adapters/__pycache__/ollama_adapter.cpython-313.pyc b/src/infrastructure/adapters/__pycache__/ollama_adapter.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ec2ea5330136a82ea7752779a61e2acc218ed3c Binary files /dev/null and b/src/infrastructure/adapters/__pycache__/ollama_adapter.cpython-313.pyc differ diff --git a/src/infrastructure/adapters/__pycache__/openai_adapter.cpython-313.pyc b/src/infrastructure/adapters/__pycache__/openai_adapter.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91d372c6399bb391f71131af87598d9e002efa85 Binary files /dev/null and b/src/infrastructure/adapters/__pycache__/openai_adapter.cpython-313.pyc differ diff --git a/src/infrastructure/adapters/__pycache__/postgres_adapter.cpython-313.pyc b/src/infrastructure/adapters/__pycache__/postgres_adapter.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2dc950515feefbf3a75b2dd87ee272ab24a3170 Binary files /dev/null and b/src/infrastructure/adapters/__pycache__/postgres_adapter.cpython-313.pyc differ diff --git a/src/infrastructure/adapters/__pycache__/qdrant_adapter.cpython-313.pyc b/src/infrastructure/adapters/__pycache__/qdrant_adapter.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08b3a68ed7928baaf890c2e5a14f678dc1cdec8f Binary files /dev/null and b/src/infrastructure/adapters/__pycache__/qdrant_adapter.cpython-313.pyc differ diff --git a/src/infrastructure/adapters/__pycache__/redis_adapter.cpython-313.pyc b/src/infrastructure/adapters/__pycache__/redis_adapter.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..816599d6127a3a66af3489624b0825ba6badb14a Binary files /dev/null and b/src/infrastructure/adapters/__pycache__/redis_adapter.cpython-313.pyc differ diff --git a/src/infrastructure/adapters/bge_embedder_adapter.py b/src/infrastructure/adapters/bge_embedder_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..cc42f63ee053c52e7b9f8b2cf93ad7733bef4694 --- /dev/null +++ b/src/infrastructure/adapters/bge_embedder_adapter.py @@ -0,0 +1,152 @@ +import logging +import os +from typing import Dict, Any, List + +from src.core.config import settings +from src.core.ports.embedder_port import EmbedderPort + +logger = logging.getLogger(__name__) + +if os.name == 'nt': + os.environ["HF_HUB_DISABLE_SYMLINKS"] = "1" + +try: + import transformers.utils.import_utils + if not hasattr(transformers.utils.import_utils, 'is_torch_fx_available'): + transformers.utils.import_utils.is_torch_fx_available = lambda: False + + from FlagEmbedding import BGEM3FlagModel + HAS_FLAG_EMBEDDING = True +except ImportError as e: + HAS_FLAG_EMBEDDING = False + logger.warning(f"FlagEmbedding not installed: {e}. Using dummy embeddings.") + +class BgeEmbedderAdapter(EmbedderPort): + def __init__(self): + self.model = None + self.model_name = settings.EMBEDDING_MODEL + + def _load_model(self): + if self.model is None: + if not HAS_FLAG_EMBEDDING: + logger.warning("FlagEmbedding not installed. Using dummy embeddings.") + return + + logger.info(f"Loading embedding model: {self.model_name}") + try: + self.model = BGEM3FlagModel(self.model_name, use_fp16=True) + logger.info(f"Successfully loaded {self.model_name} (Hybrid Mode)") + except Exception as e: + logger.error(f"Failed to load embedding model: {e}", exc_info=True) + raise e + + def encode_query(self, text: str) -> Dict[str, Any]: + """Encodes a query string into dense and sparse vectors.""" + if self.model is None: + self._load_model() + + if not HAS_FLAG_EMBEDDING or self.model is None: + return { + "dense": [0.1] * settings.VECTOR_SIZE, + "sparse": None + } + + embeddings = self.model.encode( + sentences=[text], + batch_size=1, + max_length=512, + return_dense=True, + return_sparse=True, + return_colbert_vecs=False + ) + + dense_vec = embeddings['dense_vecs'][0].tolist() + lexical_dict = embeddings['lexical_weights'][0] + sparse_vec = { + "indices": [int(k) for k in lexical_dict.keys()], + "values": [float(v) for v in lexical_dict.values()] + } + + return { + "dense": dense_vec, + "sparse": sparse_vec + } + + def encode_sparse_only(self, text: str) -> Dict[str, Any]: + """ + Encodes only the sparse (BM25/lexical) vector for a single query. + Skips dense computation — ~2x faster than encode_query. + Used for per-language sparse queries when the dense vector is + already available from the English query. + """ + if self.model is None: + self._load_model() + + if not HAS_FLAG_EMBEDDING or self.model is None: + return {"sparse": None} + + embeddings = self.model.encode( + sentences=[text], + batch_size=1, + max_length=512, + return_dense=False, # skip dense — saves ~60% compute + return_sparse=True, + return_colbert_vecs=False + ) + + lexical_dict = embeddings['lexical_weights'][0] + sparse_vec = { + "indices": [int(k) for k in lexical_dict.keys()], + "values": [float(v) for v in lexical_dict.values()] + } + return {"sparse": sparse_vec} + + def encode_sparse_batch(self, texts: List[str]) -> List[Dict[str, Any]]: + """ + Encode multiple texts into sparse vectors in a SINGLE model forward pass. + + Why this matters: + BGE-M3 holds the Python GIL during inference — ThreadPoolExecutor gives + zero benefit for CPU-bound model calls. Calling encode_sparse_only() 6 + times in a thread pool still runs sequentially. This method batches all + 6 language queries into one model.encode() call, which is ~5x faster + than 6 sequential calls because: + - One tokenization pass for all texts + - One forward pass through the transformer + - GPU/CPU utilisation is much higher with batch_size=6 vs batch_size=1 + + Returns a list of sparse dicts in the same order as `texts`. + Falls back to empty sparse vectors on failure. + """ + if not texts: + return [] + + if self.model is None: + self._load_model() + + if not HAS_FLAG_EMBEDDING or self.model is None: + return [{"sparse": None} for _ in texts] + + try: + embeddings = self.model.encode( + sentences=texts, + batch_size=len(texts), # all in one shot + max_length=512, + return_dense=False, # skip dense — not needed here + return_sparse=True, + return_colbert_vecs=False + ) + + results = [] + for lexical_dict in embeddings['lexical_weights']: + results.append({ + "sparse": { + "indices": [int(k) for k in lexical_dict.keys()], + "values": [float(v) for v in lexical_dict.values()], + } + }) + return results + + except Exception as e: + logger.error(f"encode_sparse_batch failed: {e} — returning empty sparse vectors") + return [{"sparse": None} for _ in texts] diff --git a/src/infrastructure/adapters/bge_reranker_adapter.py b/src/infrastructure/adapters/bge_reranker_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..14661532cc02405bf25fe5069ec24ee8e5dc4dd3 --- /dev/null +++ b/src/infrastructure/adapters/bge_reranker_adapter.py @@ -0,0 +1,127 @@ +import logging +import threading +from typing import List, Dict, Any +from src.core.config import settings +from src.core.ports.reranker_port import RerankerPort + +logger = logging.getLogger(__name__) + +try: + import transformers.utils.import_utils + if not hasattr(transformers.utils.import_utils, 'is_torch_fx_available'): + transformers.utils.import_utils.is_torch_fx_available = lambda: False +except Exception: + pass + +# BGE-Reranker-v2-m3 uses FlagReranker from FlagEmbedding (same package as BGE-M3 embedder) +# Fallback to sentence-transformers CrossEncoder if FlagEmbedding is unavailable +try: + from FlagEmbedding import FlagReranker + HAS_FLAG_RERANKER = True +except ImportError: + HAS_FLAG_RERANKER = False + logger.warning("FlagEmbedding not available for FlagReranker — trying CrossEncoder fallback.") + +try: + from sentence_transformers import CrossEncoder + HAS_CROSS_ENCODER = True +except ImportError: + HAS_CROSS_ENCODER = False + +if not HAS_FLAG_RERANKER and not HAS_CROSS_ENCODER: + logger.warning("Neither FlagEmbedding nor sentence-transformers installed. Reranking disabled.") + + +class BgeRerankerAdapter(RerankerPort): + """ + Multilingual reranker using BAAI/bge-reranker-v2-m3. + + Why bge-reranker-v2-m3: + - Same model family as BGE-M3 embedder — consistent representation space + - Natively multilingual: Arabic, Amharic, Somali, Swahili, French, English + - Significantly better than ms-marco-TinyBERT for non-English content + - Uses FlagReranker (FlagEmbedding) as primary, CrossEncoder as fallback + + Thread-safe lazy loading — model loads once on first rerank call. + """ + + def __init__(self): + self.model = None + self.model_name = settings.RERANKER_MODEL + self._lock = threading.Lock() + self._load_failed = False + + def _load_model(self): + if self.model is not None or self._load_failed: + return + with self._lock: + if self.model is not None or self._load_failed: + return + logger.info(f"Loading reranker model: {self.model_name}") + try: + if HAS_FLAG_RERANKER and "bge-reranker" in self.model_name.lower(): + # FlagReranker: use_fp16=True halves memory, normalize=True gives [0,1] scores + self.model = FlagReranker(self.model_name, use_fp16=True, normalize=True) + self._use_flag = True + logger.info(f"✅ Loaded {self.model_name} via FlagReranker (multilingual, fp16)") + elif HAS_CROSS_ENCODER: + self.model = CrossEncoder(self.model_name) + self._use_flag = False + logger.info(f"✅ Loaded {self.model_name} via CrossEncoder (fallback)") + else: + logger.error("No reranker backend available.") + self._load_failed = True + except Exception as e: + logger.error(f"Failed to load reranker model '{self.model_name}': {e}", exc_info=True) + self._load_failed = True + + def rerank(self, query: str, docs: List[Dict[str, Any]], top_n: int = 5) -> List[Dict[str, Any]]: + if not docs: + return [] + + if self.model is None: + self._load_model() + + # Fallback: sort by vector score if model unavailable + if self.model is None: + logger.warning("Reranker unavailable — falling back to vector score ordering.") + return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n] + + # Build (query, content) pairs. + # PERFORMANCE: Truncate content to 512 chars (~128 tokens) before scoring. + # The reranker only needs the opening paragraph to judge topical relevance. + # Full articles waste ~60% of inference time on boilerplate text. + MAX_CONTENT_CHARS = 512 + pairs = [] + valid_docs = [] + for doc in docs: + content = doc.get("content", "").strip() + if content: + truncated = content[:MAX_CONTENT_CHARS] + pairs.append([query, truncated]) + valid_docs.append(doc) + + if not pairs: + return [] + + try: + if getattr(self, "_use_flag", False): + # FlagReranker.compute_score — batch_size=64 safe now that + # content is truncated (much smaller tensors per pair) + scores = self.model.compute_score(pairs, batch_size=64) + if isinstance(scores, float): + scores = [scores] + else: + scores = self.model.predict(pairs) + if isinstance(scores, float): + scores = [scores] + + for i, doc in enumerate(valid_docs): + doc["rerank_score"] = float(scores[i]) + + valid_docs.sort(key=lambda x: x["rerank_score"], reverse=True) + return valid_docs[:top_n] + + except Exception as e: + logger.error(f"Reranker scoring failed: {e} — falling back to vector score ordering.") + return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n] diff --git a/src/infrastructure/adapters/clickhouse_adapter.py b/src/infrastructure/adapters/clickhouse_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..f4e7090dcee4c585360763c4b67f9b00ecc04447 --- /dev/null +++ b/src/infrastructure/adapters/clickhouse_adapter.py @@ -0,0 +1,53 @@ +import logging +from typing import Dict, Any, List, Optional +import clickhouse_connect +from src.core.config import settings +from src.core.ports.analytics_db_port import AnalyticsDbPort + +logger = logging.getLogger(__name__) + +class ClickHouseAdapter(AnalyticsDbPort): + def __init__(self): + try: + self.client = clickhouse_connect.get_client( + host=settings.CLICKHOUSE_HOST, + port=settings.CLICKHOUSE_PORT, + username=settings.CLICKHOUSE_USER, + password=settings.CLICKHOUSE_PASSWORD, + database=settings.CLICKHOUSE_DB, + secure=settings.CLICKHOUSE_SECURE + ) + + logger.info("Connected to ClickHouse") + except Exception as e: + logger.error(f"Failed to connect to ClickHouse: {e}") + self.client = None + + def query(self, sql: str, parameters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + if not self.client: + return {"columns": [], "rows": []} + + try: + result = self.client.query(sql, parameters) + return { + "columns": list(result.column_names), + "rows": result.result_rows + } + except Exception as e: + logger.error(f"Query error: {e}") + return {"columns": [], "rows": [], "error": str(e)} + + def get_trends(self, days: int = 3, limit: int = 10) -> Dict[str, Any]: + sql = f""" + SELECT + JSONExtractString(entity, 'name') as entity_name, + avg(sentiment_score) as avg_sentiment, + count(*) as mention_count + FROM sentiment_results + ARRAY JOIN entities AS entity + WHERE published_at >= now() - INTERVAL {days} DAY + GROUP BY entity_name + ORDER BY mention_count DESC + LIMIT {limit} + """ + return self.query(sql) diff --git a/src/infrastructure/adapters/duckduckgo_adapter.py b/src/infrastructure/adapters/duckduckgo_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..57fb92f556260a31944767f18382e76f9c4bbf36 --- /dev/null +++ b/src/infrastructure/adapters/duckduckgo_adapter.py @@ -0,0 +1,309 @@ +""" +DuckDuckGo Live Search Adapter + +Provides real-time news search using DuckDuckGo's news search API. +Integrates with the hybrid RAG pipeline to supplement database results +with fresh, live content. + +Features: +- Async execution with timeout (2s default) +- Ethiopia-focused filtering +- Error handling and graceful fallbacks +- Result normalization for hybrid ranking +""" + +import logging +import asyncio +from typing import List, Dict, Any, Optional +from datetime import datetime +import traceback + +logger = logging.getLogger(__name__) + +try: + from duckduckgo_search import DDGS + HAS_DDGS = True +except ImportError: + HAS_DDGS = False + logger.warning("duckduckgo-search not installed. Live search disabled.") + + +class DuckDuckGoAdapter: + """ + Adapter for DuckDuckGo news search. + + Provides real-time news results to complement database search. + Designed to be fast (2s timeout) and resilient (graceful fallbacks). + """ + + def __init__(self, timeout: float = 2.0, max_results: int = 5): + """ + Initialize DuckDuckGo adapter. + + Args: + timeout: Maximum time to wait for results (seconds) + max_results: Maximum number of results to return + """ + self.timeout = timeout + self.max_results = max_results + self.ddgs = DDGS() if HAS_DDGS else None + + if not HAS_DDGS: + logger.error( + "DuckDuckGo search unavailable. " + "Install with: pip install duckduckgo-search" + ) + + async def search( + self, + query: str, + max_results: Optional[int] = None, + region: str = "et-en", # Ethiopia English + add_ethiopia_filter: bool = True + ) -> List[Dict[str, Any]]: + """ + Search DuckDuckGo news for the given query. + + Args: + query: Search query + max_results: Override default max_results + region: DuckDuckGo region code (et-en = Ethiopia English) + add_ethiopia_filter: Add "Ethiopia" to query for relevance + + Returns: + List of normalized search results + """ + if not self.ddgs: + logger.warning("DuckDuckGo unavailable - returning empty results") + return [] + + max_results = max_results or self.max_results + + # Add Ethiopia filter for relevance (optional) + search_query = f"{query} Ethiopia" if add_ethiopia_filter else query + + try: + # Run sync DuckDuckGo search in thread pool with timeout + loop = asyncio.get_event_loop() + results = await asyncio.wait_for( + loop.run_in_executor( + None, + self._search_sync, + search_query, + max_results, + region + ), + timeout=self.timeout + ) + + logger.info( + f"DuckDuckGo search completed: '{query[:50]}' → {len(results)} results" + ) + return results + + except asyncio.TimeoutError: + logger.warning( + f"DuckDuckGo search timeout ({self.timeout}s) for: '{query[:50]}'" + ) + return [] + + except Exception as e: + logger.error( + f"DuckDuckGo search error for '{query[:50]}': {e}\n" + f"{traceback.format_exc()}" + ) + return [] + + def _search_sync( + self, + query: str, + max_results: int, + region: str + ) -> List[Dict[str, Any]]: + """ + Synchronous DuckDuckGo search (runs in thread pool). + + Args: + query: Search query + max_results: Maximum results to return + region: DuckDuckGo region code + + Returns: + List of normalized results + """ + results = [] + + try: + # DuckDuckGo news search + raw_results = self.ddgs.news( + keywords=query, + region=region, + max_results=max_results + ) + + # Normalize results to common format + for r in raw_results: + normalized = self._normalize_result(r) + if normalized: + results.append(normalized) + + except Exception as e: + logger.error(f"DuckDuckGo API error: {e}") + raise + + return results + + def _normalize_result(self, raw_result: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """ + Normalize DuckDuckGo result to common format. + + Args: + raw_result: Raw result from DuckDuckGo API + + Returns: + Normalized result dict or None if invalid + """ + try: + # Extract fields (DuckDuckGo news format) + title = raw_result.get("title", "").strip() + url = raw_result.get("url", "").strip() + snippet = raw_result.get("body", "").strip() + source = raw_result.get("source", "").strip() + date_str = raw_result.get("date") + + # Validate required fields + if not title or not url: + logger.debug(f"Skipping invalid result: missing title or URL") + return None + + # Parse date + published_at = self._parse_date(date_str) + + # Calculate freshness score (live results are freshest) + freshness_score = self._calculate_freshness(published_at) + + return { + "title": title, + "url": url, + "content": snippet or title, # Use title if no snippet + "snippet": snippet, + "source": source or self._extract_domain(url), + "published_at": published_at, + "source_type": "live", + "is_live": True, + "freshness_score": freshness_score, + "language": "en", # DuckDuckGo returns English + "metadata": { + "title": title, + "url": url, + "source": source, + "published_at": published_at, + "search_engine": "duckduckgo" + } + } + + except Exception as e: + logger.warning(f"Failed to normalize result: {e}") + return None + + def _parse_date(self, date_str: Optional[str]) -> str: + """ + Parse date string to ISO format. + + Args: + date_str: Date string from DuckDuckGo + + Returns: + ISO format date string or current time if parsing fails + """ + if not date_str: + return datetime.utcnow().isoformat() + + try: + # DuckDuckGo returns ISO-like format + # Try parsing common formats + from dateutil import parser + parsed = parser.parse(date_str) + return parsed.isoformat() + except: + # Fallback to current time + return datetime.utcnow().isoformat() + + def _calculate_freshness(self, published_at: str) -> float: + """ + Calculate freshness score based on article age. + + Args: + published_at: ISO format date string + + Returns: + Freshness score (0.0 to 1.0) + """ + try: + pub_date = datetime.fromisoformat(published_at.replace('Z', '+00:00')) + age = datetime.utcnow() - pub_date.replace(tzinfo=None) + age_minutes = age.total_seconds() / 60 + + # Live results are very fresh + if age_minutes < 10: + return 1.0 # < 10 min + elif age_minutes < 60: + return 0.95 # < 1 hour + elif age_minutes < 360: + return 0.9 # < 6 hours + elif age_minutes < 1440: + return 0.85 # < 24 hours + else: + return 0.8 # Older but still from live search + except: + return 1.0 # Default to fresh for live results + + def _extract_domain(self, url: str) -> str: + """ + Extract domain name from URL. + + Args: + url: Full URL + + Returns: + Domain name (e.g., "bbc.com") + """ + try: + from urllib.parse import urlparse + parsed = urlparse(url) + domain = parsed.netloc + # Remove www. prefix + if domain.startswith("www."): + domain = domain[4:] + return domain + except: + return "unknown" + + def is_available(self) -> bool: + """ + Check if DuckDuckGo search is available. + + Returns: + True if available, False otherwise + """ + return HAS_DDGS and self.ddgs is not None + + +# Module-level singleton for easy import +_default_adapter = None + +def get_duckduckgo_adapter(timeout: float = 2.0, max_results: int = 5) -> DuckDuckGoAdapter: + """ + Get or create the default DuckDuckGo adapter instance. + + Args: + timeout: Search timeout in seconds + max_results: Maximum results to return + + Returns: + DuckDuckGoAdapter instance + """ + global _default_adapter + if _default_adapter is None: + _default_adapter = DuckDuckGoAdapter(timeout=timeout, max_results=max_results) + return _default_adapter diff --git a/src/infrastructure/adapters/gemini_adapter.py b/src/infrastructure/adapters/gemini_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..cfcde9afa3ef5c5cf0616a34a89010cc64f4b69c --- /dev/null +++ b/src/infrastructure/adapters/gemini_adapter.py @@ -0,0 +1,126 @@ +import json +import logging +import requests +from typing import AsyncGenerator + +from src.core.ports.llm_port import LlmPort +from src.core.config import settings + +logger = logging.getLogger(__name__) + +# Try these in order until one works — covers all API key vintages +_MODEL_FALLBACK_CHAIN = [ + "gemini-2.0-flash", + "gemini-2.0-flash-lite", + "gemini-1.5-flash", + "gemini-1.5-flash-latest", + "gemini-flash-latest", + "gemini-pro", +] + +_BASE = "https://generativelanguage.googleapis.com/v1beta/models" + + +class GeminiAdapter(LlmPort): + """ + Google Gemini adapter using the native generateContent REST API. + Auto-discovers the first working model for the given API key. + Free tier: 15 RPM, 1M TPM, 1500 RPD — https://aistudio.google.com/apikey + """ + + def __init__(self): + self.api_key = settings.GEMINI_API_KEY + self.model: str | None = None + + if not self.api_key or self.api_key == "your-gemini-api-key-here": + self.api_key = None + logger.warning("GEMINI_API_KEY not set — Gemini adapter disabled.") + return + + # Prefer whatever is explicitly configured — but reject obvious non-model values + configured = (settings.GEMINI_MODEL or "").strip() + # Reject if it looks like an API key (starts with known prefixes or is too long) + if configured and ( + configured.startswith("key_") or + configured.startswith("gsk_") or + configured.startswith("AIza") or + len(configured) > 60 + ): + logger.warning(f"GEMINI_MODEL='{configured}' looks like an API key, ignoring it.") + configured = "" + + chain = ([configured] if configured else []) + _MODEL_FALLBACK_CHAIN + + for model in chain: + url = f"{_BASE}/{model}:generateContent" + try: + resp = requests.post( + url, + params={"key": self.api_key}, + json={"contents": [{"parts": [{"text": "hi"}]}]}, + timeout=10, + ) + if resp.status_code == 200: + self.model = model + logger.info(f"✅ Gemini adapter ready — model: {self.model}") + break + elif resp.status_code == 404: + logger.debug(f"Gemini model {model} not available (404), trying next...") + continue + else: + # 429 rate-limit etc — model exists, use it + self.model = model + logger.info(f"✅ Gemini adapter ready — model: {self.model} (status {resp.status_code})") + break + except Exception as e: + logger.debug(f"Gemini probe failed for {model}: {e}") + continue + + if not self.model: + logger.error("❌ No working Gemini model found for this API key.") + + def _url(self) -> str: + return f"{_BASE}/{self.model}:generateContent" + + def _call(self, prompt: str) -> str: + resp = requests.post( + self._url(), + params={"key": self.api_key}, + json={"contents": [{"parts": [{"text": prompt}]}]}, + timeout=60, + ) + resp.raise_for_status() + return resp.json()["candidates"][0]["content"]["parts"][0]["text"] + + def generate(self, prompt: str) -> str: + if not self.api_key or not self.model: + return "Gemini not available." + try: + return self._call(prompt) + except requests.HTTPError as e: + code = e.response.status_code if e.response else 0 + if code == 429: + return "Gemini rate limit reached. Please try again shortly." + logger.error(f"Gemini HTTP error: {e}") + return f"Gemini error: {e}" + except Exception as e: + logger.error(f"Gemini generate error: {e}") + return f"Gemini error: {e}" + + async def generate_stream(self, prompt: str) -> AsyncGenerator[str, None]: + if not self.api_key or not self.model: + yield f"data: {json.dumps({'token': 'Gemini not available.'})}\n\n" + yield "data: [DONE]\n\n" + return + try: + answer = self._call(prompt) + yield f"data: {json.dumps({'token': answer})}\n\n" + yield "data: [DONE]\n\n" + except requests.HTTPError as e: + code = e.response.status_code if e.response else 0 + msg = "Gemini rate limit reached." if code == 429 else f"Gemini error: {e}" + yield f"data: {json.dumps({'token': msg})}\n\n" + yield "data: [DONE]\n\n" + except Exception as e: + yield f"data: {json.dumps({'token': f'Gemini error: {e}'})}\n\n" + yield "data: [DONE]\n\n" diff --git a/src/infrastructure/adapters/groq_adapter.py b/src/infrastructure/adapters/groq_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..c0a1ebf1fa793d1b4d6d319ab125d0b2993dd6e3 --- /dev/null +++ b/src/infrastructure/adapters/groq_adapter.py @@ -0,0 +1,114 @@ +from typing import AsyncGenerator +from src.core.ports.llm_port import LlmPort +from langchain_openai import ChatOpenAI +from src.core.config import settings +import json +import logging + +logger = logging.getLogger(__name__) + +class GroqAdapter(LlmPort): + """ + Groq LLM adapter using the OpenAI-compatible API. + Free tier: ~14,400 RPD on llama-3.3-70b-versatile, 200+ tok/s. + Get your key: https://console.groq.com/keys + """ + GROQ_BASE_URL = "https://api.groq.com/openai/v1" + DEFAULT_MODEL = "llama-3.3-70b-versatile" + + def __init__(self): + self.llm = None + if settings.GROQ_API_KEY and settings.GROQ_API_KEY != "your-groq-api-key-here": + try: + self.llm = ChatOpenAI( + api_key=settings.GROQ_API_KEY, + base_url=self.GROQ_BASE_URL, + model=settings.GROQ_MODEL or self.DEFAULT_MODEL, + temperature=0.2, + max_tokens=1024, + ) + logger.info(f"✅ Groq adapter initialized with model: {settings.GROQ_MODEL or self.DEFAULT_MODEL}") + except Exception as e: + logger.error(f"Failed to initialize Groq adapter: {e}") + else: + logger.warning("GROQ_API_KEY not set — Groq adapter disabled.") + + def _is_rate_limit(self, error_msg: str) -> bool: + return "rate_limit" in error_msg.lower() or "429" in error_msg or "quota" in error_msg.lower() + + def generate(self, prompt: str) -> str: + if not self.llm: + return "Groq API key not configured." + try: + return self.llm.invoke(prompt).content + except Exception as e: + error_msg = str(e) + if self._is_rate_limit(error_msg): + logger.warning("Groq rate limit hit — trying fallback providers.") + return self._fallback_generate(prompt) + logger.error(f"Groq generate error: {e}") + return f"Error generating response: {error_msg}" + + def _fallback_generate(self, prompt: str) -> str: + """Try Gemini → Together AI → error message.""" + if settings.GEMINI_API_KEY and settings.GEMINI_API_KEY != "your-gemini-api-key-here": + try: + from src.infrastructure.adapters.gemini_adapter import GeminiAdapter + result = GeminiAdapter().generate(prompt) + if not result.startswith("Gemini"): + logger.info("Gemini fallback succeeded.") + return result + logger.warning(f"Gemini fallback returned error: {result}") + except Exception as e: + logger.warning(f"Gemini fallback failed: {e}") + + # Try Together AI + if settings.TOGETHER_API_KEY and settings.TOGETHER_API_KEY != "your-together-api-key-here": + try: + from langchain_openai import ChatOpenAI + together = ChatOpenAI( + api_key=settings.TOGETHER_API_KEY, + base_url="https://api.together.xyz/v1", + model=settings.TOGETHER_MODEL or "meta-llama/Llama-3.3-70B-Instruct-Turbo", + temperature=0.2, + max_tokens=1024, + ) + logger.info("Falling back to Together AI.") + return together.invoke(prompt).content + except Exception as e: + logger.warning(f"Together AI fallback failed: {e}") + + # Try HuggingFace Inference API + if settings.HF_TOKEN and settings.HF_TOKEN != "your-hf-token-here": + try: + from src.infrastructure.adapters.huggingface_adapter import HuggingFaceAdapter + result = HuggingFaceAdapter().generate(prompt) + if result and not result.startswith("HuggingFace"): + logger.info("Falling back to HuggingFace.") + return result + logger.warning(f"HuggingFace fallback returned: {result}") + except Exception as e: + logger.warning(f"HuggingFace fallback failed: {e}") + + return "All LLM providers are currently unavailable. Please try again in a few minutes." + + async def generate_stream(self, prompt: str) -> AsyncGenerator[str, None]: + if not self.llm: + yield f"data: {json.dumps({'token': 'Groq API key not configured.'})}\n\n" + yield "data: [DONE]\n\n" + return + + try: + for chunk in self.llm.stream(prompt): + if hasattr(chunk, 'content') and chunk.content: + yield f"data: {json.dumps({'token': chunk.content})}\n\n" + yield "data: [DONE]\n\n" + except Exception as e: + error_msg = str(e) + if self._is_rate_limit(error_msg): + logger.warning("Groq rate limit hit during stream — falling back.") + fallback_answer = self._fallback_generate(prompt) + yield f"data: {json.dumps({'token': fallback_answer})}\n\n" + else: + yield f"data: {json.dumps({'token': f'Error: {error_msg}'})}\n\n" + yield "data: [DONE]\n\n" diff --git a/src/infrastructure/adapters/huggingface_adapter.py b/src/infrastructure/adapters/huggingface_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..0e99ab57cde1e536599c9623eecd5a78c132c12b --- /dev/null +++ b/src/infrastructure/adapters/huggingface_adapter.py @@ -0,0 +1,84 @@ +import json +import logging +from typing import AsyncGenerator + +from src.core.ports.llm_port import LlmPort +from src.core.config import settings + +logger = logging.getLogger(__name__) + +# HF Inference Providers router — OpenAI-compatible +# Only models with chat/instruct fine-tuning work on the chat completions endpoint +_ROUTER_URL = "https://router.huggingface.co/v1" + +# Chat-compatible instruct models available on HF router (free tier) +_DEFAULT_MODEL = "meta-llama/Llama-3.1-8B-Instruct" +_FALLBACK_MODELS = [ + "meta-llama/Llama-3.1-8B-Instruct", + "mistralai/Mistral-Nemo-Instruct-2407", + "Qwen/Qwen2.5-7B-Instruct", + "microsoft/Phi-3.5-mini-instruct", + "HuggingFaceH4/zephyr-7b-beta", +] + + +class HuggingFaceAdapter(LlmPort): + """ + HuggingFace Inference Providers adapter — free tier with HF token. + Uses the new router.huggingface.co OpenAI-compatible endpoint. + + Get token: https://huggingface.co/settings/tokens + → New token → Fine-grained → enable "Make calls to Inference Providers" + Set HF_TOKEN env var. Optionally set HF_MODEL to choose a specific model. + + Free limits: generous free tier, no credit card needed. + """ + + def __init__(self): + self.token = settings.HF_TOKEN + self.model = (settings.HF_MODEL or _DEFAULT_MODEL).strip() + self.llm = None + + if self.token and self.token != "your-hf-token-here": + try: + from langchain_openai import ChatOpenAI + self.llm = ChatOpenAI( + api_key=self.token, + base_url=_ROUTER_URL, + model=f"{self.model}:fastest", # :fastest = auto-select best provider + temperature=0.2, + max_tokens=1024, + ) + logger.info(f"✅ HuggingFace adapter ready — model: {self.model}:fastest") + except Exception as e: + logger.error(f"Failed to initialize HuggingFace adapter: {e}") + else: + logger.warning("HF_TOKEN not set — HuggingFace adapter disabled.") + + def generate(self, prompt: str) -> str: + if not self.llm: + return "HuggingFace token not configured." + try: + return self.llm.invoke(prompt).content + except Exception as e: + error_msg = str(e) + if "429" in error_msg or "rate" in error_msg.lower(): + return "HuggingFace rate limit reached. Please try again shortly." + logger.error(f"HuggingFace generate error: {e}") + return f"HuggingFace error: {error_msg}" + + async def generate_stream(self, prompt: str) -> AsyncGenerator[str, None]: + if not self.llm: + yield f"data: {json.dumps({'token': 'HuggingFace token not configured.'})}\n\n" + yield "data: [DONE]\n\n" + return + try: + for chunk in self.llm.stream(prompt): + if hasattr(chunk, "content") and chunk.content: + yield f"data: {json.dumps({'token': chunk.content})}\n\n" + yield "data: [DONE]\n\n" + except Exception as e: + error_msg = str(e) + msg = "HuggingFace rate limit reached." if "429" in error_msg else f"HuggingFace error: {error_msg}" + yield f"data: {json.dumps({'token': msg})}\n\n" + yield "data: [DONE]\n\n" diff --git a/src/infrastructure/adapters/intent_classifier.py b/src/infrastructure/adapters/intent_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..05f77b58e5e6aeb4483add42fa21bea208876250 --- /dev/null +++ b/src/infrastructure/adapters/intent_classifier.py @@ -0,0 +1,137 @@ +import logging +import re +import threading + +logger = logging.getLogger(__name__) + +# ── Instant keyword shortcuts ───────────────────────────────────────────────── + +_SMALL_TALK_EXACT = { + "hi", "hello", "hey", "thanks", "thank you", "bye", "goodbye", + "good morning", "good afternoon", "good evening", "sup", "yo", + "hello there", "hey there", "hi there", "greetings", + # frustration / profanity + "wtf", "lol", "lmao", "omg", "damn", "shit", "fuck", + "for fuck sake", "for fucks sake", "oh my god", "are you kidding", + "seriously", "come on", "ugh", "argh", +} +_SMALL_TALK_PREFIX = ( + "how are you", "what are you", "who are you", "what can you do", + "tell me a joke", "make me laugh", "what's up", "whats up", + "for fuck", "for fucks", "what the fuck", "what the hell", + "are you serious", "you must be", "hello ", "hi ", "hey ", +) + +# Temporal patterns → always NEWS (user is asking about time-scoped news) +_TEMPORAL_PATTERNS = re.compile( + r"\b(" + r"today|yesterday|tomorrow|tonight|" + r"this (week|month|year|morning|evening|afternoon)|" + r"last (week|month|year|night|monday|tuesday|wednesday|thursday|friday|saturday|sunday)|" + r"next (week|month|year)|" + r"past (\d+ )?(day|days|week|weeks|month|months|year|years)|" + r"recent(ly)?|latest|breaking|just (now|happened|announced)|" + r"(monday|tuesday|wednesday|thursday|friday|saturday|sunday)|" + r"january|february|march|april|may|june|july|august|september|october|november|december|" + r"\d{4}|" # year like 2024, 2025 + r"\d+(st|nd|rd|th)" # ordinal like 1st, 2nd + r")\b", + re.IGNORECASE +) + + +class IntentClassifier: + """ + Local zero-shot intent classifier using MoritzLaurer/DeBERTa-v3-small-mnli. + - 140 MB model, ~20 ms inference, no API calls, no rate limits. + - Lazy-loaded on first use so startup is not blocked. + - Thread-safe singleton load. + + Classification priority: + 1. Small-talk exact/prefix match → OTHER (instant) + 2. Temporal pattern match → NEWS (instant, handles "this week", "yesterday", etc.) + 3. DeBERTa NLI model → NEWS or OTHER (~20ms) + 4. Keyword fallback if model failed to load + 5. Default → NEWS (always prefer RAG over hallucination) + """ + MODEL_NAME = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0" + + def __init__(self): + self._pipe = None + self._lock = threading.Lock() + self._load_failed = False + + def _load(self): + if self._pipe is not None or self._load_failed: + return + with self._lock: + if self._pipe is not None or self._load_failed: + return + try: + from transformers import pipeline + logger.info(f"Loading intent classifier: {self.MODEL_NAME} ...") + self._pipe = pipeline( + "zero-shot-classification", + model=self.MODEL_NAME, + device=-1, # CPU + multi_label=False, + ) + logger.info("✅ Intent classifier loaded.") + except Exception as e: + logger.error(f"❌ Failed to load intent classifier: {e}. Falling back to keyword matching.") + self._load_failed = True + + def classify(self, query: str) -> str: + """Returns 'NEWS' or 'OTHER'.""" + q = query.strip().lower() + + # 1. Instant small-talk shortcuts + if q in _SMALL_TALK_EXACT: + logger.debug(f"Intent: OTHER (small-talk exact) — '{q}'") + return "OTHER" + if any(q.startswith(p) for p in _SMALL_TALK_PREFIX): + logger.debug(f"Intent: OTHER (small-talk prefix) — '{q}'") + return "OTHER" + + # 2. Temporal pattern → always NEWS + if _TEMPORAL_PATTERNS.search(query): + logger.debug(f"Intent: NEWS (temporal pattern) — '{query[:60]}'") + return "NEWS" + + # 3. DeBERTa NLI model + self._load() + if self._pipe is not None: + try: + result = self._pipe( + query, + candidate_labels=[ + "news, current events, politics, economy, sports, technology, world affairs", + "small talk, greeting, joke, or general question unrelated to news", + ], + hypothesis_template="This message is about {}.", + ) + top_label = result["labels"][0] + score = result["scores"][0] + intent = "NEWS" if "news" in top_label else "OTHER" + logger.debug(f"Intent: {intent} (DeBERTa score={score:.2f}) — '{query[:60]}'") + return intent + except Exception as e: + logger.warning(f"Intent classifier inference failed: {e}. Defaulting to NEWS.") + + # 4. Keyword fallback + news_signals = [ + "latest", "recent", "news", "update", "development", "what happened", + "who is", "what is", "when did", "why did", "how did", "report", + "conflict", "election", "economy", "war", "crisis", "deal", + "agreement", "president", "minister", "market", "price", + "attack", "protest", "africa", "ethiopia", + ] + if any(s in q for s in news_signals): + return "NEWS" + + # 5. Default — always prefer RAG over hallucination + return "NEWS" + + +# Module-level singleton — shared across all requests +intent_classifier = IntentClassifier() diff --git a/src/infrastructure/adapters/intent_classifier_v2.py b/src/infrastructure/adapters/intent_classifier_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..f386ad58290c1596bdeb56a4841ae54673276f1d --- /dev/null +++ b/src/infrastructure/adapters/intent_classifier_v2.py @@ -0,0 +1,512 @@ +""" +Production-Grade Intent Classifier v2 + +Enhanced intent classification for hybrid RAG system with: +- Multi-class classification (NEWS_TEMPORAL, NEWS_HISTORICAL, NEWS_GENERAL, OTHER) +- Confidence scoring with thresholds +- Query complexity analysis +- Metrics tracking +- Fallback strategies +- Thread-safe lazy loading + +Classification Hierarchy: +1. Instant shortcuts (regex patterns) - 0ms +2. DeBERTa zero-shot NLI - ~20ms +3. Keyword fallback - 0ms +4. Default (NEWS_GENERAL) - safe fallback +""" + +import logging +import re +import threading +from typing import Dict, Any, Optional, Tuple +from dataclasses import dataclass +from datetime import datetime +import time + +logger = logging.getLogger(__name__) + + +# ═══════════════════════════════════════════════════════════════════════════ +# PATTERN DEFINITIONS +# ═══════════════════════════════════════════════════════════════════════════ + +# Small talk patterns (instant OTHER classification) +_SMALL_TALK_EXACT = { + "hi", "hello", "hey", "thanks", "thank you", "bye", "goodbye", + "good morning", "good afternoon", "good evening", "sup", "yo", + "hello there", "hey there", "hi there", "greetings", "howdy", + # Frustration / profanity + "wtf", "lol", "lmao", "omg", "damn", "shit", "fuck", + "for fuck sake", "for fucks sake", "oh my god", "are you kidding", + "seriously", "come on", "ugh", "argh", "ffs", +} + +_SMALL_TALK_PREFIX = ( + "how are you", "what are you", "who are you", "what can you do", + "tell me a joke", "make me laugh", "what's up", "whats up", + "for fuck", "for fucks", "what the fuck", "what the hell", + "are you serious", "you must be", "hello ", "hi ", "hey ", + "can you help", "i need help", "help me", +) + +# Temporal patterns (instant NEWS_TEMPORAL classification) +_TEMPORAL_PATTERNS = re.compile( + r"\b(" + r"today|yesterday|tomorrow|tonight|now|currently|" + r"this (week|month|year|morning|evening|afternoon)|" + r"last (week|month|year|night|hour|" + r"monday|tuesday|wednesday|thursday|friday|saturday|sunday)|" + r"next (week|month|year)|" + r"past (\d+ )?(hour|hours|day|days|week|weeks|month|months)|" + r"recent(ly)?|latest|breaking|just (now|happened|announced|reported)|" + r"(monday|tuesday|wednesday|thursday|friday|saturday|sunday)|" + r"january|february|march|april|may|june|july|august|september|october|november|december|" + r"\d{4}|" # year like 2024, 2025 + r"\d+(st|nd|rd|th)|" # ordinal like 1st, 2nd + r"current|ongoing|live|real[- ]?time" + r")\b", + re.IGNORECASE +) + +# Historical patterns (instant NEWS_HISTORICAL classification) +_HISTORICAL_PATTERNS = re.compile( + r"\b(" + r"history|historical|background|context|origin|" + r"how (did|was|were)|why (did|was|were)|" + r"what (led to|caused|resulted in)|" + r"timeline|chronology|evolution|development|" + r"past|previous|former|old|ancient|" + r"analysis|overview|summary|explanation|" + r"tell me about|explain|describe" + r")\b", + re.IGNORECASE +) + +# News signal keywords (fallback NEWS classification) +_NEWS_KEYWORDS = { + "news", "report", "update", "development", "announcement", + "conflict", "war", "peace", "crisis", "deal", "agreement", + "election", "vote", "campaign", "president", "minister", "government", + "economy", "market", "price", "inflation", "trade", + "protest", "demonstration", "strike", "rally", + "attack", "violence", "security", "military", + "ethiopia", "addis", "abiy", "fano", "tigray", "amhara", "oromia", + "africa", "african", "horn of africa", +} + + +# ═══════════════════════════════════════════════════════════════════════════ +# DATA CLASSES +# ═══════════════════════════════════════════════════════════════════════════ + +@dataclass +class IntentResult: + """ + Intent classification result with confidence and metadata. + """ + intent: str # NEWS_TEMPORAL, NEWS_HISTORICAL, NEWS_GENERAL, OTHER + confidence: float # 0.0 to 1.0 + method: str # "regex", "deberta", "keyword", "default" + inference_time_ms: float # Time taken for classification + query_complexity: str # "simple", "medium", "complex" + should_use_live: bool # Recommendation for live search + should_use_db: bool # Recommendation for DB search + metadata: Dict[str, Any] # Additional info + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for logging/caching""" + return { + "intent": self.intent, + "confidence": self.confidence, + "method": self.method, + "inference_time_ms": self.inference_time_ms, + "query_complexity": self.query_complexity, + "should_use_live": self.should_use_live, + "should_use_db": self.should_use_db, + "metadata": self.metadata + } + + +# ═══════════════════════════════════════════════════════════════════════════ +# PRODUCTION-GRADE INTENT CLASSIFIER +# ═══════════════════════════════════════════════════════════════════════════ + +class IntentClassifierV2: + """ + Production-grade intent classifier with multi-class classification. + + Intent Classes: + - NEWS_TEMPORAL: Time-sensitive news queries (use live search) + - NEWS_HISTORICAL: Historical/background queries (use DB only) + - NEWS_GENERAL: General news queries (use hybrid) + - OTHER: Non-news queries (skip search) + + Features: + - Multi-stage classification (regex → DeBERTa → keyword → default) + - Confidence scoring with thresholds + - Query complexity analysis + - Metrics tracking + - Thread-safe lazy loading + """ + + MODEL_NAME = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0" + + # Confidence thresholds + HIGH_CONFIDENCE = 0.75 + MEDIUM_CONFIDENCE = 0.50 + LOW_CONFIDENCE = 0.30 + + def __init__(self): + self._pipe = None + self._lock = threading.Lock() + self._load_failed = False + + # Metrics tracking + self._metrics = { + "total_classifications": 0, + "by_intent": {"NEWS_TEMPORAL": 0, "NEWS_HISTORICAL": 0, "NEWS_GENERAL": 0, "OTHER": 0}, + "by_method": {"regex": 0, "deberta": 0, "keyword": 0, "default": 0}, + "avg_inference_time_ms": 0.0, + "total_inference_time_ms": 0.0, + } + + def _load(self): + """Lazy load DeBERTa model (thread-safe)""" + if self._pipe is not None or self._load_failed: + return + + with self._lock: + if self._pipe is not None or self._load_failed: + return + + try: + from transformers import pipeline + logger.info(f"Loading intent classifier: {self.MODEL_NAME} ...") + + self._pipe = pipeline( + "zero-shot-classification", + model=self.MODEL_NAME, + device=-1, # CPU (use device=0 for GPU) + multi_label=False, + ) + + logger.info("✅ Intent classifier v2 loaded successfully") + + except Exception as e: + logger.error(f"❌ Failed to load intent classifier: {e}") + self._load_failed = True + + def classify(self, query: str, use_cache: bool = True) -> IntentResult: + """ + Classify query intent with confidence scoring. + + Args: + query: User query string + use_cache: Whether to use cached results (if available) + + Returns: + IntentResult with classification and metadata + """ + start_time = time.time() + + # Normalize query + query_normalized = query.strip() + query_lower = query_normalized.lower() + + # Analyze query complexity + complexity = self._analyze_complexity(query_normalized) + + # ── Stage 1: Instant Regex Shortcuts ────────────────────────────────── + + # Check small talk (OTHER) + if query_lower in _SMALL_TALK_EXACT: + return self._create_result( + intent="OTHER", + confidence=1.0, + method="regex_exact", + start_time=start_time, + complexity=complexity, + metadata={"pattern": "small_talk_exact"} + ) + + if any(query_lower.startswith(p) for p in _SMALL_TALK_PREFIX): + return self._create_result( + intent="OTHER", + confidence=0.95, + method="regex_prefix", + start_time=start_time, + complexity=complexity, + metadata={"pattern": "small_talk_prefix"} + ) + + # Check temporal patterns (NEWS_TEMPORAL) + temporal_match = _TEMPORAL_PATTERNS.search(query_normalized) + if temporal_match: + return self._create_result( + intent="NEWS_TEMPORAL", + confidence=0.90, + method="regex_temporal", + start_time=start_time, + complexity=complexity, + metadata={"pattern": "temporal", "matched": temporal_match.group(0)} + ) + + # Check historical patterns (NEWS_HISTORICAL) + historical_match = _HISTORICAL_PATTERNS.search(query_normalized) + if historical_match: + return self._create_result( + intent="NEWS_HISTORICAL", + confidence=0.85, + method="regex_historical", + start_time=start_time, + complexity=complexity, + metadata={"pattern": "historical", "matched": historical_match.group(0)} + ) + + # ── Stage 2: DeBERTa Zero-Shot Classification ───────────────────────── + + self._load() + + if self._pipe is not None: + try: + result = self._classify_with_deberta(query_normalized) + + if result: + return self._create_result( + intent=result["intent"], + confidence=result["confidence"], + method="deberta", + start_time=start_time, + complexity=complexity, + metadata=result["metadata"] + ) + + except Exception as e: + logger.warning(f"DeBERTa classification failed: {e}") + + # ── Stage 3: Keyword Fallback ───────────────────────────────────────── + + keyword_result = self._classify_with_keywords(query_lower) + if keyword_result: + return self._create_result( + intent=keyword_result["intent"], + confidence=keyword_result["confidence"], + method="keyword", + start_time=start_time, + complexity=complexity, + metadata=keyword_result["metadata"] + ) + + # ── Stage 4: Default (Safe Fallback) ────────────────────────────────── + + return self._create_result( + intent="NEWS_GENERAL", + confidence=0.50, + method="default", + start_time=start_time, + complexity=complexity, + metadata={"reason": "no_pattern_match"} + ) + + def _classify_with_deberta(self, query: str) -> Optional[Dict[str, Any]]: + """ + Classify using DeBERTa zero-shot model. + + Returns dict with intent, confidence, metadata or None if failed. + """ + try: + # Multi-class classification + result = self._pipe( + query, + candidate_labels=[ + "breaking news, current events, today's news, latest updates, real-time news", + "historical background, past events, context, analysis, explanation", + "general news, politics, economy, world affairs, sports, technology", + "small talk, greeting, joke, general question unrelated to news", + ], + hypothesis_template="This message is about {}.", + ) + + top_label = result["labels"][0] + top_score = result["scores"][0] + + # Map label to intent + if "breaking" in top_label or "current" in top_label or "latest" in top_label: + intent = "NEWS_TEMPORAL" + elif "historical" in top_label or "background" in top_label or "context" in top_label: + intent = "NEWS_HISTORICAL" + elif "general news" in top_label or "politics" in top_label: + intent = "NEWS_GENERAL" + elif "small talk" in top_label or "greeting" in top_label: + intent = "OTHER" + else: + intent = "NEWS_GENERAL" # Default to general news + + # Only return if confidence is above threshold + if top_score >= self.LOW_CONFIDENCE: + return { + "intent": intent, + "confidence": float(top_score), + "metadata": { + "top_label": top_label, + "all_scores": { + label: float(score) + for label, score in zip(result["labels"], result["scores"]) + } + } + } + + return None + + except Exception as e: + logger.error(f"DeBERTa inference error: {e}") + return None + + def _classify_with_keywords(self, query_lower: str) -> Optional[Dict[str, Any]]: + """ + Classify using keyword matching (fallback). + + Returns dict with intent, confidence, metadata or None if no match. + """ + # Count news keyword matches + matches = [kw for kw in _NEWS_KEYWORDS if kw in query_lower] + + if matches: + # More matches = higher confidence + confidence = min(0.70, 0.50 + (len(matches) * 0.05)) + + return { + "intent": "NEWS_GENERAL", + "confidence": confidence, + "metadata": { + "matched_keywords": matches[:5], # Top 5 + "match_count": len(matches) + } + } + + return None + + def _analyze_complexity(self, query: str) -> str: + """ + Analyze query complexity based on length and structure. + + Returns: "simple", "medium", or "complex" + """ + word_count = len(query.split()) + char_count = len(query) + + # Check for question words + question_words = ["what", "when", "where", "who", "why", "how"] + has_question = any(qw in query.lower() for qw in question_words) + + if word_count <= 3 and not has_question: + return "simple" + elif word_count <= 10: + return "medium" + else: + return "complex" + + def _create_result( + self, + intent: str, + confidence: float, + method: str, + start_time: float, + complexity: str, + metadata: Dict[str, Any] + ) -> IntentResult: + """ + Create IntentResult with recommendations and metrics. + """ + inference_time_ms = (time.time() - start_time) * 1000 + + # Determine search recommendations + should_use_live = intent == "NEWS_TEMPORAL" + should_use_db = intent in ["NEWS_TEMPORAL", "NEWS_HISTORICAL", "NEWS_GENERAL"] + + # Update metrics + self._update_metrics(intent, method, inference_time_ms) + + result = IntentResult( + intent=intent, + confidence=confidence, + method=method, + inference_time_ms=inference_time_ms, + query_complexity=complexity, + should_use_live=should_use_live, + should_use_db=should_use_db, + metadata=metadata + ) + + # Log classification + logger.debug( + f"Intent: {intent} (conf={confidence:.2f}, method={method}, " + f"time={inference_time_ms:.1f}ms, complexity={complexity})" + ) + + return result + + def _update_metrics(self, intent: str, method: str, inference_time_ms: float): + """Update classification metrics""" + self._metrics["total_classifications"] += 1 + self._metrics["by_intent"][intent] = self._metrics["by_intent"].get(intent, 0) + 1 + self._metrics["by_method"][method] = self._metrics["by_method"].get(method, 0) + 1 + self._metrics["total_inference_time_ms"] += inference_time_ms + self._metrics["avg_inference_time_ms"] = ( + self._metrics["total_inference_time_ms"] / self._metrics["total_classifications"] + ) + + def get_metrics(self) -> Dict[str, Any]: + """Get classification metrics for monitoring""" + return dict(self._metrics) + + def reset_metrics(self): + """Reset metrics (useful for testing)""" + self._metrics = { + "total_classifications": 0, + "by_intent": {"NEWS_TEMPORAL": 0, "NEWS_HISTORICAL": 0, "NEWS_GENERAL": 0, "OTHER": 0}, + "by_method": {"regex": 0, "deberta": 0, "keyword": 0, "default": 0}, + "avg_inference_time_ms": 0.0, + "total_inference_time_ms": 0.0, + } + + +# ═══════════════════════════════════════════════════════════════════════════ +# MODULE-LEVEL SINGLETON +# ═══════════════════════════════════════════════════════════════════════════ + +# Global singleton instance +intent_classifier_v2 = IntentClassifierV2() + + +# ═══════════════════════════════════════════════════════════════════════════ +# BACKWARD COMPATIBILITY WRAPPER +# ═══════════════════════════════════════════════════════════════════════════ + +class IntentClassifier: + """ + Backward-compatible wrapper for existing code. + Maps v2 multi-class intents to v1 binary (NEWS/OTHER). + """ + + def __init__(self): + self._classifier_v2 = intent_classifier_v2 + + def classify(self, query: str) -> str: + """ + Classify query intent (backward compatible). + + Returns: "NEWS" or "OTHER" + """ + result = self._classifier_v2.classify(query) + + # Map v2 intents to v1 binary + if result.intent == "OTHER": + return "OTHER" + else: + return "NEWS" # All NEWS_* intents map to NEWS + + +# Backward-compatible singleton +intent_classifier = IntentClassifier() diff --git a/src/infrastructure/adapters/ollama_adapter.py b/src/infrastructure/adapters/ollama_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..ff98c6405df1717d3d8fdb6e4bbe358bee969cae --- /dev/null +++ b/src/infrastructure/adapters/ollama_adapter.py @@ -0,0 +1,69 @@ +import json +import logging +import requests +import aiohttp +from typing import AsyncGenerator + +from src.core.ports.llm_port import LlmPort +from src.core.config import settings + +logger = logging.getLogger(__name__) + +class OllamaAdapter(LlmPort): + def __init__(self): + self.host = settings.OLLAMA_HOST.rstrip('/') + self.model = settings.OLLAMA_MODEL + self.api_url = f"{self.host}/api/generate" + + def generate(self, prompt: str) -> str: + print(f"OLLAMA DEBUG: Using model: {self.model}") + print(f"OLLAMA DEBUG: API URL: {self.api_url}") + print(f"OLLAMA DEBUG: Prompt length: {len(prompt)} chars") + + payload = { + "model": self.model, + "prompt": prompt, + "stream": False + } + + try: + print(f"OLLAMA DEBUG: Sending request to Ollama...") + response = requests.post(self.api_url, json=payload, timeout=180) + print(f"OLLAMA DEBUG: Response status: {response.status_code}") + response.raise_for_status() + data = response.json() + result = data.get("response", "") + print(f"OLLAMA DEBUG: Generated response length: {len(result)} chars") + return result + except requests.exceptions.RequestException as e: + logger.error(f"Error communicating with Ollama: {e}") + return f"Error communicating with local LLM: {e}" + + async def generate_stream(self, prompt: str) -> AsyncGenerator[str, None]: + payload = { + "model": self.model, + "prompt": prompt, + "stream": True + } + + try: + async with aiohttp.ClientSession() as session: + async with session.post(self.api_url, json=payload) as response: + response.raise_for_status() + async for line in response.content: + if line: + try: + data = json.loads(line.decode('utf-8')) + token = data.get("response", "") + # Format as Server-Sent Events (SSE) for the frontend + yield f"data: {json.dumps({'token': token})}\n\n" + + if data.get("done", False): + break + except json.JSONDecodeError: + pass + except Exception as e: + logger.error(f"Streaming error from Ollama: {e}") + yield f"data: {json.dumps({'token': f'[Error: {e}]'})}\n\n" + + yield "data: [DONE]\n\n" diff --git a/src/infrastructure/adapters/openai_adapter.py b/src/infrastructure/adapters/openai_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..8771823d24eb7e37d982d420ea63b51469ee3fbf --- /dev/null +++ b/src/infrastructure/adapters/openai_adapter.py @@ -0,0 +1,29 @@ +from typing import AsyncGenerator +from src.core.ports.llm_port import LlmPort +from langchain_openai import ChatOpenAI +from src.core.config import settings +import json + +class OpenAiAdapter(LlmPort): + def __init__(self): + if settings.OPENAI_API_KEY: + self.llm = ChatOpenAI(api_key=settings.OPENAI_API_KEY, model="gpt-4", temperature=0.2) + else: + self.llm = None + + def generate(self, prompt: str) -> str: + if not self.llm: + return "LLM API Key not configured." + return self.llm.invoke(prompt).content + + async def generate_stream(self, prompt: str) -> AsyncGenerator[str, None]: + if not self.llm: + yield "data: [DONE]\n\n" + return + + for chunk in self.llm.stream(prompt): + if hasattr(chunk, 'content'): + yield f"data: {json.dumps({'token': chunk.content})}\n\n" + else: + yield f"data: {json.dumps({'token': str(chunk)})}\n\n" + yield "data: [DONE]\n\n" diff --git a/src/infrastructure/adapters/postgres_adapter.py b/src/infrastructure/adapters/postgres_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..945cdee22eb5e2fd841d40b7439a33e79f8911f4 --- /dev/null +++ b/src/infrastructure/adapters/postgres_adapter.py @@ -0,0 +1,39 @@ +import json +from typing import List, Optional +from sqlalchemy.orm import Session +from src.core.ports.chat_history_port import ChatHistoryPort, ChatMessage + +class Message(ChatMessage): + def __init__(self, role: str, content: str): + self.role = role + self.content = content + +class PostgresAdapter(ChatHistoryPort): + def __init__(self, db_session: Session): + self.db = db_session + + def get_history(self, session_id: str, limit: int = 6) -> List[ChatMessage]: + from src.core.domain.db_models import ChatHistory + past_messages = ( + self.db.query(ChatHistory) + .filter(ChatHistory.session_id == session_id) + .order_by(ChatHistory.timestamp.asc()) + .limit(limit) + .all() + ) + return [Message(role=msg.role, content=msg.content) for msg in past_messages] + + def save_interaction(self, session_id: str, user_query: str, ai_response: str, + retrieved_doc_ids: List[str], user_id: Optional[int] = None): + from src.core.domain.db_models import ChatHistory + user_msg = ChatHistory(session_id=session_id, role="user", content=user_query, user_id=user_id) + ai_msg = ChatHistory( + session_id=session_id, + role="assistant", + content=ai_response, + retrieved_doc_ids=json.dumps(retrieved_doc_ids), + user_id=user_id + ) + self.db.add(user_msg) + self.db.add(ai_msg) + self.db.commit() diff --git a/src/infrastructure/adapters/qdrant_adapter.py b/src/infrastructure/adapters/qdrant_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..3068e7f4f370d7a596f35ece9dabb1dbad5c1c3c --- /dev/null +++ b/src/infrastructure/adapters/qdrant_adapter.py @@ -0,0 +1,363 @@ +import logging +import traceback +from datetime import datetime, timedelta +from typing import Dict, Any, List, Optional + +from qdrant_client import QdrantClient +from qdrant_client.http import models + +from src.core.config import settings +from src.core.ports.vector_store_port import VectorStorePort, SearchResult + +logger = logging.getLogger(__name__) + + +def _build_filter( + source_filter: Optional[str], + language_filter: Optional[str], + days_back: Optional[int], +) -> Optional[models.Filter]: + """Build a Qdrant filter object from optional search constraints. + + source_filter supports prefix matching: "bbc" matches "bbc_english", "bbc_arabic", etc. + This handles the case where users say "BBC" but sources are stored as "bbc_english". + """ + must: list = [] + if source_filter: + src = source_filter.lower().strip() + # Use prefix match so "bbc" matches "bbc_english", "bbc_arabic", "bbc_swahili" etc. + must.append(models.FieldCondition( + key="source", + match=models.MatchText(text=src) # full-text contains match + )) + if language_filter: + must.append(models.FieldCondition( + key="language", match=models.MatchValue(value=language_filter) + )) + if days_back is not None: + min_date = datetime.utcnow() - timedelta(days=days_back) + must.append(models.FieldCondition( + key="published_at", + range=models.DatetimeRange(gte=min_date) + )) + return models.Filter(must=must) if must else None + + +def _points_to_results(points: list) -> List[SearchResult]: + hits = [] + for p in points: + payload = p.payload or {} + hits.append(SearchResult( + content=payload.get("text", payload.get("content", "")), + metadata=payload, + score=p.score, + doc_id=payload.get("doc_id"), + )) + return hits + + +class QdrantAdapter(VectorStorePort): + + def __init__(self): + try: + if settings.QDRANT_URL and settings.QDRANT_API_KEY: + self.client = QdrantClient( + url=settings.QDRANT_URL, + api_key=settings.QDRANT_API_KEY, + ) + logger.info(f"Connected to Qdrant Cloud: {settings.QDRANT_URL}") + else: + self.client = QdrantClient( + host=settings.QDRANT_HOST, port=settings.QDRANT_PORT + ) + logger.info(f"Connected to Qdrant at {settings.QDRANT_HOST}:{settings.QDRANT_PORT}") + self._ensure_indexes() + except Exception as e: + logger.error(f"Failed to connect to Qdrant: {e}") + self.client = None + + # ── Index management ────────────────────────────────────────────────────── + + def _ensure_indexes(self): + """Ensure required payload indexes exist — creates them if missing.""" + if not self.client: + return + try: + indexes = { + "source": models.PayloadSchemaType.KEYWORD, + "language": models.PayloadSchemaType.KEYWORD, + "published_at": models.PayloadSchemaType.DATETIME, + } + info = self.client.get_collection(settings.QDRANT_COLLECTION) + existing = set(info.payload_schema.keys()) if info.payload_schema else set() + for field, schema in indexes.items(): + if field not in existing: + logger.info(f"Creating missing payload index: {field}") + self.client.create_payload_index( + collection_name=settings.QDRANT_COLLECTION, + field_name=field, + field_schema=schema, + ) + logger.info(f"✅ Index created: {field}") + except Exception as e: + logger.warning(f"Could not ensure indexes: {e}") + + # ── Internal query executor ─────────────────────────────────────────────── + + def _execute_query( + self, + dense_vec: Optional[List[float]], + sparse_vec: Optional[Dict[str, Any]], + filter_obj: Optional[models.Filter], + limit: int, + label: str = "", + ) -> List[SearchResult]: + """ + Core Qdrant query executor. + - If both dense and sparse are provided → hybrid RRF search + - If only dense → pure dense ANN search + - If only sparse → pure sparse BM25 search + """ + try: + has_sparse = bool(sparse_vec and sparse_vec.get("indices")) + has_dense = bool(dense_vec) + + if has_dense and has_sparse: + prefetch = [ + models.Prefetch( + query=models.SparseVector( + indices=sparse_vec["indices"], + values=sparse_vec["values"], + ), + using="sparse", + limit=limit, + ), + models.Prefetch( + query=dense_vec, + using="dense", + limit=limit, + ), + ] + response = self.client.query_points( + collection_name=settings.QDRANT_COLLECTION, + prefetch=prefetch, + query=models.FusionQuery(fusion=models.Fusion.RRF), + query_filter=filter_obj, + limit=limit, + with_payload=True, + with_vectors=False, + ) + print(f"QDRANT [{label}]: hybrid RRF → {len(response.points)} hits") + + elif has_dense: + response = self.client.query_points( + collection_name=settings.QDRANT_COLLECTION, + query=dense_vec, + using="dense", + query_filter=filter_obj, + limit=limit, + with_payload=True, + with_vectors=False, + ) + print(f"QDRANT [{label}]: dense-only → {len(response.points)} hits") + + elif has_sparse: + response = self.client.query_points( + collection_name=settings.QDRANT_COLLECTION, + query=models.SparseVector( + indices=sparse_vec["indices"], + values=sparse_vec["values"], + ), + using="sparse", + query_filter=filter_obj, + limit=limit, + with_payload=True, + with_vectors=False, + ) + print(f"QDRANT [{label}]: sparse-only → {len(response.points)} hits") + + else: + logger.warning(f"QDRANT [{label}]: no vectors provided — returning empty") + return [] + + return _points_to_results(response.points) + + except Exception as e: + logger.error(f"QDRANT [{label}]: query failed — {e}\n{traceback.format_exc()}") + return [] + + # ── Public search interface ─────────────────────────────────────────────── + + def search( + self, + query_vectors: Dict[str, Any], + limit: int = 5, + source_filter: Optional[str] = None, + language_filter: Optional[str] = None, + days_back: Optional[int] = None, + ) -> List[SearchResult]: + """ + Standard hybrid search — dense + sparse from the same query. + Used by the /news/search endpoint and as a fallback. + Retries without the datetime filter if the index is missing. + """ + if not self.client: + return [] + + dense_vec = query_vectors.get("dense") + sparse_vec = query_vectors.get("sparse") + + # Attempt with datetime filter first, then without if index missing + for use_date_filter in ([True, False] if days_back is not None else [False]): + effective_days = days_back if use_date_filter else None + filter_obj = _build_filter(source_filter, language_filter, effective_days) + + try: + results = self._execute_query( + dense_vec, sparse_vec, filter_obj, limit, label=language_filter or "all" + ) + return results + except Exception as e: + if "Index required but not found" in str(e) and use_date_filter: + print("QDRANT: datetime index missing — retrying without date filter") + continue + logger.error(f"QDRANT search error: {e}") + return [] + + return [] + + def search_with_vectors( + self, + dense_vec: Optional[List[float]], + sparse_vec: Optional[Dict[str, Any]], + limit: int = 5, + source_filter: Optional[str] = None, + language_filter: Optional[str] = None, + days_back: Optional[int] = None, + ) -> List[SearchResult]: + """ + Optimised multilingual search — accepts pre-computed dense and sparse + vectors separately so callers can mix them freely. + + Key use case (multilingual pipeline): + - dense_vec = English query dense vector (language-agnostic, computed once) + - sparse_vec = translated query sparse vec (language-specific BM25, per lane) + - language_filter = the target language code for this lane + + This avoids recomputing the dense vector 6 times — it is computed once + from the English query and reused across all language lanes. + Retries without the datetime filter if the Qdrant index is missing. + """ + if not self.client: + return [] + + for use_date_filter in ([True, False] if days_back is not None else [False]): + effective_days = days_back if use_date_filter else None + filter_obj = _build_filter(source_filter, language_filter, effective_days) + + try: + results = self._execute_query( + dense_vec, sparse_vec, filter_obj, limit, + label=language_filter or "all" + ) + return results + except Exception as e: + if "Index required but not found" in str(e) and use_date_filter: + print(f"QDRANT [{language_filter}]: datetime index missing — retrying without date filter") + continue + logger.error(f"QDRANT search_with_vectors error [{language_filter}]: {e}") + return [] + + return [] + + # ── Other VectorStorePort methods ───────────────────────────────────────── + + def get_by_doc_id(self, doc_id: str) -> Optional[SearchResult]: + if not self.client: + return None + try: + results, _ = self.client.scroll( + collection_name=settings.QDRANT_COLLECTION, + scroll_filter=models.Filter( + must=[models.FieldCondition( + key="doc_id", match=models.MatchValue(value=doc_id) + )] + ), + limit=1, + with_payload=True, + with_vectors=False, + ) + if results: + payload = results[0].payload or {} + return SearchResult( + content=payload.get("text", payload.get("content", "")), + metadata=payload, + score=1.0, + doc_id=payload.get("doc_id"), + ) + return None + except Exception as e: + logger.error(f"Error getting doc from Qdrant: {e}") + return None + + def get_collection_stats(self) -> Dict[str, Any]: + if not self.client: + return {"vectors_count": 0} + try: + collection = self.client.get_collection(settings.QDRANT_COLLECTION) + return {"vectors_count": collection.points_count} + except Exception as e: + logger.error(f"Error getting collection stats: {e}") + return {"vectors_count": 0} + + def browse( + self, + limit: int = 20, + offset: int = 0, + source: Optional[str] = None, + language: Optional[str] = None, + ) -> Dict[str, Any]: + if not self.client: + return {"articles": [], "next_offset": None} + + must: list = [] + if source: + must.append(models.FieldCondition( + key="source", match=models.MatchValue(value=source) + )) + if language: + must.append(models.FieldCondition( + key="language", match=models.MatchValue(value=language) + )) + filter_obj = models.Filter(must=must) if must else None + + try: + # Fetch more than needed so we can deduplicate to first chunk per article + results, next_page_offset = self.client.scroll( + collection_name=settings.QDRANT_COLLECTION, + scroll_filter=filter_obj, + limit=limit * 8, + offset=offset, + with_payload=True, + with_vectors=False, + ) + + # Keep only the lowest chunk_index per doc_id (first chunk of each article) + seen_docs: dict = {} + for point in results: + payload = point.payload or {} + doc_id = payload.get("doc_id", point.id) + chunk_index = payload.get("chunk_index", 0) + if doc_id not in seen_docs or chunk_index < seen_docs[doc_id][1]: + seen_docs[doc_id] = (point, chunk_index) + + deduped = [v[0] for v in seen_docs.values()] + deduped.sort( + key=lambda p: (p.payload or {}).get("published_at") or "", + reverse=True, + ) + return {"articles": deduped[:limit], "next_offset": next_page_offset} + + except Exception as e: + logger.error(f"Error browsing Qdrant: {e}") + return {"articles": [], "next_offset": None} diff --git a/src/infrastructure/adapters/redis_adapter.py b/src/infrastructure/adapters/redis_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..0b50b452ef3e4e91ff71d02e83dd9126036dadd9 --- /dev/null +++ b/src/infrastructure/adapters/redis_adapter.py @@ -0,0 +1,66 @@ +import json +import logging +from typing import Optional, Dict, Any +import redis +import hashlib + +from src.core.ports.cache_port import CachePort +from src.core.config import settings + +logger = logging.getLogger(__name__) + +class RedisAdapter(CachePort): + def __init__(self): + try: + if hasattr(settings, 'REDIS_URL') and settings.REDIS_URL: + url = settings.REDIS_URL + # Upstash requires TLS — upgrade redis:// to rediss:// + if url.startswith("redis://") and "upstash.io" in url: + url = "rediss://" + url[len("redis://"):] + self.client = redis.from_url(url, decode_responses=True) + else: + pool = redis.ConnectionPool( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + db=settings.REDIS_DB, + password=settings.REDIS_PASSWORD or None, + decode_responses=True + ) + self.client = redis.Redis(connection_pool=pool) + self.client.ping() + logger.info("Connected to Redis cache.") + except Exception as e: + logger.warning(f"Could not connect to Redis: {e}. Cache disabled.") + self.client = None + + + + def get(self, key: str) -> Optional[Any]: + if not self.client: return None + try: + data = self.client.get(key) + return json.loads(data) if data else None + except Exception as e: + logger.error(f"Redis get error: {e}") + return None + + def set(self, key: str, value: Any, expiration: int = 3600) -> bool: + if not self.client: return False + try: + self.client.setex(key, expiration, json.dumps(value)) + return True + except Exception as e: + logger.error(f"Redis set error: {e}") + return False + + def search_similar(self, query_vector: list, threshold: float = 0.95) -> Optional[Dict[str, Any]]: + """ + In a full enterprise setup, this would use Redisearch vector indexes. + For this simplified adapter, we will rely on exact match (SHA256 of the prompt) + as an interim caching mechanism until Redis vector extensions are configured. + """ + pass + + def generate_exact_hash(self, text: str) -> str: + """Helper for exact match caching if vector search is unavailable.""" + return hashlib.sha256(text.encode('utf-8')).hexdigest() diff --git a/src/infrastructure/adapters/together_adapter.py b/src/infrastructure/adapters/together_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..a9bd09cba11e695eb7c4891705b69931bcf7be03 --- /dev/null +++ b/src/infrastructure/adapters/together_adapter.py @@ -0,0 +1,73 @@ +import json +import logging +from typing import AsyncGenerator + +from src.core.ports.llm_port import LlmPort +from src.core.config import settings + +logger = logging.getLogger(__name__) + +class TogetherAdapter(LlmPort): + """ + Together AI adapter via OpenAI-compatible API. + Free $25 credit on signup, then pay-as-you-go (very cheap). + Available models (set TOGETHER_MODEL env var): + meta-llama/Llama-3.3-70B-Instruct-Turbo ← default, best quality + mistralai/Mistral-Small-24B-Instruct-2501 + mistralai/Mixtral-8x7B-Instruct-v0.1 + mistralai/Mistral-7B-Instruct-v0.1 + Docs: https://docs.together.ai + """ + BASE_URL = "https://api.together.xyz/v1" + DEFAULT_MODEL = "meta-llama/Llama-3.3-70B-Instruct-Turbo" + + def __init__(self): + self.llm = None + key = settings.TOGETHER_API_KEY + if key and key != "your-together-api-key-here": + try: + from langchain_openai import ChatOpenAI + model = settings.TOGETHER_MODEL or self.DEFAULT_MODEL + self.llm = ChatOpenAI( + api_key=key, + base_url=self.BASE_URL, + model=model, + temperature=0.2, + max_tokens=1024, + ) + logger.info(f"✅ Together AI adapter initialized with model: {model}") + except Exception as e: + logger.error(f"Failed to initialize Together adapter: {e}") + else: + logger.warning("TOGETHER_API_KEY not set — Together AI adapter disabled.") + + def generate(self, prompt: str) -> str: + if not self.llm: + return "Together AI API key not configured." + try: + return self.llm.invoke(prompt).content + except Exception as e: + error_msg = str(e) + if "429" in error_msg or "rate" in error_msg.lower(): + return "Together AI rate limit reached. Please try again shortly." + logger.error(f"Together generate error: {e}") + return f"Error generating response: {error_msg}" + + async def generate_stream(self, prompt: str) -> AsyncGenerator[str, None]: + if not self.llm: + yield f"data: {json.dumps({'token': 'Together AI API key not configured.'})}\n\n" + yield "data: [DONE]\n\n" + return + try: + for chunk in self.llm.stream(prompt): + if hasattr(chunk, "content") and chunk.content: + yield f"data: {json.dumps({'token': chunk.content})}\n\n" + yield "data: [DONE]\n\n" + except Exception as e: + error_msg = str(e) + if "429" in error_msg or "rate" in error_msg.lower(): + msg = "Together AI rate limit reached. Please try again shortly." + else: + msg = f"Error: {error_msg}" + yield f"data: {json.dumps({'token': msg})}\n\n" + yield "data: [DONE]\n\n" diff --git a/src/infrastructure/database.py b/src/infrastructure/database.py new file mode 100644 index 0000000000000000000000000000000000000000..bff98613ab3d882c0d12f1458e7fcd17bfb498b5 --- /dev/null +++ b/src/infrastructure/database.py @@ -0,0 +1,37 @@ +from sqlalchemy import create_engine, text +from sqlalchemy.orm import sessionmaker +from src.core.config import settings +from src.core.domain.db_models import Base + +engine = create_engine( + settings.SQLALCHEMY_DATABASE_URI, + pool_recycle=300, + pool_pre_ping=True +) + +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +def init_db(): + Base.metadata.create_all(bind=engine) + _run_migrations() + +def _run_migrations(): + """Apply any missing schema changes that create_all won't handle (existing tables).""" + migrations = [ + "ALTER TABLE chat_history ADD COLUMN IF NOT EXISTS pinned BOOLEAN DEFAULT FALSE", + "ALTER TABLE chat_history ADD COLUMN IF NOT EXISTS user_id INTEGER REFERENCES users(id) ON DELETE SET NULL", + ] + with engine.connect() as conn: + for sql in migrations: + try: + conn.execute(text(sql)) + except Exception: + pass # Column may already exist or constraint may differ — safe to ignore + conn.commit() + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000000000000000000000000000000000000..0b6ea99207a7b1ceb9840cf63e279924b1d91a56 --- /dev/null +++ b/src/main.py @@ -0,0 +1,63 @@ +import os +import sys +from pathlib import Path + +# Add project root and local src to path to allow direct execution +current_dir = Path(__file__).resolve().parent +project_root = current_dir.parent.parent.parent.parent # d:\...pipeline +sys.path.append(str(project_root)) +sys.path.append(str(current_dir.parent)) # d:\...\rag-api + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from src.core.config import settings +from src.api.routes import rag, analytics, interactions, accounts, news, auth +from src.infrastructure.database import init_db +from src.api.dependencies import prewarm_models +import threading +import uvicorn + +app = FastAPI( + title=settings.PROJECT_NAME, + openapi_url=f"{settings.API_V1_STR}/openapi.json" +) + +# Set all CORS enabled origins +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +@app.on_event("startup") +def on_startup(): + init_db() + # Pre-warming moved to lazy loading or manual script to prevent + # initialization timeouts/crashes on the first run. + # thread = threading.Thread(target=prewarm_models) + # thread.start() + +app.include_router(rag.router, prefix=f"{settings.API_V1_STR}/rag", tags=["RAG"]) +app.include_router(analytics.router, prefix=f"{settings.API_V1_STR}/analytics", tags=["Analytics"]) +app.include_router(interactions.router, prefix=f"{settings.API_V1_STR}/interactions", tags=["Interactions"]) +app.include_router(accounts.router, prefix=f"{settings.API_V1_STR}/accounts", tags=["Accounts"]) +# Alias for registration convenience +app.include_router(accounts.router, prefix=f"{settings.API_V1_STR}/users", tags=["Users"], include_in_schema=False) +app.include_router(auth.router, prefix=f"{settings.API_V1_STR}/auth", tags=["Auth"]) +app.include_router(news.router, prefix=f"{settings.API_V1_STR}/news", tags=["News"]) + +@app.get(f"{settings.API_V1_STR}/status") +def get_status(): + return {"status": "online", "message": "RAG API Service is operational"} + +@app.get("/") +def read_root(): + return {"message": "Welcome to the RAG API Service", "status": "online"} + +if __name__ == "__main__": + # Start the server directly from this script + # Use the app object directly for simplicity when running as main + uvicorn.run(app, host="0.0.0.0", port=8000) + diff --git a/tests/test_rag_pipeline.py b/tests/test_rag_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..faa78bae86c198048a67302085328f84476c5eba --- /dev/null +++ b/tests/test_rag_pipeline.py @@ -0,0 +1,34 @@ +import sys +import os +from unittest.mock import MagicMock + +# Add src to path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))) + +try: + import tiktoken + from src.api.routes.rag import limit_context_tokens + print("Tiktoken and limit_context_tokens imported successfully.") +except ImportError as e: + print(f"Import error: {e}") + +def test_context_limiter(): + sources = [ + {"content": "This is a long piece of text " * 100, "metadata": {}}, + {"content": "Another short text", "metadata": {}}, + {"content": "Third piece of text", "metadata": {}} + ] + + # Test with 100 tokens + context, filtered = limit_context_tokens(sources, max_tokens=100) + enc = tiktoken.get_encoding("cl100k_base") + tokens = len(enc.encode(context)) + print(f"Context token count: {tokens}") + assert tokens <= 110 # slight buffer for delimiters + print("Context limiter test passed!") + +if __name__ == "__main__": + try: + test_context_limiter() + except Exception as e: + print(f"Test failed: {e}")