Spaces:
Sleeping
Sleeping
Commit
·
690700c
0
Parent(s):
Initial deployment to HuggingFace Spaces
Browse files- .env.example +177 -0
- .gitignore +479 -0
- Dockerfile +27 -0
- README_SPACE.md +28 -0
- app/__init__.py +11 -0
- app/api/__init__.py +6 -0
- app/api/v1/__init__.py +8 -0
- app/api/v1/auth.py +193 -0
- app/api/v1/chat.py +1346 -0
- app/config.py +236 -0
- app/core/__init__.py +6 -0
- app/core/llm_manager.py +445 -0
- app/db/__init__.py +6 -0
- app/db/mongodb.py +390 -0
- app/db/repositories/__init__.py +6 -0
- app/db/repositories/conversation_repository.py +1049 -0
- app/db/repositories/user_repository.py +155 -0
- app/main.py +301 -0
- app/ml/__init__.py +6 -0
- app/ml/policy_network.py +610 -0
- app/ml/retriever.py +522 -0
- app/models/__init__.py +3 -0
- app/models/user.py +69 -0
- app/services/__init__.py +6 -0
- app/services/chat_service.py +335 -0
- app/utils/dependencies.py +87 -0
- app/utils/security.py +190 -0
- backups/backup_chat_service.py +340 -0
- backups/backup_config.py +640 -0
- backups/backup_llm_manager.py +430 -0
- backups/backup_main.py +275 -0
- backups/backup_requirements.txt +182 -0
- build_faiss_index.py +339 -0
- folder_structure.txt +40 -0
- requirements.txt +38 -0
.env.example
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ================================================================================
|
| 2 |
+
# BANKING RAG CHATBOT API - ENVIRONMENT VARIABLES
|
| 3 |
+
# Copy this file to .env and fill in your actual values
|
| 4 |
+
# ================================================================================
|
| 5 |
+
|
| 6 |
+
# ============================================================================
|
| 7 |
+
# APPLICATION SETTINGS
|
| 8 |
+
# ============================================================================
|
| 9 |
+
DEBUG=False
|
| 10 |
+
ENVIRONMENT=production
|
| 11 |
+
|
| 12 |
+
# ============================================================================
|
| 13 |
+
# MONGODB (Get from: https://www.mongodb.com/cloud/atlas)
|
| 14 |
+
# ============================================================================
|
| 15 |
+
# Connection string format:
|
| 16 |
+
# example string here
|
| 17 |
+
MONGODB_URI=example
|
| 18 |
+
DATABASE_NAME=banking_rag_db
|
| 19 |
+
|
| 20 |
+
# ============================================================================
|
| 21 |
+
# SECURITY
|
| 22 |
+
# ============================================================================
|
| 23 |
+
# Generate a secure secret key with:
|
| 24 |
+
# python -c "import secrets; print(secrets.token_urlsafe(32))"
|
| 25 |
+
SECRET_KEY=your-secret-key-here-change-this-in-production-min-32-characters
|
| 26 |
+
ALGORITHM=HS256
|
| 27 |
+
ACCESS_TOKEN_EXPIRE_MINUTES=1440
|
| 28 |
+
|
| 29 |
+
# ============================================================================
|
| 30 |
+
# LLM API KEYS - ALL THREE CO-EXIST (No fallback logic)
|
| 31 |
+
# ============================================================================
|
| 32 |
+
|
| 33 |
+
# --- GOOGLE GEMINI API (PRIMARY) ---
|
| 34 |
+
# Get from: https://aistudio.google.com/app/apikey
|
| 35 |
+
# You have Google Pro - this is your main LLM for response generation
|
| 36 |
+
GOOGLE_API_KEY=AIzaSyXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
| 37 |
+
|
| 38 |
+
# Which Gemini model to use
|
| 39 |
+
# Options: gemini-2.0-flash-lite, gemini-1.5-flash
|
| 40 |
+
GEMINI_MODEL=gemini-2.0-flash-lite
|
| 41 |
+
|
| 42 |
+
# Gemini rate limits (Pro tier)
|
| 43 |
+
GEMINI_REQUESTS_PER_MINUTE=60
|
| 44 |
+
GEMINI_TOKENS_PER_MINUTE=60000
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# --- GROQ API (SECONDARY) ---
|
| 48 |
+
# Get from: https://console.groq.com/keys
|
| 49 |
+
# Single key for specific fast inference tasks (llama models)
|
| 50 |
+
GROQ_API_KEY=gsk_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
| 51 |
+
|
| 52 |
+
# Groq model (fast inference for policy evaluations)
|
| 53 |
+
GROQ_MODEL=llama3-70b-8192
|
| 54 |
+
|
| 55 |
+
# Groq rate limits (Free tier)
|
| 56 |
+
GROQ_REQUESTS_PER_MINUTE=30
|
| 57 |
+
GROQ_TOKENS_PER_MINUTE=30000
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# --- HUGGING FACE TOKEN (REQUIRED) ---
|
| 61 |
+
# Get from: https://huggingface.co/settings/tokens
|
| 62 |
+
# Required for: Model downloads (e5-base-v2, BERT), embeddings
|
| 63 |
+
HF_TOKEN=hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
| 64 |
+
|
| 65 |
+
# ============================================================================
|
| 66 |
+
# MODEL PATHS (Local storage)
|
| 67 |
+
# ============================================================================
|
| 68 |
+
RETRIEVER_MODEL_PATH=models/best_retriever_model.pth
|
| 69 |
+
POLICY_MODEL_PATH=models/policy_network.pt
|
| 70 |
+
FAISS_INDEX_PATH=models/faiss_index.pkl
|
| 71 |
+
KB_PATH=data/final_knowledge_base.jsonl
|
| 72 |
+
|
| 73 |
+
# ============================================================================
|
| 74 |
+
# RAG PARAMETERS
|
| 75 |
+
# ============================================================================
|
| 76 |
+
# Number of documents to retrieve from FAISS
|
| 77 |
+
TOP_K=5
|
| 78 |
+
|
| 79 |
+
# Minimum similarity threshold for retrieval
|
| 80 |
+
SIMILARITY_THRESHOLD=0.5
|
| 81 |
+
|
| 82 |
+
# Maximum context length to send to LLM (in characters)
|
| 83 |
+
MAX_CONTEXT_LENGTH=2000
|
| 84 |
+
|
| 85 |
+
# ============================================================================
|
| 86 |
+
# POLICY NETWORK PARAMETERS
|
| 87 |
+
# ============================================================================
|
| 88 |
+
# Maximum sequence length for policy input
|
| 89 |
+
POLICY_MAX_LEN=256
|
| 90 |
+
|
| 91 |
+
# Confidence threshold for policy decisions
|
| 92 |
+
CONFIDENCE_THRESHOLD=0.7
|
| 93 |
+
|
| 94 |
+
# ============================================================================
|
| 95 |
+
# LLM GENERATION PARAMETERS
|
| 96 |
+
# ============================================================================
|
| 97 |
+
# Temperature for response generation (0.0 = deterministic, 1.0 = creative)
|
| 98 |
+
LLM_TEMPERATURE=0.7
|
| 99 |
+
|
| 100 |
+
# Maximum tokens to generate in response
|
| 101 |
+
LLM_MAX_TOKENS=512
|
| 102 |
+
|
| 103 |
+
# System prompt template
|
| 104 |
+
SYSTEM_PROMPT=You are a helpful banking assistant. Answer questions clearly and concisely.
|
| 105 |
+
|
| 106 |
+
# ============================================================================
|
| 107 |
+
# LLM ROUTING STRATEGY
|
| 108 |
+
# ============================================================================
|
| 109 |
+
# Define which LLM to use for which task
|
| 110 |
+
# Options: gemini, groq
|
| 111 |
+
|
| 112 |
+
# Main chat responses (user-facing) - Use Gemini Pro (best quality)
|
| 113 |
+
CHAT_LLM=gemini
|
| 114 |
+
|
| 115 |
+
# Response evaluation (RL training) - Use Groq (fast, good enough)
|
| 116 |
+
EVALUATION_LLM=groq
|
| 117 |
+
|
| 118 |
+
# Policy network inference - Local BERT model (no API call)
|
| 119 |
+
POLICY_LLM=local
|
| 120 |
+
|
| 121 |
+
# ============================================================================
|
| 122 |
+
# CORS SETTINGS (for frontend)
|
| 123 |
+
# ============================================================================
|
| 124 |
+
# Comma-separated list of allowed origins
|
| 125 |
+
# Use "*" for development (allows all origins)
|
| 126 |
+
# For production, specify exact domains:
|
| 127 |
+
# ALLOWED_ORIGINS=https://yourdomain.com,https://www.yourdomain.com
|
| 128 |
+
ALLOWED_ORIGINS=*
|
| 129 |
+
|
| 130 |
+
# ============================================================================
|
| 131 |
+
# LOGGING
|
| 132 |
+
# ============================================================================
|
| 133 |
+
LOG_LEVEL=INFO
|
| 134 |
+
|
| 135 |
+
# ============================================================================
|
| 136 |
+
# OPTIONAL: Advanced Settings
|
| 137 |
+
# ============================================================================
|
| 138 |
+
# Maximum conversation history to include in context
|
| 139 |
+
MAX_HISTORY_TURNS=4
|
| 140 |
+
|
| 141 |
+
# Enable/disable response caching
|
| 142 |
+
ENABLE_CACHE=True
|
| 143 |
+
|
| 144 |
+
# Cache TTL in seconds (1 hour)
|
| 145 |
+
CACHE_TTL=3600
|
| 146 |
+
|
| 147 |
+
# Environment
|
| 148 |
+
ENVIRONMENT=production
|
| 149 |
+
DEBUG=False
|
| 150 |
+
|
| 151 |
+
# MongoDB
|
| 152 |
+
MONGODB_URI=your_mongodb_uri_here
|
| 153 |
+
|
| 154 |
+
# JWT
|
| 155 |
+
SECRET_KEY=your-secret-key-here
|
| 156 |
+
ALGORITHM=HS256
|
| 157 |
+
ACCESS_TOKEN_EXPIRE_MINUTES=1440
|
| 158 |
+
|
| 159 |
+
# Groq API Keys
|
| 160 |
+
GROQ_API_KEY_1=your_groq_key_1
|
| 161 |
+
GROQ_API_KEY_2=your_groq_key_2
|
| 162 |
+
GROQ_API_KEY_3=your_groq_key_3
|
| 163 |
+
|
| 164 |
+
# HuggingFace Tokens
|
| 165 |
+
HF_TOKEN_1=your_hf_token_1
|
| 166 |
+
HF_TOKEN_2=your_hf_token_2
|
| 167 |
+
HF_TOKEN_3=your_hf_token_3
|
| 168 |
+
|
| 169 |
+
# HuggingFace Model Repository
|
| 170 |
+
HF_MODEL_REPO=YOUR_USERNAME/questrag-models
|
| 171 |
+
|
| 172 |
+
# CORS
|
| 173 |
+
ALLOWED_ORIGINS=*
|
| 174 |
+
|
| 175 |
+
# Device
|
| 176 |
+
DEVICE=cpu
|
| 177 |
+
|
.gitignore
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# # ================================================================================
|
| 2 |
+
# # BANKING RAG CHATBOT - .gitignore
|
| 3 |
+
# # Prevents committing sensitive files, large models, and temporary files
|
| 4 |
+
# # IMPORTANT: This protects your API keys and credentials!
|
| 5 |
+
# # ================================================================================
|
| 6 |
+
|
| 7 |
+
# # ============================================================================
|
| 8 |
+
# # ENVIRONMENT VARIABLES (MOST IMPORTANT)
|
| 9 |
+
# # ============================================================================
|
| 10 |
+
# .env
|
| 11 |
+
# .env.local
|
| 12 |
+
# .env.production
|
| 13 |
+
# .env.development
|
| 14 |
+
# .env.*.local
|
| 15 |
+
|
| 16 |
+
# # ============================================================================
|
| 17 |
+
# # PYTHON
|
| 18 |
+
# # ============================================================================
|
| 19 |
+
# # Byte-compiled / optimized / DLL files
|
| 20 |
+
# __pycache__/
|
| 21 |
+
# *.py[cod]
|
| 22 |
+
# *$py.class
|
| 23 |
+
# *.so
|
| 24 |
+
|
| 25 |
+
# # C extensions
|
| 26 |
+
# *.so
|
| 27 |
+
|
| 28 |
+
# # Distribution / packaging
|
| 29 |
+
# .Python
|
| 30 |
+
# build/
|
| 31 |
+
# develop-eggs/
|
| 32 |
+
# dist/
|
| 33 |
+
# downloads/
|
| 34 |
+
# eggs/
|
| 35 |
+
# .eggs/
|
| 36 |
+
# lib/
|
| 37 |
+
# lib64/
|
| 38 |
+
# parts/
|
| 39 |
+
# sdist/
|
| 40 |
+
# var/
|
| 41 |
+
# wheels/
|
| 42 |
+
# pip-wheel-metadata/
|
| 43 |
+
# share/python-wheels/
|
| 44 |
+
# *.egg-info/
|
| 45 |
+
# .installed.cfg
|
| 46 |
+
# *.egg
|
| 47 |
+
# MANIFEST
|
| 48 |
+
|
| 49 |
+
# # PyInstaller
|
| 50 |
+
# *.manifest
|
| 51 |
+
# *.spec
|
| 52 |
+
|
| 53 |
+
# # Unit test / coverage reports
|
| 54 |
+
# htmlcov/
|
| 55 |
+
# .tox/
|
| 56 |
+
# .nox/
|
| 57 |
+
# .coverage
|
| 58 |
+
# .coverage.*
|
| 59 |
+
# .cache
|
| 60 |
+
# nosetests.xml
|
| 61 |
+
# coverage.xml
|
| 62 |
+
# *.cover
|
| 63 |
+
# *.py,cover
|
| 64 |
+
# .hypothesis/
|
| 65 |
+
# .pytest_cache/
|
| 66 |
+
|
| 67 |
+
# # Jupyter Notebook
|
| 68 |
+
# .ipynb_checkpoints
|
| 69 |
+
# *.ipynb
|
| 70 |
+
|
| 71 |
+
# # IPython
|
| 72 |
+
# profile_default/
|
| 73 |
+
# ipython_config.py
|
| 74 |
+
|
| 75 |
+
# # Virtual environments
|
| 76 |
+
# venv/
|
| 77 |
+
# env/
|
| 78 |
+
# ENV/
|
| 79 |
+
# env.bak/
|
| 80 |
+
# venv.bak/
|
| 81 |
+
# .venv
|
| 82 |
+
# .env/
|
| 83 |
+
|
| 84 |
+
# # pyenv
|
| 85 |
+
# .python-version
|
| 86 |
+
|
| 87 |
+
# # pipenv
|
| 88 |
+
# Pipfile.lock
|
| 89 |
+
|
| 90 |
+
# # Poetry
|
| 91 |
+
# poetry.lock
|
| 92 |
+
|
| 93 |
+
# # ============================================================================
|
| 94 |
+
# # ML MODELS (Optional - depends on Git LFS usage)
|
| 95 |
+
# # ============================================================================
|
| 96 |
+
# # If NOT using Git LFS, uncomment these to avoid committing large models:
|
| 97 |
+
# # models/*.pth
|
| 98 |
+
# # models/*.pt
|
| 99 |
+
# # models/*.pkl
|
| 100 |
+
# # models/*.bin
|
| 101 |
+
# # models/*.h5
|
| 102 |
+
|
| 103 |
+
# # If using Git LFS, comment out the above and models will be tracked by LFS
|
| 104 |
+
|
| 105 |
+
# # Temporary model files
|
| 106 |
+
# models/temp/
|
| 107 |
+
# models/checkpoints/
|
| 108 |
+
# *.ckpt
|
| 109 |
+
|
| 110 |
+
# # ============================================================================
|
| 111 |
+
# # DATA FILES
|
| 112 |
+
# # ============================================================================
|
| 113 |
+
# # Prevent committing large datasets
|
| 114 |
+
# data/*.csv
|
| 115 |
+
# data/*.json
|
| 116 |
+
# data/*.jsonl
|
| 117 |
+
# data/*.parquet
|
| 118 |
+
# data/*.feather
|
| 119 |
+
# data/raw/
|
| 120 |
+
# data/processed/
|
| 121 |
+
# data/temp/
|
| 122 |
+
|
| 123 |
+
# # Keep .gitkeep files to preserve directory structure
|
| 124 |
+
# !data/.gitkeep
|
| 125 |
+
# !models/.gitkeep
|
| 126 |
+
|
| 127 |
+
# # ============================================================================
|
| 128 |
+
# # LOGS
|
| 129 |
+
# # ============================================================================
|
| 130 |
+
# # Log files
|
| 131 |
+
# *.log
|
| 132 |
+
# logs/
|
| 133 |
+
# *.log.*
|
| 134 |
+
|
| 135 |
+
# # ============================================================================
|
| 136 |
+
# # DATABASES
|
| 137 |
+
# # ============================================================================
|
| 138 |
+
# # Local database files
|
| 139 |
+
# *.db
|
| 140 |
+
# *.sqlite
|
| 141 |
+
# *.sqlite3
|
| 142 |
+
|
| 143 |
+
# # ============================================================================
|
| 144 |
+
# # IDEs & EDITORS
|
| 145 |
+
# # ============================================================================
|
| 146 |
+
# # VSCode
|
| 147 |
+
# .vscode/
|
| 148 |
+
# *.code-workspace
|
| 149 |
+
|
| 150 |
+
# # PyCharm
|
| 151 |
+
# .idea/
|
| 152 |
+
# *.iml
|
| 153 |
+
# *.iws
|
| 154 |
+
|
| 155 |
+
# # Sublime Text
|
| 156 |
+
# *.sublime-project
|
| 157 |
+
# *.sublime-workspace
|
| 158 |
+
|
| 159 |
+
# # Vim
|
| 160 |
+
# *.swp
|
| 161 |
+
# *.swo
|
| 162 |
+
# *~
|
| 163 |
+
|
| 164 |
+
# # Emacs
|
| 165 |
+
# *~
|
| 166 |
+
# \#*\#
|
| 167 |
+
# .\#*
|
| 168 |
+
|
| 169 |
+
# # ============================================================================
|
| 170 |
+
# # MACOS
|
| 171 |
+
# # ============================================================================
|
| 172 |
+
# .DS_Store
|
| 173 |
+
# .AppleDouble
|
| 174 |
+
# .LSOverride
|
| 175 |
+
|
| 176 |
+
# # Thumbnails
|
| 177 |
+
# ._*
|
| 178 |
+
|
| 179 |
+
# # ============================================================================
|
| 180 |
+
# # WINDOWS
|
| 181 |
+
# # ============================================================================
|
| 182 |
+
# Thumbs.db
|
| 183 |
+
# Thumbs.db:encryptable
|
| 184 |
+
# ehthumbs.db
|
| 185 |
+
# ehthumbs_vista.db
|
| 186 |
+
# Desktop.ini
|
| 187 |
+
|
| 188 |
+
# # ============================================================================
|
| 189 |
+
# # LINUX
|
| 190 |
+
# # ============================================================================
|
| 191 |
+
# *~
|
| 192 |
+
|
| 193 |
+
# # ============================================================================
|
| 194 |
+
# # DOCKER
|
| 195 |
+
# # ============================================================================
|
| 196 |
+
# # Docker files (if we decide to add Docker later)
|
| 197 |
+
# # .dockerignore is separate
|
| 198 |
+
# docker-compose.override.yml
|
| 199 |
+
# .docker/
|
| 200 |
+
|
| 201 |
+
# # ============================================================================
|
| 202 |
+
# # TEMPORARY FILES
|
| 203 |
+
# # ============================================================================
|
| 204 |
+
# *.tmp
|
| 205 |
+
# *.temp
|
| 206 |
+
# *.bak
|
| 207 |
+
# *.backup
|
| 208 |
+
# *.old
|
| 209 |
+
# tmp/
|
| 210 |
+
# temp/
|
| 211 |
+
|
| 212 |
+
# # ============================================================================
|
| 213 |
+
# # SECRETS & CREDENTIALS
|
| 214 |
+
# # ============================================================================
|
| 215 |
+
# # Any file with "secret" or "credential" in name
|
| 216 |
+
# *secret*
|
| 217 |
+
# *credential*
|
| 218 |
+
# *password*
|
| 219 |
+
# *.pem
|
| 220 |
+
# *.key
|
| 221 |
+
# *.crt
|
| 222 |
+
# *.cer
|
| 223 |
+
|
| 224 |
+
# # AWS credentials
|
| 225 |
+
# .aws/
|
| 226 |
+
|
| 227 |
+
# # Google Cloud credentials
|
| 228 |
+
# *-credentials.json
|
| 229 |
+
# service-account*.json
|
| 230 |
+
|
| 231 |
+
# # ============================================================================
|
| 232 |
+
# # CACHE
|
| 233 |
+
# # ============================================================================
|
| 234 |
+
# .cache/
|
| 235 |
+
# *.cache
|
| 236 |
+
|
| 237 |
+
# # Hugging Face cache
|
| 238 |
+
# .huggingface/
|
| 239 |
+
# transformers_cache/
|
| 240 |
+
|
| 241 |
+
# # ============================================================================
|
| 242 |
+
# # MONITORING & PROFILING
|
| 243 |
+
# # ============================================================================
|
| 244 |
+
# *.prof
|
| 245 |
+
# *.lprof
|
| 246 |
+
|
| 247 |
+
# # ============================================================================
|
| 248 |
+
# # FRONTEND (if you add frontend to this repo)
|
| 249 |
+
# # ============================================================================
|
| 250 |
+
# node_modules/
|
| 251 |
+
# .next/
|
| 252 |
+
# out/
|
| 253 |
+
# build/
|
| 254 |
+
# dist/
|
| 255 |
+
|
| 256 |
+
# # ============================================================================
|
| 257 |
+
# # MISC
|
| 258 |
+
# # ============================================================================
|
| 259 |
+
# # System files
|
| 260 |
+
# .Trash-*
|
| 261 |
+
|
| 262 |
+
# # Backup files
|
| 263 |
+
# *~
|
| 264 |
+
# *.orig
|
| 265 |
+
# *.rej
|
| 266 |
+
|
| 267 |
+
# # Compressed files
|
| 268 |
+
# *.zip
|
| 269 |
+
# *.tar.gz
|
| 270 |
+
# *.rar
|
| 271 |
+
# *.7z
|
| 272 |
+
|
| 273 |
+
# # ============================================================================
|
| 274 |
+
# # PROJECT-SPECIFIC
|
| 275 |
+
# # ============================================================================
|
| 276 |
+
# # Add any project-specific patterns here
|
| 277 |
+
# uploads/
|
| 278 |
+
# downloads/
|
| 279 |
+
# exports/
|
| 280 |
+
|
| 281 |
+
# # Test outputs
|
| 282 |
+
# test_outputs/
|
| 283 |
+
# test_results/
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
# ================================================================================
|
| 302 |
+
# BANKING RAG CHATBOT - Backend .gitignore
|
| 303 |
+
# Prevents committing sensitive files, large models, and temporary files
|
| 304 |
+
# ================================================================================
|
| 305 |
+
|
| 306 |
+
# ============================================================================
|
| 307 |
+
# ENVIRONMENT VARIABLES (MOST IMPORTANT - PROTECTS API KEYS)
|
| 308 |
+
# ============================================================================
|
| 309 |
+
.env
|
| 310 |
+
.env.local
|
| 311 |
+
.env.production
|
| 312 |
+
.env.development
|
| 313 |
+
.env.*.local
|
| 314 |
+
|
| 315 |
+
# ============================================================================
|
| 316 |
+
# PYTHON
|
| 317 |
+
# ============================================================================
|
| 318 |
+
__pycache__/
|
| 319 |
+
*.py[cod]
|
| 320 |
+
*$py.class
|
| 321 |
+
*.so
|
| 322 |
+
.Python
|
| 323 |
+
build/
|
| 324 |
+
develop-eggs/
|
| 325 |
+
dist/
|
| 326 |
+
downloads/
|
| 327 |
+
eggs/
|
| 328 |
+
.eggs/
|
| 329 |
+
lib/
|
| 330 |
+
lib64/
|
| 331 |
+
parts/
|
| 332 |
+
sdist/
|
| 333 |
+
var/
|
| 334 |
+
wheels/
|
| 335 |
+
*.egg-info/
|
| 336 |
+
.installed.cfg
|
| 337 |
+
*.egg
|
| 338 |
+
MANIFEST
|
| 339 |
+
|
| 340 |
+
# Virtual environments
|
| 341 |
+
venv/
|
| 342 |
+
env/
|
| 343 |
+
ENV/
|
| 344 |
+
env.bak/
|
| 345 |
+
venv.bak/
|
| 346 |
+
.venv
|
| 347 |
+
.env/
|
| 348 |
+
|
| 349 |
+
# Testing
|
| 350 |
+
htmlcov/
|
| 351 |
+
.tox/
|
| 352 |
+
.nox/
|
| 353 |
+
.coverage
|
| 354 |
+
.coverage.*
|
| 355 |
+
.cache
|
| 356 |
+
nosetests.xml
|
| 357 |
+
coverage.xml
|
| 358 |
+
*.cover
|
| 359 |
+
.hypothesis/
|
| 360 |
+
.pytest_cache/
|
| 361 |
+
|
| 362 |
+
# Jupyter Notebook
|
| 363 |
+
.ipynb_checkpoints
|
| 364 |
+
*.ipynb
|
| 365 |
+
|
| 366 |
+
# ============================================================================
|
| 367 |
+
# ML MODELS & DATA (DO NOT COMMIT - TOO LARGE 1GB+)
|
| 368 |
+
# ============================================================================
|
| 369 |
+
# Model files (PyTorch, TensorFlow, etc.)
|
| 370 |
+
app/models/*.pth
|
| 371 |
+
app/models/*.pt
|
| 372 |
+
app/models/*.pkl
|
| 373 |
+
app/models/*.bin
|
| 374 |
+
app/models/*.h5
|
| 375 |
+
app/models/*.onnx
|
| 376 |
+
app/models/*.safetensors
|
| 377 |
+
models/*.pth
|
| 378 |
+
models/*.pt
|
| 379 |
+
models/*.pkl
|
| 380 |
+
models/*.bin
|
| 381 |
+
|
| 382 |
+
# FAISS indices (large vector databases)
|
| 383 |
+
app/models/*.index
|
| 384 |
+
app/models/*.faiss
|
| 385 |
+
*.index
|
| 386 |
+
*.faiss
|
| 387 |
+
|
| 388 |
+
# Knowledge base and data files
|
| 389 |
+
app/data/*.jsonl
|
| 390 |
+
app/data/*.json
|
| 391 |
+
app/data/*.csv
|
| 392 |
+
app/data/*.parquet
|
| 393 |
+
data/*.jsonl
|
| 394 |
+
data/*.json
|
| 395 |
+
data/*.csv
|
| 396 |
+
|
| 397 |
+
# Temporary model files
|
| 398 |
+
models/temp/
|
| 399 |
+
models/checkpoints/
|
| 400 |
+
*.ckpt
|
| 401 |
+
|
| 402 |
+
# ============================================================================
|
| 403 |
+
# LOGS
|
| 404 |
+
# ============================================================================
|
| 405 |
+
*.log
|
| 406 |
+
logs/
|
| 407 |
+
*.log.*
|
| 408 |
+
|
| 409 |
+
# ============================================================================
|
| 410 |
+
# DATABASES
|
| 411 |
+
# ============================================================================
|
| 412 |
+
*.db
|
| 413 |
+
*.sqlite
|
| 414 |
+
*.sqlite3
|
| 415 |
+
|
| 416 |
+
# ============================================================================
|
| 417 |
+
# IDEs & EDITORS
|
| 418 |
+
# ============================================================================
|
| 419 |
+
# VSCode
|
| 420 |
+
.vscode/
|
| 421 |
+
*.code-workspace
|
| 422 |
+
|
| 423 |
+
# PyCharm
|
| 424 |
+
.idea/
|
| 425 |
+
*.iml
|
| 426 |
+
*.iws
|
| 427 |
+
|
| 428 |
+
# ============================================================================
|
| 429 |
+
# OS FILES
|
| 430 |
+
# ============================================================================
|
| 431 |
+
# macOS
|
| 432 |
+
.DS_Store
|
| 433 |
+
.AppleDouble
|
| 434 |
+
.LSOverride
|
| 435 |
+
._*
|
| 436 |
+
|
| 437 |
+
# Windows
|
| 438 |
+
Thumbs.db
|
| 439 |
+
ehthumbs.db
|
| 440 |
+
Desktop.ini
|
| 441 |
+
|
| 442 |
+
# Linux
|
| 443 |
+
*~
|
| 444 |
+
|
| 445 |
+
# ============================================================================
|
| 446 |
+
# CACHE & TEMPORARY
|
| 447 |
+
# ============================================================================
|
| 448 |
+
.cache/
|
| 449 |
+
*.cache
|
| 450 |
+
.huggingface/
|
| 451 |
+
transformers_cache/
|
| 452 |
+
sentence_transformers/
|
| 453 |
+
*.tmp
|
| 454 |
+
*.temp
|
| 455 |
+
*.bak
|
| 456 |
+
tmp/
|
| 457 |
+
temp/
|
| 458 |
+
|
| 459 |
+
# ============================================================================
|
| 460 |
+
# SECRETS & CREDENTIALS (DOUBLE CHECK)
|
| 461 |
+
# ============================================================================
|
| 462 |
+
*secret*
|
| 463 |
+
*credential*
|
| 464 |
+
*password*
|
| 465 |
+
*.pem
|
| 466 |
+
*.key
|
| 467 |
+
*.crt
|
| 468 |
+
*.cer
|
| 469 |
+
.aws/
|
| 470 |
+
*-credentials.json
|
| 471 |
+
service-account*.json
|
| 472 |
+
|
| 473 |
+
# ============================================================================
|
| 474 |
+
# KEEP DIRECTORY STRUCTURE
|
| 475 |
+
# ============================================================================
|
| 476 |
+
# These .gitkeep files preserve empty directories
|
| 477 |
+
!app/models/.gitkeep
|
| 478 |
+
!app/data/.gitkeep
|
| 479 |
+
!logs/.gitkeep
|
Dockerfile
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use Python 3.12 slim
|
| 2 |
+
FROM python:3.12-slim
|
| 3 |
+
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
# Install system dependencies
|
| 7 |
+
RUN apt-get update && apt-get install -y \
|
| 8 |
+
git \
|
| 9 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
# Copy requirements
|
| 12 |
+
COPY requirements.txt .
|
| 13 |
+
|
| 14 |
+
# Install Python packages
|
| 15 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 16 |
+
|
| 17 |
+
# Copy entire app
|
| 18 |
+
COPY . .
|
| 19 |
+
|
| 20 |
+
# Environment variables
|
| 21 |
+
ENV PYTHONUNBUFFERED=1
|
| 22 |
+
|
| 23 |
+
# Expose HuggingFace Spaces port (7860)
|
| 24 |
+
EXPOSE 7860
|
| 25 |
+
|
| 26 |
+
# Run FastAPI
|
| 27 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README_SPACE.md
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: QUESTRAG Backend
|
| 3 |
+
emoji: 🏦
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# QUESTRAG Banking Chatbot Backend
|
| 11 |
+
|
| 12 |
+
FastAPI backend for QUESTRAG - Banking RAG Chatbot with Reinforcement Learning.
|
| 13 |
+
|
| 14 |
+
## Features
|
| 15 |
+
- 🤖 RAG Pipeline with FAISS
|
| 16 |
+
- 🧠 RL-based Policy Network
|
| 17 |
+
- ⚡ Groq (Llama 3) + HuggingFace fallback
|
| 18 |
+
- 🔐 JWT Authentication
|
| 19 |
+
- 📊 MongoDB Atlas
|
| 20 |
+
|
| 21 |
+
## API Documentation
|
| 22 |
+
Visit `/docs` for interactive Swagger UI.
|
| 23 |
+
|
| 24 |
+
## Endpoints
|
| 25 |
+
- `POST /api/v1/auth/register` - Register new user
|
| 26 |
+
- `POST /api/v1/auth/login` - Login
|
| 27 |
+
- `POST /api/v1/chat/` - Send message (requires auth)
|
| 28 |
+
- `GET /health` - Health check
|
app/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Banking RAG Chatbot API - Application Package
|
| 3 |
+
This file marks the 'app' directory as a Python package.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# Package version
|
| 7 |
+
__version__ = "1.0.0"
|
| 8 |
+
|
| 9 |
+
# Package metadata
|
| 10 |
+
__author__ = "Eeshanya Joshi"
|
| 11 |
+
__description__ = "Banking RAG Chatbot with RL-Optimized Retrieval"
|
app/api/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API package - Contains all API endpoints
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
__version__ = "1.0.0"
|
| 6 |
+
__author__ = "Eeshanya Joshi"
|
app/api/v1/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API Version 1 - All v1 endpoints
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
# This makes it easier to import routers
|
| 6 |
+
# Example: from app.api.v1 import chat, auth
|
| 7 |
+
|
| 8 |
+
__version__ = "1.0.0"
|
app/api/v1/auth.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Authentication API Endpoints
|
| 3 |
+
User registration, login, and token management
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from fastapi import APIRouter, HTTPException, status, Depends
|
| 7 |
+
from datetime import timedelta
|
| 8 |
+
|
| 9 |
+
from app.models.user import UserRegister, UserLogin, Token, UserResponse, TokenData
|
| 10 |
+
from app.db.repositories.user_repository import UserRepository
|
| 11 |
+
from app.utils.security import verify_password, create_access_token
|
| 12 |
+
from app.utils.dependencies import get_current_user
|
| 13 |
+
from app.config import settings
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
router = APIRouter()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@router.post("/register", response_model=Token, status_code=status.HTTP_201_CREATED)
|
| 20 |
+
async def register_user(user_data: UserRegister):
|
| 21 |
+
"""
|
| 22 |
+
Register a new user.
|
| 23 |
+
|
| 24 |
+
Creates a new user account with hashed password and returns
|
| 25 |
+
an access token for immediate login.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
user_data: User registration data (email, password, full_name)
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
Token: JWT access token and user info
|
| 32 |
+
|
| 33 |
+
Raises:
|
| 34 |
+
HTTPException: If email already exists
|
| 35 |
+
"""
|
| 36 |
+
user_repo = UserRepository()
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
# Create user
|
| 40 |
+
user_id = await user_repo.create_user(
|
| 41 |
+
email=user_data.email,
|
| 42 |
+
password=user_data.password,
|
| 43 |
+
full_name=user_data.full_name
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Get created user
|
| 47 |
+
user = await user_repo.get_user_by_id(user_id)
|
| 48 |
+
|
| 49 |
+
# Generate access token
|
| 50 |
+
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
| 51 |
+
access_token = create_access_token(
|
| 52 |
+
data={"user_id": user["user_id"], "email": user["email"]},
|
| 53 |
+
expires_delta=access_token_expires
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Return token and user info
|
| 57 |
+
return Token(
|
| 58 |
+
access_token=access_token,
|
| 59 |
+
token_type="bearer",
|
| 60 |
+
user=UserResponse(
|
| 61 |
+
user_id=user["user_id"],
|
| 62 |
+
email=user["email"],
|
| 63 |
+
full_name=user["full_name"],
|
| 64 |
+
created_at=user["created_at"]
|
| 65 |
+
)
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
except ValueError as e:
|
| 69 |
+
raise HTTPException(
|
| 70 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 71 |
+
detail=str(e)
|
| 72 |
+
)
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f"❌ Registration error: {e}")
|
| 75 |
+
raise HTTPException(
|
| 76 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 77 |
+
detail="Failed to register user"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@router.post("/login", response_model=Token)
|
| 82 |
+
async def login_user(user_data: UserLogin):
|
| 83 |
+
"""
|
| 84 |
+
Login user and get access token.
|
| 85 |
+
|
| 86 |
+
Validates user credentials and returns JWT access token.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
user_data: User login data (email, password)
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
Token: JWT access token and user info
|
| 93 |
+
|
| 94 |
+
Raises:
|
| 95 |
+
HTTPException: If credentials are invalid
|
| 96 |
+
"""
|
| 97 |
+
user_repo = UserRepository()
|
| 98 |
+
|
| 99 |
+
# Get user by email
|
| 100 |
+
user = await user_repo.get_user_by_email(user_data.email)
|
| 101 |
+
|
| 102 |
+
if not user:
|
| 103 |
+
raise HTTPException(
|
| 104 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 105 |
+
detail="Invalid email or password",
|
| 106 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Verify password
|
| 110 |
+
if not verify_password(user_data.password, user["hashed_password"]):
|
| 111 |
+
raise HTTPException(
|
| 112 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 113 |
+
detail="Invalid email or password",
|
| 114 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Check if user is active
|
| 118 |
+
if not user.get("is_active", False):
|
| 119 |
+
raise HTTPException(
|
| 120 |
+
status_code=status.HTTP_403_FORBIDDEN,
|
| 121 |
+
detail="User account is inactive"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Generate access token
|
| 125 |
+
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
| 126 |
+
access_token = create_access_token(
|
| 127 |
+
data={"user_id": user["user_id"], "email": user["email"]},
|
| 128 |
+
expires_delta=access_token_expires
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# Return token and user info
|
| 132 |
+
return Token(
|
| 133 |
+
access_token=access_token,
|
| 134 |
+
token_type="bearer",
|
| 135 |
+
user=UserResponse(
|
| 136 |
+
user_id=user["user_id"],
|
| 137 |
+
email=user["email"],
|
| 138 |
+
full_name=user["full_name"],
|
| 139 |
+
created_at=user["created_at"]
|
| 140 |
+
)
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@router.get("/me", response_model=UserResponse)
|
| 145 |
+
async def get_current_user_info(current_user: TokenData = Depends(get_current_user)):
|
| 146 |
+
"""
|
| 147 |
+
Get current authenticated user information.
|
| 148 |
+
|
| 149 |
+
Protected route that requires valid JWT token.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
current_user: Current authenticated user (from token)
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
UserResponse: Current user information
|
| 156 |
+
"""
|
| 157 |
+
user_repo = UserRepository()
|
| 158 |
+
user = await user_repo.get_user_by_id(current_user.user_id)
|
| 159 |
+
|
| 160 |
+
if not user:
|
| 161 |
+
raise HTTPException(
|
| 162 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 163 |
+
detail="User not found"
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
return UserResponse(
|
| 167 |
+
user_id=user["user_id"],
|
| 168 |
+
email=user["email"],
|
| 169 |
+
full_name=user["full_name"],
|
| 170 |
+
created_at=user["created_at"]
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
@router.post("/logout")
|
| 175 |
+
async def logout_user(current_user: TokenData = Depends(get_current_user)):
|
| 176 |
+
"""
|
| 177 |
+
Logout user (client-side token deletion).
|
| 178 |
+
|
| 179 |
+
In JWT-based auth, logout is handled client-side by
|
| 180 |
+
deleting the token. This endpoint is for logging purposes.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
current_user: Current authenticated user (from token)
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
dict: Success message
|
| 187 |
+
"""
|
| 188 |
+
print(f"👋 User logged out: {current_user.email}")
|
| 189 |
+
|
| 190 |
+
return {
|
| 191 |
+
"message": "Successfully logged out",
|
| 192 |
+
"user_id": current_user.user_id
|
| 193 |
+
}
|
app/api/v1/chat.py
ADDED
|
@@ -0,0 +1,1346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chat API Endpoints (WITH AUTHENTICATION)
|
| 3 |
+
RESTful API for the Banking RAG Chatbot
|
| 4 |
+
|
| 5 |
+
NOW REQUIRES JWT TOKEN FOR ALL ENDPOINTS!
|
| 6 |
+
|
| 7 |
+
Endpoints:
|
| 8 |
+
- POST /chat - Send a message and get response (PROTECTED)
|
| 9 |
+
- GET /chat/history/{conversation_id} - Get conversation history (PROTECTED)
|
| 10 |
+
- POST /chat/conversation - Create new conversation (PROTECTED)
|
| 11 |
+
- GET /chat/conversations - List user's conversations (PROTECTED)
|
| 12 |
+
- DELETE /chat/conversation/{conversation_id} - Delete conversation (PROTECTED)
|
| 13 |
+
- GET /chat/health - Health check (PUBLIC)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from fastapi import APIRouter, HTTPException, status, Depends
|
| 17 |
+
from pydantic import BaseModel, Field
|
| 18 |
+
from typing import List, Dict, Optional
|
| 19 |
+
from datetime import datetime
|
| 20 |
+
|
| 21 |
+
from app.services.chat_service import chat_service
|
| 22 |
+
from app.db.repositories.conversation_repository import ConversationRepository
|
| 23 |
+
from app.utils.dependencies import get_current_user # AUTH DEPENDENCY
|
| 24 |
+
from app.models.user import TokenData # USER DATA FROM TOKEN
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ============================================================================
|
| 28 |
+
# CREATE ROUTER
|
| 29 |
+
# ============================================================================
|
| 30 |
+
router = APIRouter()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ============================================================================
|
| 34 |
+
# DEPENDENCY: Get ConversationRepository instance
|
| 35 |
+
# ============================================================================
|
| 36 |
+
def get_conversation_repo() -> ConversationRepository:
|
| 37 |
+
"""
|
| 38 |
+
Dependency that provides ConversationRepository instance.
|
| 39 |
+
This ensures MongoDB is connected before repository is used.
|
| 40 |
+
"""
|
| 41 |
+
return ConversationRepository()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ============================================================================
|
| 45 |
+
# PYDANTIC MODELS (Request/Response schemas)
|
| 46 |
+
# ============================================================================
|
| 47 |
+
|
| 48 |
+
class ChatRequest(BaseModel):
|
| 49 |
+
"""Request model for chat endpoint"""
|
| 50 |
+
query: str = Field(..., description="User query text", min_length=1, max_length=1000)
|
| 51 |
+
conversation_id: Optional[str] = Field(None, description="Optional conversation ID")
|
| 52 |
+
|
| 53 |
+
class Config:
|
| 54 |
+
json_schema_extra = {
|
| 55 |
+
"example": {
|
| 56 |
+
"query": "What is my account balance?",
|
| 57 |
+
"conversation_id": "conv-123"
|
| 58 |
+
}
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class ChatResponse(BaseModel):
|
| 63 |
+
"""Response model for chat endpoint"""
|
| 64 |
+
response: str = Field(..., description="Generated response text")
|
| 65 |
+
conversation_id: str = Field(..., description="Conversation ID")
|
| 66 |
+
policy_action: str = Field(..., description="Policy decision: FETCH or NO_FETCH")
|
| 67 |
+
policy_confidence: float = Field(..., description="Policy confidence score (0-1)")
|
| 68 |
+
documents_retrieved: int = Field(..., description="Number of documents retrieved")
|
| 69 |
+
top_doc_score: Optional[float] = Field(None, description="Best document similarity score")
|
| 70 |
+
total_time_ms: float = Field(..., description="Total processing time in milliseconds")
|
| 71 |
+
timestamp: str = Field(..., description="Response timestamp (ISO format)")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class ConversationCreateResponse(BaseModel):
|
| 75 |
+
"""Response after creating a conversation"""
|
| 76 |
+
conversation_id: str = Field(..., description="Created conversation ID")
|
| 77 |
+
created_at: str = Field(..., description="Creation timestamp")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class MessageModel(BaseModel):
|
| 81 |
+
"""Single message in conversation history"""
|
| 82 |
+
role: str = Field(..., description="Message role: user or assistant")
|
| 83 |
+
content: str = Field(..., description="Message content")
|
| 84 |
+
timestamp: str = Field(..., description="Message timestamp")
|
| 85 |
+
metadata: Optional[Dict] = Field(None, description="Optional metadata")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class ConversationHistoryResponse(BaseModel):
|
| 89 |
+
"""Response containing conversation history"""
|
| 90 |
+
conversation_id: str
|
| 91 |
+
messages: List[MessageModel]
|
| 92 |
+
message_count: int
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# ============================================================================
|
| 96 |
+
# ENDPOINTS (ALL PROTECTED WITH JWT)
|
| 97 |
+
# ============================================================================
|
| 98 |
+
|
| 99 |
+
@router.post("/", response_model=ChatResponse, status_code=status.HTTP_200_OK)
|
| 100 |
+
async def chat(
|
| 101 |
+
request: ChatRequest,
|
| 102 |
+
current_user: TokenData = Depends(get_current_user),
|
| 103 |
+
repo: ConversationRepository = Depends(get_conversation_repo) # ← INJECT REPO
|
| 104 |
+
):
|
| 105 |
+
"""
|
| 106 |
+
Main chat endpoint - Send a query and get a response.
|
| 107 |
+
|
| 108 |
+
**REQUIRES AUTHENTICATION** - JWT token must be provided in Authorization header.
|
| 109 |
+
"""
|
| 110 |
+
try:
|
| 111 |
+
# Get user_id from token
|
| 112 |
+
user_id = current_user.user_id
|
| 113 |
+
|
| 114 |
+
# If no conversation_id provided, create a new conversation
|
| 115 |
+
conversation_id = request.conversation_id
|
| 116 |
+
if not conversation_id:
|
| 117 |
+
conversation_id = await repo.create_conversation(user_id=user_id)
|
| 118 |
+
else:
|
| 119 |
+
# Verify user owns this conversation
|
| 120 |
+
conversation = await repo.get_conversation(conversation_id)
|
| 121 |
+
if not conversation:
|
| 122 |
+
raise HTTPException(
|
| 123 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 124 |
+
detail="Conversation not found"
|
| 125 |
+
)
|
| 126 |
+
if conversation["user_id"] != user_id:
|
| 127 |
+
raise HTTPException(
|
| 128 |
+
status_code=status.HTTP_403_FORBIDDEN,
|
| 129 |
+
detail="Access denied - you don't own this conversation"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Get conversation history
|
| 133 |
+
history = await repo.get_conversation_history(
|
| 134 |
+
conversation_id=conversation_id,
|
| 135 |
+
max_messages=10
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Save user message
|
| 139 |
+
await repo.add_message(
|
| 140 |
+
conversation_id=conversation_id,
|
| 141 |
+
message={
|
| 142 |
+
'role': 'user',
|
| 143 |
+
'content': request.query,
|
| 144 |
+
'timestamp': datetime.now()
|
| 145 |
+
}
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Process query through RAG pipeline
|
| 149 |
+
result = await chat_service.process_query(
|
| 150 |
+
query=request.query,
|
| 151 |
+
conversation_history=history,
|
| 152 |
+
user_id=user_id
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# Save assistant message
|
| 156 |
+
await repo.add_message(
|
| 157 |
+
conversation_id=conversation_id,
|
| 158 |
+
message={
|
| 159 |
+
'role': 'assistant',
|
| 160 |
+
'content': result['response'],
|
| 161 |
+
'timestamp': datetime.now(),
|
| 162 |
+
'metadata': {
|
| 163 |
+
'policy_action': result['policy_action'],
|
| 164 |
+
'policy_confidence': result['policy_confidence'],
|
| 165 |
+
'documents_retrieved': result['documents_retrieved'],
|
| 166 |
+
'top_doc_score': result['top_doc_score']
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# Log retrieval data for RL training
|
| 172 |
+
await repo.log_retrieval({
|
| 173 |
+
'conversation_id': conversation_id,
|
| 174 |
+
'user_id': user_id,
|
| 175 |
+
'query': request.query,
|
| 176 |
+
'policy_action': result['policy_action'],
|
| 177 |
+
'policy_confidence': result['policy_confidence'],
|
| 178 |
+
'should_retrieve': result['should_retrieve'],
|
| 179 |
+
'documents_retrieved': result['documents_retrieved'],
|
| 180 |
+
'top_doc_score': result['top_doc_score'],
|
| 181 |
+
'response': result['response'],
|
| 182 |
+
'retrieval_time_ms': result['retrieval_time_ms'],
|
| 183 |
+
'generation_time_ms': result['generation_time_ms'],
|
| 184 |
+
'total_time_ms': result['total_time_ms'],
|
| 185 |
+
'retrieved_docs_metadata': result.get('retrieved_docs_metadata', []),
|
| 186 |
+
'timestamp': datetime.now()
|
| 187 |
+
})
|
| 188 |
+
|
| 189 |
+
# Return response
|
| 190 |
+
return ChatResponse(
|
| 191 |
+
response=result['response'],
|
| 192 |
+
conversation_id=conversation_id,
|
| 193 |
+
policy_action=result['policy_action'],
|
| 194 |
+
policy_confidence=result['policy_confidence'],
|
| 195 |
+
documents_retrieved=result['documents_retrieved'],
|
| 196 |
+
top_doc_score=result['top_doc_score'],
|
| 197 |
+
total_time_ms=result['total_time_ms'],
|
| 198 |
+
timestamp=result['timestamp']
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
except HTTPException:
|
| 202 |
+
raise
|
| 203 |
+
except Exception as e:
|
| 204 |
+
print(f"❌ Chat endpoint error: {e}")
|
| 205 |
+
import traceback
|
| 206 |
+
traceback.print_exc()
|
| 207 |
+
raise HTTPException(
|
| 208 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 209 |
+
detail=f"Failed to process chat request: {str(e)}"
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
@router.post("/conversation", response_model=ConversationCreateResponse, status_code=status.HTTP_201_CREATED)
|
| 214 |
+
async def create_conversation(
|
| 215 |
+
current_user: TokenData = Depends(get_current_user),
|
| 216 |
+
repo: ConversationRepository = Depends(get_conversation_repo)
|
| 217 |
+
):
|
| 218 |
+
"""Create a new conversation"""
|
| 219 |
+
try:
|
| 220 |
+
conversation_id = await repo.create_conversation(user_id=current_user.user_id)
|
| 221 |
+
return ConversationCreateResponse(
|
| 222 |
+
conversation_id=conversation_id,
|
| 223 |
+
created_at=datetime.now().isoformat()
|
| 224 |
+
)
|
| 225 |
+
except Exception as e:
|
| 226 |
+
raise HTTPException(
|
| 227 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 228 |
+
detail=f"Failed to create conversation: {str(e)}"
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
@router.get("/history/{conversation_id}", response_model=ConversationHistoryResponse)
|
| 233 |
+
async def get_conversation_history(
|
| 234 |
+
conversation_id: str,
|
| 235 |
+
current_user: TokenData = Depends(get_current_user),
|
| 236 |
+
repo: ConversationRepository = Depends(get_conversation_repo)
|
| 237 |
+
):
|
| 238 |
+
"""Get conversation history by ID"""
|
| 239 |
+
try:
|
| 240 |
+
conversation = await repo.get_conversation(conversation_id)
|
| 241 |
+
|
| 242 |
+
if not conversation:
|
| 243 |
+
raise HTTPException(
|
| 244 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 245 |
+
detail=f"Conversation {conversation_id} not found"
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
if conversation["user_id"] != current_user.user_id:
|
| 249 |
+
raise HTTPException(
|
| 250 |
+
status_code=status.HTTP_403_FORBIDDEN,
|
| 251 |
+
detail="Access denied - you don't own this conversation"
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
messages = []
|
| 255 |
+
for msg in conversation.get('messages', []):
|
| 256 |
+
messages.append(MessageModel(
|
| 257 |
+
role=msg['role'],
|
| 258 |
+
content=msg['content'],
|
| 259 |
+
timestamp=msg['timestamp'].isoformat() if isinstance(msg['timestamp'], datetime) else msg['timestamp'],
|
| 260 |
+
metadata=msg.get('metadata')
|
| 261 |
+
))
|
| 262 |
+
|
| 263 |
+
return ConversationHistoryResponse(
|
| 264 |
+
conversation_id=conversation_id,
|
| 265 |
+
messages=messages,
|
| 266 |
+
message_count=len(messages)
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
except HTTPException:
|
| 270 |
+
raise
|
| 271 |
+
except Exception as e:
|
| 272 |
+
raise HTTPException(
|
| 273 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 274 |
+
detail=f"Failed to fetch conversation history: {str(e)}"
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
@router.get("/conversations")
|
| 279 |
+
async def list_user_conversations(
|
| 280 |
+
limit: int = 10,
|
| 281 |
+
skip: int = 0,
|
| 282 |
+
current_user: TokenData = Depends(get_current_user),
|
| 283 |
+
repo: ConversationRepository = Depends(get_conversation_repo)
|
| 284 |
+
):
|
| 285 |
+
"""List all conversations for the authenticated user"""
|
| 286 |
+
try:
|
| 287 |
+
conversations = await repo.get_user_conversations(
|
| 288 |
+
user_id=current_user.user_id,
|
| 289 |
+
limit=limit,
|
| 290 |
+
skip=skip
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
return {
|
| 294 |
+
"user_id": current_user.user_id,
|
| 295 |
+
"user_email": current_user.email,
|
| 296 |
+
"conversations": [
|
| 297 |
+
{
|
| 298 |
+
"conversation_id": conv['conversation_id'],
|
| 299 |
+
"created_at": conv['created_at'].isoformat() if isinstance(conv['created_at'], datetime) else conv['created_at'],
|
| 300 |
+
"updated_at": conv['updated_at'].isoformat() if isinstance(conv['updated_at'], datetime) else conv['updated_at'],
|
| 301 |
+
"message_count": len(conv.get('messages', []))
|
| 302 |
+
}
|
| 303 |
+
for conv in conversations
|
| 304 |
+
],
|
| 305 |
+
"total": len(conversations)
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
except Exception as e:
|
| 309 |
+
raise HTTPException(
|
| 310 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 311 |
+
detail=f"Failed to fetch conversations: {str(e)}"
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
@router.delete("/conversation/{conversation_id}")
|
| 316 |
+
async def delete_conversation(
|
| 317 |
+
conversation_id: str,
|
| 318 |
+
current_user: TokenData = Depends(get_current_user),
|
| 319 |
+
repo: ConversationRepository = Depends(get_conversation_repo)
|
| 320 |
+
):
|
| 321 |
+
"""Delete a conversation"""
|
| 322 |
+
try:
|
| 323 |
+
conversation = await repo.get_conversation(conversation_id)
|
| 324 |
+
|
| 325 |
+
if not conversation:
|
| 326 |
+
raise HTTPException(
|
| 327 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 328 |
+
detail=f"Conversation {conversation_id} not found"
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
if conversation["user_id"] != current_user.user_id:
|
| 332 |
+
raise HTTPException(
|
| 333 |
+
status_code=status.HTTP_403_FORBIDDEN,
|
| 334 |
+
detail="Access denied - you don't own this conversation"
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
success = await repo.delete_conversation(conversation_id)
|
| 338 |
+
|
| 339 |
+
if success:
|
| 340 |
+
return {
|
| 341 |
+
"message": "Conversation deleted successfully",
|
| 342 |
+
"conversation_id": conversation_id
|
| 343 |
+
}
|
| 344 |
+
else:
|
| 345 |
+
raise HTTPException(
|
| 346 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 347 |
+
detail="Failed to delete conversation"
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
except HTTPException:
|
| 351 |
+
raise
|
| 352 |
+
except Exception as e:
|
| 353 |
+
raise HTTPException(
|
| 354 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 355 |
+
detail=f"Failed to delete conversation: {str(e)}"
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
@router.get("/health")
|
| 360 |
+
async def chat_health():
|
| 361 |
+
"""Health check for chat service (PUBLIC)"""
|
| 362 |
+
try:
|
| 363 |
+
health = await chat_service.health_check()
|
| 364 |
+
|
| 365 |
+
return {
|
| 366 |
+
"status": "healthy",
|
| 367 |
+
"service": "chat",
|
| 368 |
+
"components": health['components'],
|
| 369 |
+
"timestamp": datetime.now().isoformat()
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
except Exception as e:
|
| 373 |
+
return {
|
| 374 |
+
"status": "unhealthy",
|
| 375 |
+
"service": "chat",
|
| 376 |
+
"error": str(e),
|
| 377 |
+
"timestamp": datetime.now().isoformat()
|
| 378 |
+
}
|
| 379 |
+
# ============================================================================
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
# """
|
| 390 |
+
# Chat API Endpoints (WITH AUTHENTICATION)
|
| 391 |
+
# RESTful API for the Banking RAG Chatbot
|
| 392 |
+
|
| 393 |
+
# NOW REQUIRES JWT TOKEN FOR ALL ENDPOINTS!
|
| 394 |
+
|
| 395 |
+
# Endpoints:
|
| 396 |
+
# - POST /chat - Send a message and get response (PROTECTED)
|
| 397 |
+
# - GET /chat/history/{conversation_id} - Get conversation history (PROTECTED)
|
| 398 |
+
# - POST /chat/conversation - Create new conversation (PROTECTED)
|
| 399 |
+
# - GET /chat/conversations - List user's conversations (PROTECTED)
|
| 400 |
+
# - DELETE /chat/conversation/{conversation_id} - Delete conversation (PROTECTED)
|
| 401 |
+
# - GET /chat/health - Health check (PUBLIC)
|
| 402 |
+
# """
|
| 403 |
+
|
| 404 |
+
# from fastapi import APIRouter, HTTPException, status, Depends
|
| 405 |
+
# from pydantic import BaseModel, Field
|
| 406 |
+
# from typing import List, Dict, Optional
|
| 407 |
+
# from datetime import datetime
|
| 408 |
+
|
| 409 |
+
# from app.services.chat_service import chat_service
|
| 410 |
+
# from app.db.repositories.conversation_repository import ConversationRepository
|
| 411 |
+
# from app.utils.dependencies import get_current_user # AUTH DEPENDENCY
|
| 412 |
+
# from app.models.user import TokenData # USER DATA FROM TOKEN
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
# # ============================================================================
|
| 416 |
+
# # CREATE ROUTER
|
| 417 |
+
# # ============================================================================
|
| 418 |
+
# router = APIRouter()
|
| 419 |
+
|
| 420 |
+
# # Initialize repository
|
| 421 |
+
# conversation_repo = ConversationRepository()
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
# # ============================================================================
|
| 425 |
+
# # PYDANTIC MODELS (Request/Response schemas)
|
| 426 |
+
# # ============================================================================
|
| 427 |
+
|
| 428 |
+
# class ChatRequest(BaseModel):
|
| 429 |
+
# """
|
| 430 |
+
# Request model for chat endpoint.
|
| 431 |
+
|
| 432 |
+
# NOTE: user_id is now extracted from JWT token, not from request body!
|
| 433 |
+
|
| 434 |
+
# Example:
|
| 435 |
+
# {
|
| 436 |
+
# "query": "What is my account balance?",
|
| 437 |
+
# "conversation_id": "abc-123"
|
| 438 |
+
# }
|
| 439 |
+
# """
|
| 440 |
+
# query: str = Field(..., description="User query text", min_length=1, max_length=1000)
|
| 441 |
+
# conversation_id: Optional[str] = Field(None, description="Optional conversation ID")
|
| 442 |
+
|
| 443 |
+
# class Config:
|
| 444 |
+
# json_schema_extra = {
|
| 445 |
+
# "example": {
|
| 446 |
+
# "query": "What is my account balance?",
|
| 447 |
+
# "conversation_id": "conv-123"
|
| 448 |
+
# }
|
| 449 |
+
# }
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
# class ChatResponse(BaseModel):
|
| 453 |
+
# """
|
| 454 |
+
# Response model for chat endpoint.
|
| 455 |
+
|
| 456 |
+
# Contains the generated response plus metadata about the RAG pipeline.
|
| 457 |
+
# """
|
| 458 |
+
# response: str = Field(..., description="Generated response text")
|
| 459 |
+
# conversation_id: str = Field(..., description="Conversation ID")
|
| 460 |
+
# policy_action: str = Field(..., description="Policy decision: FETCH or NO_FETCH")
|
| 461 |
+
# policy_confidence: float = Field(..., description="Policy confidence score (0-1)")
|
| 462 |
+
# documents_retrieved: int = Field(..., description="Number of documents retrieved")
|
| 463 |
+
# top_doc_score: Optional[float] = Field(None, description="Best document similarity score")
|
| 464 |
+
# total_time_ms: float = Field(..., description="Total processing time in milliseconds")
|
| 465 |
+
# timestamp: str = Field(..., description="Response timestamp (ISO format)")
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
# class ConversationCreateRequest(BaseModel):
|
| 469 |
+
# """Request to create a new conversation (no user_id needed - from token)"""
|
| 470 |
+
# pass # Empty - user_id comes from JWT token
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
# class ConversationCreateResponse(BaseModel):
|
| 474 |
+
# """Response after creating a conversation"""
|
| 475 |
+
# conversation_id: str = Field(..., description="Created conversation ID")
|
| 476 |
+
# created_at: str = Field(..., description="Creation timestamp")
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
# class MessageModel(BaseModel):
|
| 480 |
+
# """Single message in conversation history"""
|
| 481 |
+
# role: str = Field(..., description="Message role: user or assistant")
|
| 482 |
+
# content: str = Field(..., description="Message content")
|
| 483 |
+
# timestamp: str = Field(..., description="Message timestamp")
|
| 484 |
+
# metadata: Optional[Dict] = Field(None, description="Optional metadata")
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
# class ConversationHistoryResponse(BaseModel):
|
| 488 |
+
# """Response containing conversation history"""
|
| 489 |
+
# conversation_id: str
|
| 490 |
+
# messages: List[MessageModel]
|
| 491 |
+
# message_count: int
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
# # ============================================================================
|
| 495 |
+
# # ENDPOINTS (ALL PROTECTED WITH JWT)
|
| 496 |
+
# # ============================================================================
|
| 497 |
+
|
| 498 |
+
# @router.post("/", response_model=ChatResponse, status_code=status.HTTP_200_OK)
|
| 499 |
+
# async def chat(
|
| 500 |
+
# request: ChatRequest,
|
| 501 |
+
# current_user: TokenData = Depends(get_current_user) # ← REQUIRES AUTH!
|
| 502 |
+
# ):
|
| 503 |
+
# """
|
| 504 |
+
# Main chat endpoint - Send a query and get a response.
|
| 505 |
+
|
| 506 |
+
# **REQUIRES AUTHENTICATION** - JWT token must be provided in Authorization header.
|
| 507 |
+
|
| 508 |
+
# This endpoint:
|
| 509 |
+
# 1. Extracts user_id from JWT token
|
| 510 |
+
# 2. Processes the query through the RAG pipeline
|
| 511 |
+
# 3. Saves messages to MongoDB
|
| 512 |
+
# 4. Logs retrieval data for RL training
|
| 513 |
+
# 5. Returns response with metadata
|
| 514 |
+
|
| 515 |
+
# Args:
|
| 516 |
+
# request: ChatRequest with query and optional conversation_id
|
| 517 |
+
# current_user: Authenticated user data from JWT token
|
| 518 |
+
|
| 519 |
+
# Returns:
|
| 520 |
+
# ChatResponse: Generated response with metadata
|
| 521 |
+
|
| 522 |
+
# Raises:
|
| 523 |
+
# HTTPException: If processing fails or user not authenticated
|
| 524 |
+
# """
|
| 525 |
+
# try:
|
| 526 |
+
# # Get user_id from token (NOT from request body!)
|
| 527 |
+
# user_id = current_user.user_id
|
| 528 |
+
|
| 529 |
+
# # If no conversation_id provided, create a new conversation
|
| 530 |
+
# conversation_id = request.conversation_id
|
| 531 |
+
# if not conversation_id:
|
| 532 |
+
# conversation_id = await conversation_repo.create_conversation(
|
| 533 |
+
# user_id=user_id
|
| 534 |
+
# )
|
| 535 |
+
# else:
|
| 536 |
+
# # Verify user owns this conversation
|
| 537 |
+
# conversation = await conversation_repo.get_conversation(conversation_id)
|
| 538 |
+
# if not conversation:
|
| 539 |
+
# raise HTTPException(
|
| 540 |
+
# status_code=status.HTTP_404_NOT_FOUND,
|
| 541 |
+
# detail="Conversation not found"
|
| 542 |
+
# )
|
| 543 |
+
# if conversation["user_id"] != user_id:
|
| 544 |
+
# raise HTTPException(
|
| 545 |
+
# status_code=status.HTTP_403_FORBIDDEN,
|
| 546 |
+
# detail="Access denied - you don't own this conversation"
|
| 547 |
+
# )
|
| 548 |
+
|
| 549 |
+
# # Get conversation history
|
| 550 |
+
# history = await conversation_repo.get_conversation_history(
|
| 551 |
+
# conversation_id=conversation_id,
|
| 552 |
+
# max_messages=10 # Last 5 turns (10 messages)
|
| 553 |
+
# )
|
| 554 |
+
|
| 555 |
+
# # Save user message to database
|
| 556 |
+
# await conversation_repo.add_message(
|
| 557 |
+
# conversation_id=conversation_id,
|
| 558 |
+
# message={
|
| 559 |
+
# 'role': 'user',
|
| 560 |
+
# 'content': request.query,
|
| 561 |
+
# 'timestamp': datetime.now()
|
| 562 |
+
# }
|
| 563 |
+
# )
|
| 564 |
+
|
| 565 |
+
# # Process query through RAG pipeline
|
| 566 |
+
# result = await chat_service.process_query(
|
| 567 |
+
# query=request.query,
|
| 568 |
+
# conversation_history=history,
|
| 569 |
+
# user_id=user_id
|
| 570 |
+
# )
|
| 571 |
+
|
| 572 |
+
# # Save assistant message to database
|
| 573 |
+
# await conversation_repo.add_message(
|
| 574 |
+
# conversation_id=conversation_id,
|
| 575 |
+
# message={
|
| 576 |
+
# 'role': 'assistant',
|
| 577 |
+
# 'content': result['response'],
|
| 578 |
+
# 'timestamp': datetime.now(),
|
| 579 |
+
# 'metadata': {
|
| 580 |
+
# 'policy_action': result['policy_action'],
|
| 581 |
+
# 'policy_confidence': result['policy_confidence'],
|
| 582 |
+
# 'documents_retrieved': result['documents_retrieved'],
|
| 583 |
+
# 'top_doc_score': result['top_doc_score']
|
| 584 |
+
# }
|
| 585 |
+
# }
|
| 586 |
+
# )
|
| 587 |
+
|
| 588 |
+
# # Log retrieval data for RL training
|
| 589 |
+
# await conversation_repo.log_retrieval({
|
| 590 |
+
# 'conversation_id': conversation_id,
|
| 591 |
+
# 'user_id': user_id,
|
| 592 |
+
# 'query': request.query,
|
| 593 |
+
# 'policy_action': result['policy_action'],
|
| 594 |
+
# 'policy_confidence': result['policy_confidence'],
|
| 595 |
+
# 'should_retrieve': result['should_retrieve'],
|
| 596 |
+
# 'documents_retrieved': result['documents_retrieved'],
|
| 597 |
+
# 'top_doc_score': result['top_doc_score'],
|
| 598 |
+
# 'response': result['response'],
|
| 599 |
+
# 'retrieval_time_ms': result['retrieval_time_ms'],
|
| 600 |
+
# 'generation_time_ms': result['generation_time_ms'],
|
| 601 |
+
# 'total_time_ms': result['total_time_ms'],
|
| 602 |
+
# 'retrieved_docs_metadata': result.get('retrieved_docs_metadata', []),
|
| 603 |
+
# 'timestamp': datetime.now()
|
| 604 |
+
# })
|
| 605 |
+
|
| 606 |
+
# # Return response
|
| 607 |
+
# return ChatResponse(
|
| 608 |
+
# response=result['response'],
|
| 609 |
+
# conversation_id=conversation_id,
|
| 610 |
+
# policy_action=result['policy_action'],
|
| 611 |
+
# policy_confidence=result['policy_confidence'],
|
| 612 |
+
# documents_retrieved=result['documents_retrieved'],
|
| 613 |
+
# top_doc_score=result['top_doc_score'],
|
| 614 |
+
# total_time_ms=result['total_time_ms'],
|
| 615 |
+
# timestamp=result['timestamp']
|
| 616 |
+
# )
|
| 617 |
+
|
| 618 |
+
# except HTTPException:
|
| 619 |
+
# raise # Re-raise HTTP exceptions
|
| 620 |
+
# except Exception as e:
|
| 621 |
+
# print(f"❌ Chat endpoint error: {e}")
|
| 622 |
+
# raise HTTPException(
|
| 623 |
+
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 624 |
+
# detail=f"Failed to process chat request: {str(e)}"
|
| 625 |
+
# )
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
# @router.post("/conversation", response_model=ConversationCreateResponse, status_code=status.HTTP_201_CREATED)
|
| 629 |
+
# async def create_conversation(
|
| 630 |
+
# current_user: TokenData = Depends(get_current_user) # ← REQUIRES AUTH!
|
| 631 |
+
# ):
|
| 632 |
+
# """
|
| 633 |
+
# Create a new conversation.
|
| 634 |
+
|
| 635 |
+
# **REQUIRES AUTHENTICATION** - User ID is extracted from JWT token.
|
| 636 |
+
|
| 637 |
+
# Args:
|
| 638 |
+
# current_user: Authenticated user data from JWT token
|
| 639 |
+
|
| 640 |
+
# Returns:
|
| 641 |
+
# ConversationCreateResponse: Created conversation ID
|
| 642 |
+
# """
|
| 643 |
+
# try:
|
| 644 |
+
# conversation_id = await conversation_repo.create_conversation(
|
| 645 |
+
# user_id=current_user.user_id
|
| 646 |
+
# )
|
| 647 |
+
|
| 648 |
+
# return ConversationCreateResponse(
|
| 649 |
+
# conversation_id=conversation_id,
|
| 650 |
+
# created_at=datetime.now().isoformat()
|
| 651 |
+
# )
|
| 652 |
+
|
| 653 |
+
# except Exception as e:
|
| 654 |
+
# raise HTTPException(
|
| 655 |
+
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 656 |
+
# detail=f"Failed to create conversation: {str(e)}"
|
| 657 |
+
# )
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
# @router.get("/history/{conversation_id}", response_model=ConversationHistoryResponse)
|
| 661 |
+
# async def get_conversation_history(
|
| 662 |
+
# conversation_id: str,
|
| 663 |
+
# current_user: TokenData = Depends(get_current_user) # ← REQUIRES AUTH!
|
| 664 |
+
# ):
|
| 665 |
+
# """
|
| 666 |
+
# Get conversation history by ID.
|
| 667 |
+
|
| 668 |
+
# **REQUIRES AUTHENTICATION** - User can only access their own conversations.
|
| 669 |
+
|
| 670 |
+
# Args:
|
| 671 |
+
# conversation_id: Conversation ID
|
| 672 |
+
# current_user: Authenticated user data from JWT token
|
| 673 |
+
|
| 674 |
+
# Returns:
|
| 675 |
+
# ConversationHistoryResponse: List of messages
|
| 676 |
+
|
| 677 |
+
# Raises:
|
| 678 |
+
# HTTPException: If conversation not found or user doesn't own it
|
| 679 |
+
# """
|
| 680 |
+
# try:
|
| 681 |
+
# # Get conversation
|
| 682 |
+
# conversation = await conversation_repo.get_conversation(conversation_id)
|
| 683 |
+
|
| 684 |
+
# if not conversation:
|
| 685 |
+
# raise HTTPException(
|
| 686 |
+
# status_code=status.HTTP_404_NOT_FOUND,
|
| 687 |
+
# detail=f"Conversation {conversation_id} not found"
|
| 688 |
+
# )
|
| 689 |
+
|
| 690 |
+
# # Verify user owns this conversation
|
| 691 |
+
# if conversation["user_id"] != current_user.user_id:
|
| 692 |
+
# raise HTTPException(
|
| 693 |
+
# status_code=status.HTTP_403_FORBIDDEN,
|
| 694 |
+
# detail="Access denied - you don't own this conversation"
|
| 695 |
+
# )
|
| 696 |
+
|
| 697 |
+
# # Format messages
|
| 698 |
+
# messages = []
|
| 699 |
+
# for msg in conversation.get('messages', []):
|
| 700 |
+
# messages.append(MessageModel(
|
| 701 |
+
# role=msg['role'],
|
| 702 |
+
# content=msg['content'],
|
| 703 |
+
# timestamp=msg['timestamp'].isoformat() if isinstance(msg['timestamp'], datetime) else msg['timestamp'],
|
| 704 |
+
# metadata=msg.get('metadata')
|
| 705 |
+
# ))
|
| 706 |
+
|
| 707 |
+
# return ConversationHistoryResponse(
|
| 708 |
+
# conversation_id=conversation_id,
|
| 709 |
+
# messages=messages,
|
| 710 |
+
# message_count=len(messages)
|
| 711 |
+
# )
|
| 712 |
+
|
| 713 |
+
# except HTTPException:
|
| 714 |
+
# raise
|
| 715 |
+
# except Exception as e:
|
| 716 |
+
# raise HTTPException(
|
| 717 |
+
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 718 |
+
# detail=f"Failed to fetch conversation history: {str(e)}"
|
| 719 |
+
# )
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
# @router.get("/conversations")
|
| 723 |
+
# async def list_user_conversations(
|
| 724 |
+
# limit: int = 10,
|
| 725 |
+
# skip: int = 0,
|
| 726 |
+
# current_user: TokenData = Depends(get_current_user) # ← REQUIRES AUTH!
|
| 727 |
+
# ):
|
| 728 |
+
# """
|
| 729 |
+
# List all conversations for the authenticated user.
|
| 730 |
+
|
| 731 |
+
# **REQUIRES AUTHENTICATION** - User ID is extracted from JWT token.
|
| 732 |
+
|
| 733 |
+
# Args:
|
| 734 |
+
# limit: Maximum conversations to return (default: 10)
|
| 735 |
+
# skip: Number to skip for pagination (default: 0)
|
| 736 |
+
# current_user: Authenticated user data from JWT token
|
| 737 |
+
|
| 738 |
+
# Returns:
|
| 739 |
+
# dict: List of conversations for current user
|
| 740 |
+
# """
|
| 741 |
+
# try:
|
| 742 |
+
# conversations = await conversation_repo.get_user_conversations(
|
| 743 |
+
# user_id=current_user.user_id, # From JWT token!
|
| 744 |
+
# limit=limit,
|
| 745 |
+
# skip=skip
|
| 746 |
+
# )
|
| 747 |
+
|
| 748 |
+
# # Format response
|
| 749 |
+
# return {
|
| 750 |
+
# "user_id": current_user.user_id,
|
| 751 |
+
# "user_email": current_user.email,
|
| 752 |
+
# "conversations": [
|
| 753 |
+
# {
|
| 754 |
+
# "conversation_id": conv['conversation_id'],
|
| 755 |
+
# "created_at": conv['created_at'].isoformat() if isinstance(conv['created_at'], datetime) else conv['created_at'],
|
| 756 |
+
# "updated_at": conv['updated_at'].isoformat() if isinstance(conv['updated_at'], datetime) else conv['updated_at'],
|
| 757 |
+
# "message_count": len(conv.get('messages', []))
|
| 758 |
+
# }
|
| 759 |
+
# for conv in conversations
|
| 760 |
+
# ],
|
| 761 |
+
# "total": len(conversations)
|
| 762 |
+
# }
|
| 763 |
+
|
| 764 |
+
# except Exception as e:
|
| 765 |
+
# raise HTTPException(
|
| 766 |
+
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 767 |
+
# detail=f"Failed to fetch conversations: {str(e)}"
|
| 768 |
+
# )
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
# @router.delete("/conversation/{conversation_id}")
|
| 772 |
+
# async def delete_conversation(
|
| 773 |
+
# conversation_id: str,
|
| 774 |
+
# current_user: TokenData = Depends(get_current_user) # ← REQUIRES AUTH!
|
| 775 |
+
# ):
|
| 776 |
+
# """
|
| 777 |
+
# Delete a conversation.
|
| 778 |
+
|
| 779 |
+
# **REQUIRES AUTHENTICATION** - User can only delete their own conversations.
|
| 780 |
+
|
| 781 |
+
# Args:
|
| 782 |
+
# conversation_id: Conversation ID to delete
|
| 783 |
+
# current_user: Authenticated user data from JWT token
|
| 784 |
+
|
| 785 |
+
# Returns:
|
| 786 |
+
# dict: Success message
|
| 787 |
+
|
| 788 |
+
# Raises:
|
| 789 |
+
# HTTPException: If conversation not found or user doesn't own it
|
| 790 |
+
# """
|
| 791 |
+
# try:
|
| 792 |
+
# # Get conversation
|
| 793 |
+
# conversation = await conversation_repo.get_conversation(conversation_id)
|
| 794 |
+
|
| 795 |
+
# if not conversation:
|
| 796 |
+
# raise HTTPException(
|
| 797 |
+
# status_code=status.HTTP_404_NOT_FOUND,
|
| 798 |
+
# detail=f"Conversation {conversation_id} not found"
|
| 799 |
+
# )
|
| 800 |
+
|
| 801 |
+
# # Verify user owns this conversation
|
| 802 |
+
# if conversation["user_id"] != current_user.user_id:
|
| 803 |
+
# raise HTTPException(
|
| 804 |
+
# status_code=status.HTTP_403_FORBIDDEN,
|
| 805 |
+
# detail="Access denied - you don't own this conversation"
|
| 806 |
+
# )
|
| 807 |
+
|
| 808 |
+
# # Delete conversation
|
| 809 |
+
# success = await conversation_repo.delete_conversation(conversation_id)
|
| 810 |
+
|
| 811 |
+
# if success:
|
| 812 |
+
# return {
|
| 813 |
+
# "message": "Conversation deleted successfully",
|
| 814 |
+
# "conversation_id": conversation_id
|
| 815 |
+
# }
|
| 816 |
+
# else:
|
| 817 |
+
# raise HTTPException(
|
| 818 |
+
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 819 |
+
# detail="Failed to delete conversation"
|
| 820 |
+
# )
|
| 821 |
+
|
| 822 |
+
# except HTTPException:
|
| 823 |
+
# raise
|
| 824 |
+
# except Exception as e:
|
| 825 |
+
# raise HTTPException(
|
| 826 |
+
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 827 |
+
# detail=f"Failed to delete conversation: {str(e)}"
|
| 828 |
+
# )
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
# @router.get("/health")
|
| 832 |
+
# async def chat_health():
|
| 833 |
+
# """
|
| 834 |
+
# Health check for chat service.
|
| 835 |
+
|
| 836 |
+
# **PUBLIC ENDPOINT** - No authentication required.
|
| 837 |
+
|
| 838 |
+
# Returns:
|
| 839 |
+
# dict: Health status of chat service components
|
| 840 |
+
# """
|
| 841 |
+
# try:
|
| 842 |
+
# health = await chat_service.health_check()
|
| 843 |
+
|
| 844 |
+
# return {
|
| 845 |
+
# "status": "healthy",
|
| 846 |
+
# "service": "chat",
|
| 847 |
+
# "components": health['components'],
|
| 848 |
+
# "timestamp": datetime.now().isoformat()
|
| 849 |
+
# }
|
| 850 |
+
|
| 851 |
+
# except Exception as e:
|
| 852 |
+
# return {
|
| 853 |
+
# "status": "unhealthy",
|
| 854 |
+
# "service": "chat",
|
| 855 |
+
# "error": str(e),
|
| 856 |
+
# "timestamp": datetime.now().isoformat()
|
| 857 |
+
# }
|
| 858 |
+
|
| 859 |
+
|
| 860 |
+
# # ============================================================================
|
| 861 |
+
# # USAGE DOCUMENTATION
|
| 862 |
+
# # ============================================================================
|
| 863 |
+
# """
|
| 864 |
+
# === API USAGE EXAMPLES (WITH AUTHENTICATION) ===
|
| 865 |
+
|
| 866 |
+
# ALL ENDPOINTS (except /health) NOW REQUIRE JWT TOKEN IN AUTHORIZATION HEADER!
|
| 867 |
+
|
| 868 |
+
# 1. Register user:
|
| 869 |
+
# POST /api/v1/auth/register
|
| 870 |
+
# Body: {
|
| 871 |
+
# "email": "user@example.com",
|
| 872 |
+
# "password": "SecurePass123",
|
| 873 |
+
# "full_name": "John Doe"
|
| 874 |
+
# }
|
| 875 |
+
# Response: { "access_token": "eyJ...", "user": {...} }
|
| 876 |
+
|
| 877 |
+
# 2. Login:
|
| 878 |
+
# POST /api/v1/auth/login
|
| 879 |
+
# Body: {
|
| 880 |
+
# "email": "user@example.com",
|
| 881 |
+
# "password": "SecurePass123"
|
| 882 |
+
# }
|
| 883 |
+
# Response: { "access_token": "eyJ...", "user": {...} }
|
| 884 |
+
|
| 885 |
+
# 3. Send chat message (WITH TOKEN):
|
| 886 |
+
# POST /api/v1/chat/
|
| 887 |
+
# Headers: { "Authorization": "Bearer eyJ..." }
|
| 888 |
+
# Body: {
|
| 889 |
+
# "query": "What is my account balance?",
|
| 890 |
+
# "conversation_id": "conv_abc" // optional
|
| 891 |
+
# }
|
| 892 |
+
|
| 893 |
+
# 4. Get conversation history (WITH TOKEN):
|
| 894 |
+
# GET /api/v1/chat/history/conv_abc
|
| 895 |
+
# Headers: { "Authorization": "Bearer eyJ..." }
|
| 896 |
+
|
| 897 |
+
# 5. List conversations (WITH TOKEN):
|
| 898 |
+
# GET /api/v1/chat/conversations?limit=10
|
| 899 |
+
# Headers: { "Authorization": "Bearer eyJ..." }
|
| 900 |
+
|
| 901 |
+
# === TESTING WITH CURL ===
|
| 902 |
+
|
| 903 |
+
# # 1. Register
|
| 904 |
+
# TOKEN=$(curl -X POST "http://localhost:8000/api/v1/auth/register" \
|
| 905 |
+
# -H "Content-Type: application/json" \
|
| 906 |
+
# -d '{"email":"test@test.com","password":"test123","full_name":"Test User"}' \
|
| 907 |
+
# | jq -r '.access_token')
|
| 908 |
+
|
| 909 |
+
# # 2. Send chat message with token
|
| 910 |
+
# curl -X POST "http://localhost:8000/api/v1/chat/" \
|
| 911 |
+
# -H "Content-Type: application/json" \
|
| 912 |
+
# -H "Authorization: Bearer $TOKEN" \
|
| 913 |
+
# -d '{"query": "What is my balance?"}'
|
| 914 |
+
# """
|
| 915 |
+
# # ============================================================================
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
|
| 928 |
+
|
| 929 |
+
|
| 930 |
+
|
| 931 |
+
|
| 932 |
+
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
# # ======================================================================================================
|
| 942 |
+
# # OLD CODE
|
| 943 |
+
# # ======================================================================================================
|
| 944 |
+
|
| 945 |
+
# # """
|
| 946 |
+
# # Chat API Endpoints
|
| 947 |
+
# # RESTful API for the Banking RAG Chatbot
|
| 948 |
+
|
| 949 |
+
# # Endpoints:
|
| 950 |
+
# # - POST /chat - Send a message and get response
|
| 951 |
+
# # - GET /chat/history/{conversation_id} - Get conversation history
|
| 952 |
+
# # - POST /chat/conversation - Create new conversation
|
| 953 |
+
# # - GET /chat/conversations - List user's conversations
|
| 954 |
+
# # - GET /chat/health - Health check for chat service
|
| 955 |
+
# # """
|
| 956 |
+
|
| 957 |
+
# # from fastapi import APIRouter, HTTPException, status
|
| 958 |
+
# # from pydantic import BaseModel, Field
|
| 959 |
+
# # from typing import List, Dict, Optional
|
| 960 |
+
# # from datetime import datetime
|
| 961 |
+
|
| 962 |
+
# # from app.services.chat_service import chat_service
|
| 963 |
+
# # from app.db.repositories.conversation_repository import ConversationRepository
|
| 964 |
+
|
| 965 |
+
|
| 966 |
+
# # # ============================================================================
|
| 967 |
+
# # # CREATE ROUTER
|
| 968 |
+
# # # ============================================================================
|
| 969 |
+
# # router = APIRouter()
|
| 970 |
+
|
| 971 |
+
# # # Initialize repository
|
| 972 |
+
# # conversation_repo = ConversationRepository()
|
| 973 |
+
|
| 974 |
+
|
| 975 |
+
# # # ============================================================================
|
| 976 |
+
# # # PYDANTIC MODELS (Request/Response schemas)
|
| 977 |
+
# # # ============================================================================
|
| 978 |
+
|
| 979 |
+
# # class ChatRequest(BaseModel):
|
| 980 |
+
# # """
|
| 981 |
+
# # Request model for chat endpoint.
|
| 982 |
+
|
| 983 |
+
# # Example:
|
| 984 |
+
# # {
|
| 985 |
+
# # "query": "What is my account balance?",
|
| 986 |
+
# # "conversation_id": "abc-123",
|
| 987 |
+
# # "user_id": "user_456"
|
| 988 |
+
# # }
|
| 989 |
+
# # """
|
| 990 |
+
# # query: str = Field(..., description="User query text", min_length=1, max_length=1000)
|
| 991 |
+
# # conversation_id: Optional[str] = Field(None, description="Optional conversation ID")
|
| 992 |
+
# # user_id: str = Field(..., description="User ID")
|
| 993 |
+
|
| 994 |
+
# # class Config:
|
| 995 |
+
# # json_schema_extra = {
|
| 996 |
+
# # "example": {
|
| 997 |
+
# # "query": "What is my account balance?",
|
| 998 |
+
# # "conversation_id": "conv-123",
|
| 999 |
+
# # "user_id": "user-456"
|
| 1000 |
+
# # }
|
| 1001 |
+
# # }
|
| 1002 |
+
|
| 1003 |
+
|
| 1004 |
+
# # class ChatResponse(BaseModel):
|
| 1005 |
+
# # """
|
| 1006 |
+
# # Response model for chat endpoint.
|
| 1007 |
+
|
| 1008 |
+
# # Contains the generated response plus metadata about the RAG pipeline.
|
| 1009 |
+
# # """
|
| 1010 |
+
# # response: str = Field(..., description="Generated response text")
|
| 1011 |
+
# # conversation_id: str = Field(..., description="Conversation ID")
|
| 1012 |
+
# # policy_action: str = Field(..., description="Policy decision: FETCH or NO_FETCH")
|
| 1013 |
+
# # policy_confidence: float = Field(..., description="Policy confidence score (0-1)")
|
| 1014 |
+
# # documents_retrieved: int = Field(..., description="Number of documents retrieved")
|
| 1015 |
+
# # top_doc_score: Optional[float] = Field(None, description="Best document similarity score")
|
| 1016 |
+
# # total_time_ms: float = Field(..., description="Total processing time in milliseconds")
|
| 1017 |
+
# # timestamp: str = Field(..., description="Response timestamp (ISO format)")
|
| 1018 |
+
|
| 1019 |
+
|
| 1020 |
+
# # class ConversationCreateRequest(BaseModel):
|
| 1021 |
+
# # """Request to create a new conversation"""
|
| 1022 |
+
# # user_id: str = Field(..., description="User ID")
|
| 1023 |
+
|
| 1024 |
+
|
| 1025 |
+
# # class ConversationCreateResponse(BaseModel):
|
| 1026 |
+
# # """Response after creating a conversation"""
|
| 1027 |
+
# # conversation_id: str = Field(..., description="Created conversation ID")
|
| 1028 |
+
# # created_at: str = Field(..., description="Creation timestamp")
|
| 1029 |
+
|
| 1030 |
+
|
| 1031 |
+
# # class MessageModel(BaseModel):
|
| 1032 |
+
# # """Single message in conversation history"""
|
| 1033 |
+
# # role: str = Field(..., description="Message role: user or assistant")
|
| 1034 |
+
# # content: str = Field(..., description="Message content")
|
| 1035 |
+
# # timestamp: str = Field(..., description="Message timestamp")
|
| 1036 |
+
# # metadata: Optional[Dict] = Field(None, description="Optional metadata")
|
| 1037 |
+
|
| 1038 |
+
|
| 1039 |
+
# # class ConversationHistoryResponse(BaseModel):
|
| 1040 |
+
# # """Response containing conversation history"""
|
| 1041 |
+
# # conversation_id: str
|
| 1042 |
+
# # messages: List[MessageModel]
|
| 1043 |
+
# # message_count: int
|
| 1044 |
+
|
| 1045 |
+
|
| 1046 |
+
# # # ============================================================================
|
| 1047 |
+
# # # ENDPOINTS
|
| 1048 |
+
# # # ============================================================================
|
| 1049 |
+
|
| 1050 |
+
# # @router.post("/", response_model=ChatResponse, status_code=status.HTTP_200_OK)
|
| 1051 |
+
# # async def chat(request: ChatRequest):
|
| 1052 |
+
# # """
|
| 1053 |
+
# # Main chat endpoint - Send a query and get a response.
|
| 1054 |
+
|
| 1055 |
+
# # This endpoint:
|
| 1056 |
+
# # 1. Processes the query through the RAG pipeline
|
| 1057 |
+
# # 2. Saves messages to MongoDB
|
| 1058 |
+
# # 3. Logs retrieval data for RL training
|
| 1059 |
+
# # 4. Returns response with metadata
|
| 1060 |
+
|
| 1061 |
+
# # Args:
|
| 1062 |
+
# # request: ChatRequest with query, conversation_id, user_id
|
| 1063 |
+
|
| 1064 |
+
# # Returns:
|
| 1065 |
+
# # ChatResponse: Generated response with metadata
|
| 1066 |
+
|
| 1067 |
+
# # Raises:
|
| 1068 |
+
# # HTTPException: If processing fails
|
| 1069 |
+
# # """
|
| 1070 |
+
# # try:
|
| 1071 |
+
# # # If no conversation_id provided, create a new conversation
|
| 1072 |
+
# # conversation_id = request.conversation_id
|
| 1073 |
+
# # if not conversation_id:
|
| 1074 |
+
# # conversation_id = await conversation_repo.create_conversation(
|
| 1075 |
+
# # user_id=request.user_id
|
| 1076 |
+
# # )
|
| 1077 |
+
|
| 1078 |
+
# # # Get conversation history
|
| 1079 |
+
# # history = await conversation_repo.get_conversation_history(
|
| 1080 |
+
# # conversation_id=conversation_id,
|
| 1081 |
+
# # max_messages=10 # Last 5 turns (10 messages)
|
| 1082 |
+
# # )
|
| 1083 |
+
|
| 1084 |
+
# # # Save user message to database
|
| 1085 |
+
# # await conversation_repo.add_message(
|
| 1086 |
+
# # conversation_id=conversation_id,
|
| 1087 |
+
# # message={
|
| 1088 |
+
# # 'role': 'user',
|
| 1089 |
+
# # 'content': request.query,
|
| 1090 |
+
# # 'timestamp': datetime.now()
|
| 1091 |
+
# # }
|
| 1092 |
+
# # )
|
| 1093 |
+
|
| 1094 |
+
# # # Process query through RAG pipeline
|
| 1095 |
+
# # result = await chat_service.process_query(
|
| 1096 |
+
# # query=request.query,
|
| 1097 |
+
# # conversation_history=history,
|
| 1098 |
+
# # user_id=request.user_id
|
| 1099 |
+
# # )
|
| 1100 |
+
|
| 1101 |
+
# # # Save assistant message to database
|
| 1102 |
+
# # await conversation_repo.add_message(
|
| 1103 |
+
# # conversation_id=conversation_id,
|
| 1104 |
+
# # message={
|
| 1105 |
+
# # 'role': 'assistant',
|
| 1106 |
+
# # 'content': result['response'],
|
| 1107 |
+
# # 'timestamp': datetime.now(),
|
| 1108 |
+
# # 'metadata': {
|
| 1109 |
+
# # 'policy_action': result['policy_action'],
|
| 1110 |
+
# # 'policy_confidence': result['policy_confidence'],
|
| 1111 |
+
# # 'documents_retrieved': result['documents_retrieved'],
|
| 1112 |
+
# # 'top_doc_score': result['top_doc_score']
|
| 1113 |
+
# # }
|
| 1114 |
+
# # }
|
| 1115 |
+
# # )
|
| 1116 |
+
|
| 1117 |
+
# # # Log retrieval data for RL training
|
| 1118 |
+
# # await conversation_repo.log_retrieval({
|
| 1119 |
+
# # 'conversation_id': conversation_id,
|
| 1120 |
+
# # 'user_id': request.user_id,
|
| 1121 |
+
# # 'query': request.query,
|
| 1122 |
+
# # 'policy_action': result['policy_action'],
|
| 1123 |
+
# # 'policy_confidence': result['policy_confidence'],
|
| 1124 |
+
# # 'should_retrieve': result['should_retrieve'],
|
| 1125 |
+
# # 'documents_retrieved': result['documents_retrieved'],
|
| 1126 |
+
# # 'top_doc_score': result['top_doc_score'],
|
| 1127 |
+
# # 'response': result['response'],
|
| 1128 |
+
# # 'retrieval_time_ms': result['retrieval_time_ms'],
|
| 1129 |
+
# # 'generation_time_ms': result['generation_time_ms'],
|
| 1130 |
+
# # 'total_time_ms': result['total_time_ms'],
|
| 1131 |
+
# # 'retrieved_docs_metadata': result.get('retrieved_docs_metadata', []),
|
| 1132 |
+
# # 'timestamp': datetime.now()
|
| 1133 |
+
# # })
|
| 1134 |
+
|
| 1135 |
+
# # # Return response
|
| 1136 |
+
# # return ChatResponse(
|
| 1137 |
+
# # response=result['response'],
|
| 1138 |
+
# # conversation_id=conversation_id,
|
| 1139 |
+
# # policy_action=result['policy_action'],
|
| 1140 |
+
# # policy_confidence=result['policy_confidence'],
|
| 1141 |
+
# # documents_retrieved=result['documents_retrieved'],
|
| 1142 |
+
# # top_doc_score=result['top_doc_score'],
|
| 1143 |
+
# # total_time_ms=result['total_time_ms'],
|
| 1144 |
+
# # timestamp=result['timestamp']
|
| 1145 |
+
# # )
|
| 1146 |
+
|
| 1147 |
+
# # except Exception as e:
|
| 1148 |
+
# # print(f"❌ Chat endpoint error: {e}")
|
| 1149 |
+
# # raise HTTPException(
|
| 1150 |
+
# # status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 1151 |
+
# # detail=f"Failed to process chat request: {str(e)}"
|
| 1152 |
+
# # )
|
| 1153 |
+
|
| 1154 |
+
|
| 1155 |
+
# # @router.post("/conversation", response_model=ConversationCreateResponse, status_code=status.HTTP_201_CREATED)
|
| 1156 |
+
# # async def create_conversation(request: ConversationCreateRequest):
|
| 1157 |
+
# # """
|
| 1158 |
+
# # Create a new conversation.
|
| 1159 |
+
|
| 1160 |
+
# # Args:
|
| 1161 |
+
# # request: ConversationCreateRequest with user_id
|
| 1162 |
+
|
| 1163 |
+
# # Returns:
|
| 1164 |
+
# # ConversationCreateResponse: Created conversation ID
|
| 1165 |
+
# # """
|
| 1166 |
+
# # try:
|
| 1167 |
+
# # conversation_id = await conversation_repo.create_conversation(
|
| 1168 |
+
# # user_id=request.user_id
|
| 1169 |
+
# # )
|
| 1170 |
+
|
| 1171 |
+
# # return ConversationCreateResponse(
|
| 1172 |
+
# # conversation_id=conversation_id,
|
| 1173 |
+
# # created_at=datetime.now().isoformat()
|
| 1174 |
+
# # )
|
| 1175 |
+
|
| 1176 |
+
# # except Exception as e:
|
| 1177 |
+
# # raise HTTPException(
|
| 1178 |
+
# # status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 1179 |
+
# # detail=f"Failed to create conversation: {str(e)}"
|
| 1180 |
+
# # )
|
| 1181 |
+
|
| 1182 |
+
|
| 1183 |
+
# # @router.get("/history/{conversation_id}", response_model=ConversationHistoryResponse)
|
| 1184 |
+
# # async def get_conversation_history(conversation_id: str):
|
| 1185 |
+
# # """
|
| 1186 |
+
# # Get conversation history by ID.
|
| 1187 |
+
|
| 1188 |
+
# # Args:
|
| 1189 |
+
# # conversation_id: Conversation ID
|
| 1190 |
+
|
| 1191 |
+
# # Returns:
|
| 1192 |
+
# # ConversationHistoryResponse: List of messages
|
| 1193 |
+
# # """
|
| 1194 |
+
# # try:
|
| 1195 |
+
# # # Get conversation
|
| 1196 |
+
# # conversation = await conversation_repo.get_conversation(conversation_id)
|
| 1197 |
+
|
| 1198 |
+
# # if not conversation:
|
| 1199 |
+
# # raise HTTPException(
|
| 1200 |
+
# # status_code=status.HTTP_404_NOT_FOUND,
|
| 1201 |
+
# # detail=f"Conversation {conversation_id} not found"
|
| 1202 |
+
# # )
|
| 1203 |
+
|
| 1204 |
+
# # # Format messages
|
| 1205 |
+
# # messages = []
|
| 1206 |
+
# # for msg in conversation.get('messages', []):
|
| 1207 |
+
# # messages.append(MessageModel(
|
| 1208 |
+
# # role=msg['role'],
|
| 1209 |
+
# # content=msg['content'],
|
| 1210 |
+
# # timestamp=msg['timestamp'].isoformat() if isinstance(msg['timestamp'], datetime) else msg['timestamp'],
|
| 1211 |
+
# # metadata=msg.get('metadata')
|
| 1212 |
+
# # ))
|
| 1213 |
+
|
| 1214 |
+
# # return ConversationHistoryResponse(
|
| 1215 |
+
# # conversation_id=conversation_id,
|
| 1216 |
+
# # messages=messages,
|
| 1217 |
+
# # message_count=len(messages)
|
| 1218 |
+
# # )
|
| 1219 |
+
|
| 1220 |
+
# # except HTTPException:
|
| 1221 |
+
# # raise
|
| 1222 |
+
# # except Exception as e:
|
| 1223 |
+
# # raise HTTPException(
|
| 1224 |
+
# # status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 1225 |
+
# # detail=f"Failed to fetch conversation history: {str(e)}"
|
| 1226 |
+
# # )
|
| 1227 |
+
|
| 1228 |
+
|
| 1229 |
+
# # @router.get("/conversations")
|
| 1230 |
+
# # async def list_user_conversations(user_id: str, limit: int = 10, skip: int = 0):
|
| 1231 |
+
# # """
|
| 1232 |
+
# # List all conversations for a user.
|
| 1233 |
+
|
| 1234 |
+
# # Args:
|
| 1235 |
+
# # user_id: User ID
|
| 1236 |
+
# # limit: Maximum conversations to return (default: 10)
|
| 1237 |
+
# # skip: Number to skip for pagination (default: 0)
|
| 1238 |
+
|
| 1239 |
+
# # Returns:
|
| 1240 |
+
# # dict: List of conversations
|
| 1241 |
+
# # """
|
| 1242 |
+
# # try:
|
| 1243 |
+
# # conversations = await conversation_repo.get_user_conversations(
|
| 1244 |
+
# # user_id=user_id,
|
| 1245 |
+
# # limit=limit,
|
| 1246 |
+
# # skip=skip
|
| 1247 |
+
# # )
|
| 1248 |
+
|
| 1249 |
+
# # # Format response
|
| 1250 |
+
# # return {
|
| 1251 |
+
# # "user_id": user_id,
|
| 1252 |
+
# # "conversations": [
|
| 1253 |
+
# # {
|
| 1254 |
+
# # "conversation_id": conv['conversation_id'],
|
| 1255 |
+
# # "created_at": conv['created_at'].isoformat() if isinstance(conv['created_at'], datetime) else conv['created_at'],
|
| 1256 |
+
# # "updated_at": conv['updated_at'].isoformat() if isinstance(conv['updated_at'], datetime) else conv['updated_at'],
|
| 1257 |
+
# # "message_count": len(conv.get('messages', []))
|
| 1258 |
+
# # }
|
| 1259 |
+
# # for conv in conversations
|
| 1260 |
+
# # ],
|
| 1261 |
+
# # "total": len(conversations)
|
| 1262 |
+
# # }
|
| 1263 |
+
|
| 1264 |
+
# # except Exception as e:
|
| 1265 |
+
# # raise HTTPException(
|
| 1266 |
+
# # status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 1267 |
+
# # detail=f"Failed to fetch conversations: {str(e)}"
|
| 1268 |
+
# # )
|
| 1269 |
+
|
| 1270 |
+
|
| 1271 |
+
# # @router.get("/health")
|
| 1272 |
+
# # async def chat_health():
|
| 1273 |
+
# # """
|
| 1274 |
+
# # Health check for chat service.
|
| 1275 |
+
|
| 1276 |
+
# # Returns:
|
| 1277 |
+
# # dict: Health status of chat service components
|
| 1278 |
+
# # """
|
| 1279 |
+
# # try:
|
| 1280 |
+
# # health = await chat_service.health_check()
|
| 1281 |
+
|
| 1282 |
+
# # return {
|
| 1283 |
+
# # "status": "healthy",
|
| 1284 |
+
# # "service": "chat",
|
| 1285 |
+
# # "components": health['components'],
|
| 1286 |
+
# # "timestamp": datetime.now().isoformat()
|
| 1287 |
+
# # }
|
| 1288 |
+
|
| 1289 |
+
# # except Exception as e:
|
| 1290 |
+
# # return {
|
| 1291 |
+
# # "status": "unhealthy",
|
| 1292 |
+
# # "service": "chat",
|
| 1293 |
+
# # "error": str(e),
|
| 1294 |
+
# # "timestamp": datetime.now().isoformat()
|
| 1295 |
+
# # }
|
| 1296 |
+
|
| 1297 |
+
|
| 1298 |
+
# # # ============================================================================
|
| 1299 |
+
# # # USAGE DOCUMENTATION
|
| 1300 |
+
# # # ============================================================================
|
| 1301 |
+
# # """
|
| 1302 |
+
# # === API USAGE EXAMPLES ===
|
| 1303 |
+
|
| 1304 |
+
# # 1. Send a chat message:
|
| 1305 |
+
# # POST /api/v1/chat/
|
| 1306 |
+
# # Body: {
|
| 1307 |
+
# # "query": "What is my account balance?",
|
| 1308 |
+
# # "user_id": "user_123",
|
| 1309 |
+
# # "conversation_id": "conv_abc" // optional
|
| 1310 |
+
# # }
|
| 1311 |
+
|
| 1312 |
+
# # 2. Create new conversation:
|
| 1313 |
+
# # POST /api/v1/chat/conversation
|
| 1314 |
+
# # Body: {
|
| 1315 |
+
# # "user_id": "user_123"
|
| 1316 |
+
# # }
|
| 1317 |
+
|
| 1318 |
+
# # 3. Get conversation history:
|
| 1319 |
+
# # GET /api/v1/chat/history/conv_abc
|
| 1320 |
+
|
| 1321 |
+
# # 4. List user's conversations:
|
| 1322 |
+
# # GET /api/v1/chat/conversations?user_id=user_123&limit=10&skip=0
|
| 1323 |
+
|
| 1324 |
+
# # 5. Check health:
|
| 1325 |
+
# # GET /api/v1/chat/health
|
| 1326 |
+
|
| 1327 |
+
# # === TESTING WITH CURL ===
|
| 1328 |
+
|
| 1329 |
+
# # # Send chat message
|
| 1330 |
+
# # curl -X POST "http://localhost:8000/api/v1/chat/" \
|
| 1331 |
+
# # -H "Content-Type: application/json" \
|
| 1332 |
+
# # -d '{
|
| 1333 |
+
# # "query": "What is my balance?",
|
| 1334 |
+
# # "user_id": "user_123"
|
| 1335 |
+
# # }'
|
| 1336 |
+
|
| 1337 |
+
# # # Get history
|
| 1338 |
+
# # curl "http://localhost:8000/api/v1/chat/history/conv_123"
|
| 1339 |
+
|
| 1340 |
+
# # === TESTING WITH SWAGGER UI ===
|
| 1341 |
+
|
| 1342 |
+
# # After starting the server, visit:
|
| 1343 |
+
# # http://localhost:8000/docs
|
| 1344 |
+
|
| 1345 |
+
# # Interactive API documentation with "Try it out" buttons!
|
| 1346 |
+
# # """
|
app/config.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Application Configuration
|
| 3 |
+
Settings for Banking RAG Chatbot with JWT Authentication
|
| 4 |
+
Updated to support multiple Groq API keys and HuggingFace tokens with fallback logic
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from typing import List
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
import shutil # Add this import
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
load_dotenv()
|
| 14 |
+
|
| 15 |
+
class Settings:
|
| 16 |
+
"""Application settings loaded from environment variables"""
|
| 17 |
+
|
| 18 |
+
# ========================================================================
|
| 19 |
+
# ENVIRONMENT
|
| 20 |
+
# ========================================================================
|
| 21 |
+
ENVIRONMENT: str = os.getenv("ENVIRONMENT", "development")
|
| 22 |
+
DEBUG: bool = os.getenv("DEBUG", "True").lower() == "true"
|
| 23 |
+
|
| 24 |
+
# ========================================================================
|
| 25 |
+
# MONGODB
|
| 26 |
+
# ========================================================================
|
| 27 |
+
MONGODB_URI: str = os.getenv("MONGODB_URI", "")
|
| 28 |
+
DATABASE_NAME: str = os.getenv("DATABASE_NAME", "aml_ia_db")
|
| 29 |
+
|
| 30 |
+
# ========================================================================
|
| 31 |
+
# JWT AUTHENTICATION
|
| 32 |
+
# ========================================================================
|
| 33 |
+
SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key-change-in-production")
|
| 34 |
+
ALGORITHM: str = os.getenv("ALGORITHM", "HS256")
|
| 35 |
+
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440"))
|
| 36 |
+
|
| 37 |
+
# ========================================================================
|
| 38 |
+
# CORS (for frontend)
|
| 39 |
+
# ========================================================================
|
| 40 |
+
ALLOWED_ORIGINS: str = os.getenv("ALLOWED_ORIGINS", "*")
|
| 41 |
+
|
| 42 |
+
# ========================================================================
|
| 43 |
+
# GROQ API KEYS (Multiple for fallback)
|
| 44 |
+
# ========================================================================
|
| 45 |
+
GROQ_API_KEY_1: str = os.getenv("GROQ_API_KEY_1", "") # Primary
|
| 46 |
+
GROQ_API_KEY_2: str = os.getenv("GROQ_API_KEY_2", "") # Fallback 1
|
| 47 |
+
GROQ_API_KEY_3: str = os.getenv("GROQ_API_KEY_3", "") # Fallback 2
|
| 48 |
+
|
| 49 |
+
# Model names for Groq (using correct GroqCloud naming)
|
| 50 |
+
GROQ_CHAT_MODEL: str = os.getenv("GROQ_CHAT_MODEL", "llama3-8b-8192") # For chat interface
|
| 51 |
+
GROQ_EVAL_MODEL: str = os.getenv("GROQ_EVAL_MODEL", "llama3-70b-8192") # For evaluation
|
| 52 |
+
|
| 53 |
+
# ========================================================================
|
| 54 |
+
# Commented as of now, can be re-enabled if rate limiting is needed
|
| 55 |
+
# ========================================================================
|
| 56 |
+
|
| 57 |
+
# GROQ_REQUESTS_PER_MINUTE: int = int(os.getenv("GROQ_REQUESTS_PER_MINUTE", "30"))
|
| 58 |
+
|
| 59 |
+
# ========================================================================
|
| 60 |
+
# HUGGING FACE TOKENS (Multiple for fallback)
|
| 61 |
+
# ========================================================================
|
| 62 |
+
HF_TOKEN_1: str = os.getenv("HF_TOKEN_1", "") # Primary
|
| 63 |
+
HF_TOKEN_2: str = os.getenv("HF_TOKEN_2", "") # Fallback 1
|
| 64 |
+
HF_TOKEN_3: str = os.getenv("HF_TOKEN_3", "") # Fallback 2
|
| 65 |
+
|
| 66 |
+
# HuggingFace model for inference (fallback from Groq)
|
| 67 |
+
HF_CHAT_MODEL: str = os.getenv("HF_CHAT_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct")
|
| 68 |
+
HF_EVAL_MODEL: str = os.getenv("HF_EVAL_MODEL", "meta-llama/Meta-Llama-3-70B-Instruct")
|
| 69 |
+
|
| 70 |
+
# ========================================================================
|
| 71 |
+
# MODEL PATHS (for RL Policy Network and RAG models)
|
| 72 |
+
# ========================================================================
|
| 73 |
+
POLICY_MODEL_PATH: str = os.getenv("POLICY_MODEL_PATH", "app/models/best_policy_model.pth")
|
| 74 |
+
RETRIEVER_MODEL_PATH: str = os.getenv("RETRIEVER_MODEL_PATH", "app/models/best_retriever_model.pth")
|
| 75 |
+
FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH", "app/models/faiss_index.pkl")
|
| 76 |
+
KB_PATH: str = os.getenv("KB_PATH", "app/data/final_knowledge_base.jsonl")
|
| 77 |
+
|
| 78 |
+
# ========================================================================
|
| 79 |
+
# DEVICE SETTINGS (for PyTorch/TensorFlow models)
|
| 80 |
+
# ========================================================================
|
| 81 |
+
DEVICE: str = os.getenv("DEVICE", "cpu")
|
| 82 |
+
|
| 83 |
+
# ========================================================================
|
| 84 |
+
# LLM PARAMETERS
|
| 85 |
+
# ========================================================================
|
| 86 |
+
LLM_TEMPERATURE: float = float(os.getenv("LLM_TEMPERATURE", "0.7"))
|
| 87 |
+
LLM_MAX_TOKENS: int = int(os.getenv("LLM_MAX_TOKENS", "1024"))
|
| 88 |
+
|
| 89 |
+
# ========================================================================
|
| 90 |
+
# RAG PARAMETERS
|
| 91 |
+
# ========================================================================
|
| 92 |
+
TOP_K: int = int(os.getenv("TOP_K", "5"))
|
| 93 |
+
SIMILARITY_THRESHOLD: float = float(os.getenv("SIMILARITY_THRESHOLD", "0.5"))
|
| 94 |
+
MAX_CONTEXT_LENGTH: int = int(os.getenv("MAX_CONTEXT_LENGTH", "2000"))
|
| 95 |
+
|
| 96 |
+
# ========================================================================
|
| 97 |
+
# POLICY NETWORK PARAMETERS
|
| 98 |
+
# ========================================================================
|
| 99 |
+
POLICY_MAX_LEN: int = int(os.getenv("POLICY_MAX_LEN", "256"))
|
| 100 |
+
CONFIDENCE_THRESHOLD: float = float(os.getenv("CONFIDENCE_THRESHOLD", "0.7"))
|
| 101 |
+
|
| 102 |
+
# ========================================================================
|
| 103 |
+
# HUGGING FACE MODEL REPOSITORY (for deployment)
|
| 104 |
+
# ========================================================================
|
| 105 |
+
HF_MODEL_REPO: str = os.getenv("HF_MODEL_REPO", "eeshanyaj/questrag_models")
|
| 106 |
+
|
| 107 |
+
def download_model_if_needed(self, hf_filename: str, local_path: str):
|
| 108 |
+
"""
|
| 109 |
+
Download model from HuggingFace Hub if not exists locally.
|
| 110 |
+
This runs on startup for deployment.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
hf_filename: Path in HF repo (e.g., "models/best_policy_model.pth")
|
| 114 |
+
local_path: Where to save locally (e.g., "app/models/best_policy_model.pth")
|
| 115 |
+
"""
|
| 116 |
+
if not os.path.exists(local_path):
|
| 117 |
+
print(f"📥 Downloading {hf_filename} from HuggingFace Hub...")
|
| 118 |
+
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
from huggingface_hub import hf_hub_download
|
| 122 |
+
import shutil
|
| 123 |
+
|
| 124 |
+
# Download from HF Hub
|
| 125 |
+
downloaded_path = hf_hub_download(
|
| 126 |
+
repo_id=self.HF_MODEL_REPO,
|
| 127 |
+
filename=hf_filename,
|
| 128 |
+
repo_type="model",
|
| 129 |
+
cache_dir=".cache"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Copy to expected location
|
| 133 |
+
shutil.copy(downloaded_path, local_path)
|
| 134 |
+
print(f"✅ Downloaded {hf_filename}")
|
| 135 |
+
except Exception as e:
|
| 136 |
+
print(f"❌ Error downloading {hf_filename}: {e}")
|
| 137 |
+
raise
|
| 138 |
+
else:
|
| 139 |
+
print(f"✓ Model already exists: {local_path}")
|
| 140 |
+
|
| 141 |
+
return local_path
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# ========================================================================
|
| 145 |
+
# HELPER METHODS
|
| 146 |
+
# ========================================================================
|
| 147 |
+
def get_groq_api_keys(self) -> List[str]:
|
| 148 |
+
"""Get all configured Groq API keys in priority order"""
|
| 149 |
+
keys = []
|
| 150 |
+
if self.GROQ_API_KEY_1:
|
| 151 |
+
keys.append(self.GROQ_API_KEY_1)
|
| 152 |
+
if self.GROQ_API_KEY_2:
|
| 153 |
+
keys.append(self.GROQ_API_KEY_2)
|
| 154 |
+
if self.GROQ_API_KEY_3:
|
| 155 |
+
keys.append(self.GROQ_API_KEY_3)
|
| 156 |
+
return keys
|
| 157 |
+
|
| 158 |
+
def get_hf_tokens(self) -> List[str]:
|
| 159 |
+
"""Get all configured HuggingFace tokens in priority order"""
|
| 160 |
+
tokens = []
|
| 161 |
+
if self.HF_TOKEN_1:
|
| 162 |
+
tokens.append(self.HF_TOKEN_1)
|
| 163 |
+
if self.HF_TOKEN_2:
|
| 164 |
+
tokens.append(self.HF_TOKEN_2)
|
| 165 |
+
if self.HF_TOKEN_3:
|
| 166 |
+
tokens.append(self.HF_TOKEN_3)
|
| 167 |
+
return tokens
|
| 168 |
+
|
| 169 |
+
def is_groq_enabled(self) -> bool:
|
| 170 |
+
"""Check if at least one Groq API key is configured"""
|
| 171 |
+
return bool(self.get_groq_api_keys())
|
| 172 |
+
|
| 173 |
+
def is_hf_enabled(self) -> bool:
|
| 174 |
+
"""Check if at least one HuggingFace token is configured"""
|
| 175 |
+
return bool(self.get_hf_tokens())
|
| 176 |
+
|
| 177 |
+
def get_allowed_origins(self) -> List[str]:
|
| 178 |
+
"""Parse allowed origins from comma-separated string"""
|
| 179 |
+
if self.ALLOWED_ORIGINS == "*":
|
| 180 |
+
return ["*"]
|
| 181 |
+
return [origin.strip() for origin in self.ALLOWED_ORIGINS.split(",")]
|
| 182 |
+
|
| 183 |
+
def get_llm_for_task(self, task: str = "chat") -> str:
|
| 184 |
+
"""
|
| 185 |
+
Get LLM model name for a specific task.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
task: Task type ('chat' or 'evaluation')
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
str: Model name for the task
|
| 192 |
+
"""
|
| 193 |
+
if task == "evaluation":
|
| 194 |
+
return self.GROQ_EVAL_MODEL # llama3-70b-8192
|
| 195 |
+
else:
|
| 196 |
+
return self.GROQ_CHAT_MODEL # llama3-8b-8192
|
| 197 |
+
|
| 198 |
+
# ============================================================================
|
| 199 |
+
# CREATE GLOBAL SETTINGS INSTANCE
|
| 200 |
+
# ============================================================================
|
| 201 |
+
settings = Settings()
|
| 202 |
+
|
| 203 |
+
# ============================================================================
|
| 204 |
+
# PRINT CONFIGURATION ON LOAD
|
| 205 |
+
# ============================================================================
|
| 206 |
+
print("=" * 80)
|
| 207 |
+
print("✅ Configuration Loaded")
|
| 208 |
+
print("=" * 80)
|
| 209 |
+
print(f"Environment: {settings.ENVIRONMENT}")
|
| 210 |
+
print(f"Debug Mode: {settings.DEBUG}")
|
| 211 |
+
print(f"Database: {settings.DATABASE_NAME}")
|
| 212 |
+
print(f"Device: {settings.DEVICE}")
|
| 213 |
+
print(f"CORS Origins: {settings.ALLOWED_ORIGINS}")
|
| 214 |
+
print()
|
| 215 |
+
print("🔑 API Keys:")
|
| 216 |
+
groq_keys = settings.get_groq_api_keys()
|
| 217 |
+
print(f" Groq Keys: {len(groq_keys)} configured")
|
| 218 |
+
for i, key in enumerate(groq_keys, 1):
|
| 219 |
+
print(f" - Key {i}: {'✅ Set' if key else '❌ Missing'}")
|
| 220 |
+
hf_tokens = settings.get_hf_tokens()
|
| 221 |
+
print(f" HuggingFace Tokens: {len(hf_tokens)} configured")
|
| 222 |
+
for i, token in enumerate(hf_tokens, 1):
|
| 223 |
+
print(f" - Token {i}: {'✅ Set' if token else '❌ Missing'}")
|
| 224 |
+
print(f" MongoDB: {'✅ Configured' if settings.MONGODB_URI else '❌ Missing'}")
|
| 225 |
+
print(f" JWT Secret: {'✅ Configured' if settings.SECRET_KEY != 'your-secret-key-change-in-production' else '⚠️ Using default (CHANGE THIS!)'}")
|
| 226 |
+
print()
|
| 227 |
+
print("🤖 LLM Models:")
|
| 228 |
+
print(f" Chat Model: {settings.GROQ_CHAT_MODEL} (Llama 3 8B)")
|
| 229 |
+
print(f" Eval Model: {settings.GROQ_EVAL_MODEL} (Llama 3 70B)")
|
| 230 |
+
print()
|
| 231 |
+
print("🤖 Model Paths:")
|
| 232 |
+
print(f" Policy Model: {settings.POLICY_MODEL_PATH}")
|
| 233 |
+
print(f" Retriever Model: {settings.RETRIEVER_MODEL_PATH}")
|
| 234 |
+
print(f" FAISS Index: {settings.FAISS_INDEX_PATH}")
|
| 235 |
+
print(f" Knowledge Base: {settings.KB_PATH}")
|
| 236 |
+
print("=" * 80)
|
app/core/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core package - Core utilities and helpers
|
| 3 |
+
Contains security, authentication, API key management, and shared utilities
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
__version__ = "1.0.0"
|
app/core/llm_manager.py
ADDED
|
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-LLM Manager with Groq (ChatGroq) and HuggingFace Fallback Logic
|
| 3 |
+
|
| 4 |
+
Architecture:
|
| 5 |
+
- Primary: Groq API with 3 keys (sequential fallback)
|
| 6 |
+
- Fallback: HuggingFace Inference API with 3 tokens (sequential fallback)
|
| 7 |
+
- Llama 3 8B for chat interface
|
| 8 |
+
- Llama 3 70B for evaluation
|
| 9 |
+
|
| 10 |
+
Fallback Logic:
|
| 11 |
+
1. Try GROQ_API_KEY_1
|
| 12 |
+
2. If fails, try GROQ_API_KEY_2
|
| 13 |
+
3. If fails, try GROQ_API_KEY_3
|
| 14 |
+
4. If all Groq keys fail, try HF_TOKEN_1
|
| 15 |
+
5. If fails, try HF_TOKEN_2
|
| 16 |
+
6. If fails, try HF_TOKEN_3
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import time
|
| 20 |
+
from typing import List, Dict, Optional, Literal
|
| 21 |
+
from langchain_groq import ChatGroq
|
| 22 |
+
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
| 23 |
+
from huggingface_hub import InferenceClient
|
| 24 |
+
from app.config import settings
|
| 25 |
+
|
| 26 |
+
# ============================================================================
|
| 27 |
+
# GROQ MANAGER WITH FALLBACK
|
| 28 |
+
# ============================================================================
|
| 29 |
+
class GroqManager:
|
| 30 |
+
"""
|
| 31 |
+
Groq API Manager with multiple API key fallback support
|
| 32 |
+
Uses ChatGroq from langchain_groq
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self):
|
| 36 |
+
"""Initialize Groq manager with all available API keys"""
|
| 37 |
+
self.api_keys = settings.get_groq_api_keys()
|
| 38 |
+
self.chat_model_name = settings.GROQ_CHAT_MODEL # llama3-8b-8192
|
| 39 |
+
self.eval_model_name = settings.GROQ_EVAL_MODEL # llama3-70b-8192
|
| 40 |
+
|
| 41 |
+
# Track current key index
|
| 42 |
+
self.current_key_index = 0
|
| 43 |
+
|
| 44 |
+
# Rate limiting tracking
|
| 45 |
+
self.requests_this_minute = 0
|
| 46 |
+
self.last_reset = time.time()
|
| 47 |
+
|
| 48 |
+
if not self.api_keys:
|
| 49 |
+
raise ValueError("No Groq API keys configured. Set GROQ_API_KEY_1 in .env")
|
| 50 |
+
|
| 51 |
+
print(f"✅ Groq Manager initialized with {len(self.api_keys)} API key(s)")
|
| 52 |
+
print(f" Chat Model: {self.chat_model_name}")
|
| 53 |
+
print(f" Eval Model: {self.eval_model_name}")
|
| 54 |
+
|
| 55 |
+
def _check_rate_limits(self):
|
| 56 |
+
"""
|
| 57 |
+
Check and reset rate limit counters.
|
| 58 |
+
Groq Free: 30 requests/min
|
| 59 |
+
"""
|
| 60 |
+
current_time = time.time()
|
| 61 |
+
|
| 62 |
+
# Reset counters every minute
|
| 63 |
+
if current_time - self.last_reset > 60:
|
| 64 |
+
self.requests_this_minute = 0
|
| 65 |
+
self.last_reset = current_time
|
| 66 |
+
|
| 67 |
+
# Check if limits exceeded
|
| 68 |
+
# =================================================================
|
| 69 |
+
# Uncomment below if rate limiting enforcement is needed
|
| 70 |
+
# =================================================================
|
| 71 |
+
|
| 72 |
+
# if self.requests_this_minute >= settings.GROQ_REQUESTS_PER_MINUTE:
|
| 73 |
+
# wait_time = 60 - (current_time - self.last_reset)
|
| 74 |
+
# print(f"⚠️ Groq rate limit hit. Waiting {wait_time:.1f}s...")
|
| 75 |
+
# time.sleep(wait_time)
|
| 76 |
+
# self._check_rate_limits()
|
| 77 |
+
|
| 78 |
+
def _create_llm(self, api_key: str, model_name: str) -> ChatGroq:
|
| 79 |
+
"""Create ChatGroq instance with given API key and model"""
|
| 80 |
+
return ChatGroq(
|
| 81 |
+
api_key=api_key,
|
| 82 |
+
model_name=model_name,
|
| 83 |
+
temperature=settings.LLM_TEMPERATURE,
|
| 84 |
+
max_tokens=settings.LLM_MAX_TOKENS,
|
| 85 |
+
max_retries=0 # Disable automatic retries, we handle fallback manually
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
async def generate(
|
| 89 |
+
self,
|
| 90 |
+
messages: List[Dict[str, str]],
|
| 91 |
+
system_prompt: Optional[str] = None,
|
| 92 |
+
task: Literal["chat", "evaluation"] = "chat"
|
| 93 |
+
) -> str:
|
| 94 |
+
"""
|
| 95 |
+
Generate response using Groq with fallback logic.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
messages: List of conversation messages
|
| 99 |
+
system_prompt: Optional system prompt
|
| 100 |
+
task: Task type to determine model (chat uses 8B, evaluation uses 70B)
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
str: Generated response text
|
| 104 |
+
|
| 105 |
+
Raises:
|
| 106 |
+
Exception: If all Groq API keys fail
|
| 107 |
+
"""
|
| 108 |
+
self._check_rate_limits()
|
| 109 |
+
|
| 110 |
+
# Select model based on task
|
| 111 |
+
model_name = self.eval_model_name if task == "evaluation" else self.chat_model_name
|
| 112 |
+
|
| 113 |
+
# Format messages for LangChain
|
| 114 |
+
formatted_messages = []
|
| 115 |
+
|
| 116 |
+
# Add system message if provided
|
| 117 |
+
if system_prompt:
|
| 118 |
+
formatted_messages.append(SystemMessage(content=system_prompt))
|
| 119 |
+
|
| 120 |
+
# Convert conversation messages
|
| 121 |
+
for msg in messages:
|
| 122 |
+
if msg['role'] == 'user':
|
| 123 |
+
formatted_messages.append(HumanMessage(content=msg['content']))
|
| 124 |
+
elif msg['role'] == 'assistant':
|
| 125 |
+
formatted_messages.append(AIMessage(content=msg['content']))
|
| 126 |
+
|
| 127 |
+
# Try each Groq API key sequentially
|
| 128 |
+
for key_index, api_key in enumerate(self.api_keys, 1):
|
| 129 |
+
try:
|
| 130 |
+
print(f"🔑 Trying Groq API Key {key_index}/{len(self.api_keys)} with {model_name}...")
|
| 131 |
+
|
| 132 |
+
# Create LLM instance with current key
|
| 133 |
+
llm = self._create_llm(api_key, model_name)
|
| 134 |
+
|
| 135 |
+
# Generate response
|
| 136 |
+
response = await llm.ainvoke(formatted_messages)
|
| 137 |
+
|
| 138 |
+
# Track rate limits
|
| 139 |
+
self.requests_this_minute += 1
|
| 140 |
+
|
| 141 |
+
print(f"✅ Groq API Key {key_index} succeeded")
|
| 142 |
+
return response.content
|
| 143 |
+
|
| 144 |
+
except Exception as e:
|
| 145 |
+
print(f"❌ Groq API Key {key_index} failed: {e}")
|
| 146 |
+
|
| 147 |
+
# If this was the last key, raise exception
|
| 148 |
+
if key_index == len(self.api_keys):
|
| 149 |
+
print(f"❌ All {len(self.api_keys)} Groq API keys exhausted")
|
| 150 |
+
raise Exception(f"All Groq API keys failed. Last error: {e}")
|
| 151 |
+
|
| 152 |
+
# Otherwise, continue to next key
|
| 153 |
+
print(f"⏭️ Falling back to next Groq API key...")
|
| 154 |
+
continue
|
| 155 |
+
|
| 156 |
+
# ============================================================================
|
| 157 |
+
# HUGGINGFACE MANAGER WITH FALLBACK
|
| 158 |
+
# ============================================================================
|
| 159 |
+
class HuggingFaceManager:
|
| 160 |
+
"""
|
| 161 |
+
HuggingFace Inference API Manager with multiple token fallback support
|
| 162 |
+
Uses InferenceClient from huggingface_hub
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
def __init__(self):
|
| 166 |
+
"""Initialize HuggingFace manager with all available tokens"""
|
| 167 |
+
self.tokens = settings.get_hf_tokens()
|
| 168 |
+
self.chat_model_name = settings.HF_CHAT_MODEL
|
| 169 |
+
self.eval_model_name = settings.HF_EVAL_MODEL
|
| 170 |
+
|
| 171 |
+
if not self.tokens:
|
| 172 |
+
raise ValueError("No HuggingFace tokens configured. Set HF_TOKEN_1 in .env")
|
| 173 |
+
|
| 174 |
+
print(f"✅ HuggingFace Manager initialized with {len(self.tokens)} token(s)")
|
| 175 |
+
print(f" Chat Model: {self.chat_model_name}")
|
| 176 |
+
print(f" Eval Model: {self.eval_model_name}")
|
| 177 |
+
|
| 178 |
+
def _create_client(self, token: str, model_name: str) -> InferenceClient:
|
| 179 |
+
"""Create InferenceClient instance with given token and model"""
|
| 180 |
+
return InferenceClient(
|
| 181 |
+
model=model_name,
|
| 182 |
+
token=token
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
async def generate(
|
| 186 |
+
self,
|
| 187 |
+
messages: List[Dict[str, str]],
|
| 188 |
+
system_prompt: Optional[str] = None,
|
| 189 |
+
task: Literal["chat", "evaluation"] = "chat"
|
| 190 |
+
) -> str:
|
| 191 |
+
"""
|
| 192 |
+
Generate response using HuggingFace Inference API with fallback logic.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
messages: List of conversation messages
|
| 196 |
+
system_prompt: Optional system prompt
|
| 197 |
+
task: Task type to determine model
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
str: Generated response text
|
| 201 |
+
|
| 202 |
+
Raises:
|
| 203 |
+
Exception: If all HuggingFace tokens fail
|
| 204 |
+
"""
|
| 205 |
+
# Select model based on task
|
| 206 |
+
model_name = self.eval_model_name if task == "evaluation" else self.chat_model_name
|
| 207 |
+
|
| 208 |
+
# Format messages for HuggingFace chat API
|
| 209 |
+
formatted_messages = []
|
| 210 |
+
|
| 211 |
+
# Add system message if provided
|
| 212 |
+
if system_prompt:
|
| 213 |
+
formatted_messages.append({
|
| 214 |
+
"role": "system",
|
| 215 |
+
"content": system_prompt
|
| 216 |
+
})
|
| 217 |
+
|
| 218 |
+
# Convert conversation messages
|
| 219 |
+
for msg in messages:
|
| 220 |
+
formatted_messages.append({
|
| 221 |
+
"role": msg['role'],
|
| 222 |
+
"content": msg['content']
|
| 223 |
+
})
|
| 224 |
+
|
| 225 |
+
# Try each HuggingFace token sequentially
|
| 226 |
+
for token_index, token in enumerate(self.tokens, 1):
|
| 227 |
+
try:
|
| 228 |
+
print(f"🔑 Trying HuggingFace Token {token_index}/{len(self.tokens)} with {model_name}...")
|
| 229 |
+
|
| 230 |
+
# Create client with current token
|
| 231 |
+
client = self._create_client(token, model_name)
|
| 232 |
+
|
| 233 |
+
# Generate response using chat completion
|
| 234 |
+
response = client.chat_completion(
|
| 235 |
+
messages=formatted_messages,
|
| 236 |
+
max_tokens=settings.LLM_MAX_TOKENS,
|
| 237 |
+
temperature=settings.LLM_TEMPERATURE
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# Extract content from response
|
| 241 |
+
content = response.choices[0].message.content
|
| 242 |
+
|
| 243 |
+
print(f"✅ HuggingFace Token {token_index} succeeded")
|
| 244 |
+
return content
|
| 245 |
+
|
| 246 |
+
except Exception as e:
|
| 247 |
+
print(f"❌ HuggingFace Token {token_index} failed: {e}")
|
| 248 |
+
|
| 249 |
+
# If this was the last token, raise exception
|
| 250 |
+
if token_index == len(self.tokens):
|
| 251 |
+
print(f"❌ All {len(self.tokens)} HuggingFace tokens exhausted")
|
| 252 |
+
raise Exception(f"All HuggingFace tokens failed. Last error: {e}")
|
| 253 |
+
|
| 254 |
+
# Otherwise, continue to next token
|
| 255 |
+
print(f"⏭️ Falling back to next HuggingFace token...")
|
| 256 |
+
continue
|
| 257 |
+
|
| 258 |
+
# ============================================================================
|
| 259 |
+
# UNIFIED LLM MANAGER (Groq Primary, HuggingFace Fallback)
|
| 260 |
+
# ============================================================================
|
| 261 |
+
class LLMManager:
|
| 262 |
+
"""
|
| 263 |
+
Unified LLM Manager with cascading fallback logic:
|
| 264 |
+
1. Try all Groq API keys (primary)
|
| 265 |
+
2. If all fail, try all HuggingFace tokens (fallback)
|
| 266 |
+
|
| 267 |
+
Models:
|
| 268 |
+
- Chat: Llama 3 8B (for user-facing chat responses)
|
| 269 |
+
- Evaluation: Llama 3 70B (for response evaluation)
|
| 270 |
+
"""
|
| 271 |
+
|
| 272 |
+
def __init__(self):
|
| 273 |
+
"""Initialize all LLM managers"""
|
| 274 |
+
self.groq = None
|
| 275 |
+
self.huggingface = None
|
| 276 |
+
|
| 277 |
+
# Initialize Groq if configured
|
| 278 |
+
if settings.is_groq_enabled():
|
| 279 |
+
try:
|
| 280 |
+
self.groq = GroqManager()
|
| 281 |
+
except Exception as e:
|
| 282 |
+
print(f"⚠️ Failed to initialize Groq: {e}")
|
| 283 |
+
|
| 284 |
+
# Initialize HuggingFace if configured
|
| 285 |
+
if settings.is_hf_enabled():
|
| 286 |
+
try:
|
| 287 |
+
self.huggingface = HuggingFaceManager()
|
| 288 |
+
except Exception as e:
|
| 289 |
+
print(f"⚠️ Failed to initialize HuggingFace: {e}")
|
| 290 |
+
|
| 291 |
+
# Check if at least one is available
|
| 292 |
+
if not self.groq and not self.huggingface:
|
| 293 |
+
raise ValueError("No LLM provider configured. Set either Groq or HuggingFace credentials in .env")
|
| 294 |
+
|
| 295 |
+
print("✅ LLM Manager initialized with fallback logic")
|
| 296 |
+
|
| 297 |
+
async def generate(
|
| 298 |
+
self,
|
| 299 |
+
messages: List[Dict[str, str]],
|
| 300 |
+
system_prompt: Optional[str] = None,
|
| 301 |
+
task: Literal["chat", "evaluation"] = "chat"
|
| 302 |
+
) -> str:
|
| 303 |
+
"""
|
| 304 |
+
Generate response with cascading fallback logic.
|
| 305 |
+
|
| 306 |
+
Fallback order:
|
| 307 |
+
1. Try all Groq API keys (3 keys)
|
| 308 |
+
2. If all Groq keys fail, try all HuggingFace tokens (3 tokens)
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
messages: Conversation messages
|
| 312 |
+
system_prompt: Optional system prompt
|
| 313 |
+
task: Task type - "chat" (8B) or "evaluation" (70B)
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
str: Generated response
|
| 317 |
+
|
| 318 |
+
Raises:
|
| 319 |
+
ValueError: If all providers fail
|
| 320 |
+
"""
|
| 321 |
+
# Try Groq first (if available)
|
| 322 |
+
if self.groq:
|
| 323 |
+
try:
|
| 324 |
+
print("🚀 Attempting Groq API (Primary)...")
|
| 325 |
+
response = await self.groq.generate(messages, system_prompt, task)
|
| 326 |
+
return response
|
| 327 |
+
except Exception as groq_error:
|
| 328 |
+
print(f"❌ All Groq API keys failed: {groq_error}")
|
| 329 |
+
|
| 330 |
+
# Fall back to HuggingFace if available
|
| 331 |
+
if self.huggingface:
|
| 332 |
+
print("🔄 Falling back to HuggingFace Inference API...")
|
| 333 |
+
else:
|
| 334 |
+
raise ValueError(f"Groq failed and no HuggingFace fallback configured: {groq_error}")
|
| 335 |
+
|
| 336 |
+
# Try HuggingFace (if Groq failed or not available)
|
| 337 |
+
if self.huggingface:
|
| 338 |
+
try:
|
| 339 |
+
print("🚀 Attempting HuggingFace API (Fallback)...")
|
| 340 |
+
response = await self.huggingface.generate(messages, system_prompt, task)
|
| 341 |
+
return response
|
| 342 |
+
except Exception as hf_error:
|
| 343 |
+
raise ValueError(f"All LLM providers exhausted. HuggingFace error: {hf_error}")
|
| 344 |
+
|
| 345 |
+
raise ValueError("No LLM provider available")
|
| 346 |
+
|
| 347 |
+
async def generate_chat_response(
|
| 348 |
+
self,
|
| 349 |
+
query: str,
|
| 350 |
+
context: str,
|
| 351 |
+
history: List[Dict[str, str]]
|
| 352 |
+
) -> str:
|
| 353 |
+
"""
|
| 354 |
+
Generate chat response (uses Llama 3 8B).
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
query: User query
|
| 358 |
+
context: Retrieved context (from FAISS)
|
| 359 |
+
history: Conversation history
|
| 360 |
+
|
| 361 |
+
Returns:
|
| 362 |
+
str: Chat response
|
| 363 |
+
"""
|
| 364 |
+
# Import the detailed prompt
|
| 365 |
+
from app.services.chat_service import BANKING_SYSTEM_PROMPT
|
| 366 |
+
|
| 367 |
+
# Build enhanced system prompt with context
|
| 368 |
+
system_prompt = BANKING_SYSTEM_PROMPT
|
| 369 |
+
if context:
|
| 370 |
+
system_prompt += f"\n\nRelevant Knowledge Base Context:\n{context}"
|
| 371 |
+
else:
|
| 372 |
+
system_prompt += "\n\nNo specific banking documents were retrieved for this query. Provide a helpful general response while acknowledging your banking specialization."
|
| 373 |
+
|
| 374 |
+
# Build messages
|
| 375 |
+
messages = history + [{'role': 'user', 'content': query}]
|
| 376 |
+
|
| 377 |
+
# Generate using chat task (Llama 3 8B)
|
| 378 |
+
return await self.generate(messages, system_prompt, task="chat")
|
| 379 |
+
|
| 380 |
+
async def evaluate_response(
|
| 381 |
+
self,
|
| 382 |
+
query: str,
|
| 383 |
+
response: str,
|
| 384 |
+
context: str = ""
|
| 385 |
+
) -> Dict:
|
| 386 |
+
"""
|
| 387 |
+
Evaluate response quality (uses Llama 3 70B for better evaluation).
|
| 388 |
+
Used during RL training.
|
| 389 |
+
|
| 390 |
+
Args:
|
| 391 |
+
query: User query
|
| 392 |
+
response: Generated response
|
| 393 |
+
context: Retrieved context (if any)
|
| 394 |
+
|
| 395 |
+
Returns:
|
| 396 |
+
dict: Evaluation results
|
| 397 |
+
{'quality': 'Good'/'Bad', 'explanation': '...'}
|
| 398 |
+
"""
|
| 399 |
+
eval_prompt = f"""Evaluate this response:
|
| 400 |
+
|
| 401 |
+
Query: {query}
|
| 402 |
+
Response: {response}
|
| 403 |
+
Context used: {context if context else 'None'}
|
| 404 |
+
|
| 405 |
+
Is this response Good or Bad? Respond with just "Good" or "Bad" and brief explanation."""
|
| 406 |
+
|
| 407 |
+
messages = [{'role': 'user', 'content': eval_prompt}]
|
| 408 |
+
|
| 409 |
+
# Generate using evaluation task (Llama 3 70B)
|
| 410 |
+
result = await self.generate(messages, task="evaluation")
|
| 411 |
+
|
| 412 |
+
# Parse result
|
| 413 |
+
quality = "Good" if "Good" in result else "Bad"
|
| 414 |
+
|
| 415 |
+
return {
|
| 416 |
+
'quality': quality,
|
| 417 |
+
'explanation': result
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
# ============================================================================
|
| 421 |
+
# GLOBAL LLM MANAGER INSTANCE
|
| 422 |
+
# ============================================================================
|
| 423 |
+
llm_manager = LLMManager()
|
| 424 |
+
|
| 425 |
+
# ============================================================================
|
| 426 |
+
# USAGE EXAMPLE (for reference)
|
| 427 |
+
# ============================================================================
|
| 428 |
+
"""
|
| 429 |
+
# In your service file:
|
| 430 |
+
from app.core.llm_manager import llm_manager
|
| 431 |
+
|
| 432 |
+
# Generate chat response (uses Llama 3 8B with Groq → HF fallback)
|
| 433 |
+
response = await llm_manager.generate_chat_response(
|
| 434 |
+
query="What is my account balance?",
|
| 435 |
+
context="Your balance is $1000",
|
| 436 |
+
history=[]
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# Evaluate response (uses Llama 3 70B with Groq → HF fallback)
|
| 440 |
+
evaluation = await llm_manager.evaluate_response(
|
| 441 |
+
query="What is my balance?",
|
| 442 |
+
response="Your balance is $1000",
|
| 443 |
+
context="Balance: $1000"
|
| 444 |
+
)
|
| 445 |
+
"""
|
app/db/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Database package - Database connections and repositories
|
| 3 |
+
Contains MongoDB connection and repository pattern implementations
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
__version__ = "1.0.0"
|
app/db/mongodb.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MongoDB Connection with Motor (Async Driver)
|
| 3 |
+
Handles async connection to MongoDB Atlas for conversation storage
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import motor.motor_asyncio
|
| 7 |
+
from app.config import settings
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# ============================================================================
|
| 11 |
+
# GLOBAL VARIABLES
|
| 12 |
+
# ============================================================================
|
| 13 |
+
mongodb_client = None
|
| 14 |
+
mongodb_database = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ============================================================================
|
| 18 |
+
# CONNECTION FUNCTIONS
|
| 19 |
+
# ============================================================================
|
| 20 |
+
|
| 21 |
+
async def connect_to_mongo():
|
| 22 |
+
"""
|
| 23 |
+
Connect to MongoDB Atlas on application startup.
|
| 24 |
+
|
| 25 |
+
This is called from main.py during FastAPI lifespan startup.
|
| 26 |
+
Uses Motor for async MongoDB operations.
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
database: MongoDB database instance or None if connection fails
|
| 30 |
+
"""
|
| 31 |
+
global mongodb_client, mongodb_database
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
print("\n🔌 Connecting to MongoDB Atlas...")
|
| 35 |
+
|
| 36 |
+
# Hide password in logs
|
| 37 |
+
# uri_display = settings.MONGODB_URI[:50] + "..." if len(settings.MONGODB_URI) > 50 else settings.MONGODB_URI
|
| 38 |
+
# print(f" URI: {uri_display}")
|
| 39 |
+
print(f" Database: {settings.DATABASE_NAME}")
|
| 40 |
+
|
| 41 |
+
# Create Motor client (async MongoDB driver)
|
| 42 |
+
mongodb_client = motor.motor_asyncio.AsyncIOMotorClient(
|
| 43 |
+
settings.MONGODB_URI,
|
| 44 |
+
serverSelectionTimeoutMS=5000, # 5 second timeout
|
| 45 |
+
connectTimeoutMS=10000, # 10 second connection timeout
|
| 46 |
+
socketTimeoutMS=10000 # 10 second socket timeout
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Get database reference
|
| 50 |
+
mongodb_database = mongodb_client[settings.DATABASE_NAME]
|
| 51 |
+
|
| 52 |
+
# Test connection with ping
|
| 53 |
+
await mongodb_client.admin.command('ping')
|
| 54 |
+
|
| 55 |
+
print(f"✅ MongoDB connected successfully!")
|
| 56 |
+
print(f" Database: {settings.DATABASE_NAME}")
|
| 57 |
+
|
| 58 |
+
return mongodb_database
|
| 59 |
+
|
| 60 |
+
except Exception as e:
|
| 61 |
+
print(f"\n❌ MongoDB connection FAILED!")
|
| 62 |
+
print(f" Error: {str(e)}")
|
| 63 |
+
print(f"\n💡 Troubleshooting:")
|
| 64 |
+
print(f" 1. Check MONGODB_URI in .env file")
|
| 65 |
+
print(f" 2. Verify MongoDB Atlas cluster is running")
|
| 66 |
+
print(f" 3. Check network access settings (allow your IP)")
|
| 67 |
+
print(f" 4. Verify database user credentials")
|
| 68 |
+
print(f"\n⚠️ Backend will start but MongoDB features won't work!\n")
|
| 69 |
+
|
| 70 |
+
# Set to None (app can still start for debugging)
|
| 71 |
+
mongodb_database = None
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
async def close_mongo_connection():
|
| 76 |
+
"""
|
| 77 |
+
Close MongoDB connection on application shutdown.
|
| 78 |
+
|
| 79 |
+
This is called from main.py during FastAPI lifespan shutdown.
|
| 80 |
+
"""
|
| 81 |
+
global mongodb_client
|
| 82 |
+
|
| 83 |
+
if mongodb_client:
|
| 84 |
+
print("\n🔌 Closing MongoDB connection...")
|
| 85 |
+
mongodb_client.close()
|
| 86 |
+
print("✅ MongoDB connection closed")
|
| 87 |
+
else:
|
| 88 |
+
print("ℹ️ No MongoDB connection to close")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_database():
|
| 92 |
+
"""
|
| 93 |
+
Get MongoDB database instance.
|
| 94 |
+
|
| 95 |
+
This is used by repositories to access the database.
|
| 96 |
+
Returns None if MongoDB is not connected (for graceful degradation).
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
database: MongoDB database instance or None
|
| 100 |
+
"""
|
| 101 |
+
if mongodb_database is None:
|
| 102 |
+
print("\n⚠️ WARNING: MongoDB database not available!")
|
| 103 |
+
print(" Attempting to use database features without connection")
|
| 104 |
+
print(" Make sure MongoDB connection succeeded during startup\n")
|
| 105 |
+
|
| 106 |
+
return mongodb_database
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ============================================================================
|
| 110 |
+
# USAGE EXAMPLE (for reference)
|
| 111 |
+
# ============================================================================
|
| 112 |
+
"""
|
| 113 |
+
# In main.py (FastAPI lifespan):
|
| 114 |
+
|
| 115 |
+
from app.db.mongodb import connect_to_mongo, close_mongo_connection
|
| 116 |
+
|
| 117 |
+
@asynccontextmanager
|
| 118 |
+
async def lifespan(app: FastAPI):
|
| 119 |
+
# Startup
|
| 120 |
+
await connect_to_mongo()
|
| 121 |
+
yield
|
| 122 |
+
# Shutdown
|
| 123 |
+
await close_mongo_connection()
|
| 124 |
+
|
| 125 |
+
# In repositories:
|
| 126 |
+
|
| 127 |
+
from app.db.mongodb import get_database
|
| 128 |
+
|
| 129 |
+
class SomeRepository:
|
| 130 |
+
def __init__(self):
|
| 131 |
+
self.db = get_database()
|
| 132 |
+
if self.db:
|
| 133 |
+
self.collection = self.db["my_collection"]
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# """
|
| 139 |
+
# MongoDB Connection Handler
|
| 140 |
+
# Manages connection to MongoDB Atlas (cloud database)
|
| 141 |
+
|
| 142 |
+
# This uses Motor - an async MongoDB driver for Python
|
| 143 |
+
# Works perfectly with FastAPI's async/await
|
| 144 |
+
# """
|
| 145 |
+
|
| 146 |
+
# from motor.motor_asyncio import AsyncIOMotorClient
|
| 147 |
+
# from app.config import settings
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# # ============================================================================
|
| 151 |
+
# # MONGODB CLIENT SINGLETON
|
| 152 |
+
# # ============================================================================
|
| 153 |
+
# class MongoDB:
|
| 154 |
+
# """
|
| 155 |
+
# MongoDB client singleton.
|
| 156 |
+
# Stores the connection and database instance.
|
| 157 |
+
|
| 158 |
+
# Attributes:
|
| 159 |
+
# client: Motor async client connection
|
| 160 |
+
# db: Database instance
|
| 161 |
+
# """
|
| 162 |
+
# client: AsyncIOMotorClient = None
|
| 163 |
+
# db = None
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# # Create global instance
|
| 167 |
+
# mongodb = MongoDB()
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# # ============================================================================
|
| 171 |
+
# # CONNECTION FUNCTIONS
|
| 172 |
+
# # ============================================================================
|
| 173 |
+
|
| 174 |
+
# async def connect_to_mongo():
|
| 175 |
+
# """
|
| 176 |
+
# Connect to MongoDB Atlas on application startup.
|
| 177 |
+
|
| 178 |
+
# This establishes a connection pool that will be reused
|
| 179 |
+
# for all database operations throughout the app's lifetime.
|
| 180 |
+
|
| 181 |
+
# Connection string format (from .env):
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# Raises:
|
| 185 |
+
# Exception: If connection fails
|
| 186 |
+
# """
|
| 187 |
+
# try:
|
| 188 |
+
# # Create async MongoDB client
|
| 189 |
+
# # serverSelectionTimeoutMS: How long to wait before giving up (5 seconds)
|
| 190 |
+
# mongodb.client = AsyncIOMotorClient(
|
| 191 |
+
# settings.MONGODB_URI,
|
| 192 |
+
# serverSelectionTimeoutMS=5000
|
| 193 |
+
# )
|
| 194 |
+
|
| 195 |
+
# # Get database instance
|
| 196 |
+
# mongodb.db = mongodb.client[settings.DATABASE_NAME]
|
| 197 |
+
|
| 198 |
+
# # Verify connection by pinging the database
|
| 199 |
+
# await mongodb.client.admin.command('ping')
|
| 200 |
+
|
| 201 |
+
# print(f"✅ Connected to MongoDB Atlas")
|
| 202 |
+
# print(f" Database: {settings.DATABASE_NAME}")
|
| 203 |
+
|
| 204 |
+
# except Exception as e:
|
| 205 |
+
# print(f"❌ MongoDB connection failed: {e}")
|
| 206 |
+
# raise
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# async def close_mongo_connection():
|
| 210 |
+
# """
|
| 211 |
+
# Close MongoDB connection on application shutdown.
|
| 212 |
+
|
| 213 |
+
# This properly closes the connection pool and releases resources.
|
| 214 |
+
# """
|
| 215 |
+
# if mongodb.client:
|
| 216 |
+
# mongodb.client.close()
|
| 217 |
+
# print("✅ MongoDB connection closed")
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# def get_database():
|
| 221 |
+
# """
|
| 222 |
+
# Get the current database instance.
|
| 223 |
+
|
| 224 |
+
# This function is used by repositories to access the database.
|
| 225 |
+
|
| 226 |
+
# Returns:
|
| 227 |
+
# AsyncIOMotorDatabase: Database instance
|
| 228 |
+
|
| 229 |
+
# Example:
|
| 230 |
+
# from app.db.mongodb import get_database
|
| 231 |
+
|
| 232 |
+
# db = get_database()
|
| 233 |
+
# collection = db["conversations"]
|
| 234 |
+
# result = await collection.find_one({"user_id": "123"})
|
| 235 |
+
# """
|
| 236 |
+
# return mongodb.db
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# # ============================================================================
|
| 240 |
+
# # HELPER FUNCTIONS
|
| 241 |
+
# # ============================================================================
|
| 242 |
+
|
| 243 |
+
# async def check_connection() -> bool:
|
| 244 |
+
# """
|
| 245 |
+
# Check if MongoDB connection is alive.
|
| 246 |
+
|
| 247 |
+
# Returns:
|
| 248 |
+
# bool: True if connected, False otherwise
|
| 249 |
+
# """
|
| 250 |
+
# try:
|
| 251 |
+
# if mongodb.client is None:
|
| 252 |
+
# return False
|
| 253 |
+
|
| 254 |
+
# # Try to ping the database
|
| 255 |
+
# await mongodb.client.admin.command('ping')
|
| 256 |
+
# return True
|
| 257 |
+
# except Exception:
|
| 258 |
+
# return False
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# async def get_collection_names():
|
| 262 |
+
# """
|
| 263 |
+
# Get list of all collection names in the database.
|
| 264 |
+
# Useful for debugging and admin operations.
|
| 265 |
+
|
| 266 |
+
# Returns:
|
| 267 |
+
# list: List of collection names
|
| 268 |
+
# """
|
| 269 |
+
# if mongodb.db is None:
|
| 270 |
+
# return []
|
| 271 |
+
|
| 272 |
+
# return await mongodb.db.list_collection_names()
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# async def create_indexes():
|
| 276 |
+
# """
|
| 277 |
+
# Create database indexes for better query performance.
|
| 278 |
+
|
| 279 |
+
# This should be called once after first deployment.
|
| 280 |
+
# Indexes speed up queries on specific fields.
|
| 281 |
+
|
| 282 |
+
# Collections and their indexes:
|
| 283 |
+
# 1. conversations:
|
| 284 |
+
# - conversation_id (unique)
|
| 285 |
+
# - user_id (for user queries)
|
| 286 |
+
# - created_at (for sorting)
|
| 287 |
+
|
| 288 |
+
# 2. users:
|
| 289 |
+
# - user_id (unique)
|
| 290 |
+
# - email (unique)
|
| 291 |
+
|
| 292 |
+
# 3. retrieval_logs:
|
| 293 |
+
# - log_id (unique)
|
| 294 |
+
# - timestamp (for time-series queries)
|
| 295 |
+
# """
|
| 296 |
+
# db = get_database()
|
| 297 |
+
|
| 298 |
+
# # Conversations collection
|
| 299 |
+
# conversations = db["conversations"]
|
| 300 |
+
# await conversations.create_index("conversation_id", unique=True)
|
| 301 |
+
# await conversations.create_index("user_id")
|
| 302 |
+
# await conversations.create_index("created_at")
|
| 303 |
+
# print("✅ Created indexes for 'conversations' collection")
|
| 304 |
+
|
| 305 |
+
# # Users collection
|
| 306 |
+
# users = db["users"]
|
| 307 |
+
# await users.create_index("user_id", unique=True)
|
| 308 |
+
# await users.create_index("email", unique=True)
|
| 309 |
+
# print("✅ Created indexes for 'users' collection")
|
| 310 |
+
|
| 311 |
+
# # Retrieval logs collection
|
| 312 |
+
# retrieval_logs = db["retrieval_logs"]
|
| 313 |
+
# await retrieval_logs.create_index("log_id", unique=True)
|
| 314 |
+
# await retrieval_logs.create_index("timestamp")
|
| 315 |
+
# await retrieval_logs.create_index("user_id")
|
| 316 |
+
# print("✅ Created indexes for 'retrieval_logs' collection")
|
| 317 |
+
|
| 318 |
+
# print("✅ All database indexes created successfully")
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# # ============================================================================
|
| 322 |
+
# # USAGE EXAMPLES (for reference)
|
| 323 |
+
# # ============================================================================
|
| 324 |
+
# """
|
| 325 |
+
# # In your repository or service:
|
| 326 |
+
|
| 327 |
+
# from app.db.mongodb import get_database
|
| 328 |
+
|
| 329 |
+
# async def get_user_conversations(user_id: str):
|
| 330 |
+
# db = get_database()
|
| 331 |
+
# conversations = db["conversations"]
|
| 332 |
+
|
| 333 |
+
# cursor = conversations.find({"user_id": user_id})
|
| 334 |
+
# results = await cursor.to_list(length=10)
|
| 335 |
+
|
| 336 |
+
# return results
|
| 337 |
+
|
| 338 |
+
# # In main.py startup:
|
| 339 |
+
# from app.db.mongodb import connect_to_mongo, create_indexes
|
| 340 |
+
|
| 341 |
+
# await connect_to_mongo()
|
| 342 |
+
# await create_indexes() # Run once on first deployment
|
| 343 |
+
# """
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
# # Key Features:
|
| 368 |
+
# # Async/Await - Works with FastAPI's async nature
|
| 369 |
+
|
| 370 |
+
# # Connection Pooling - Reuses connections efficiently
|
| 371 |
+
|
| 372 |
+
# # Singleton Pattern - One connection for entire app
|
| 373 |
+
|
| 374 |
+
# # MongoDB Atlas Compatible - Works with cloud MongoDB
|
| 375 |
+
|
| 376 |
+
# # Index Creation - Optimizes query performance
|
| 377 |
+
|
| 378 |
+
# # Health Check - check_connection() for monitoring
|
| 379 |
+
|
| 380 |
+
# # How to Use:
|
| 381 |
+
# # python
|
| 382 |
+
# # # In any repository/service file:
|
| 383 |
+
|
| 384 |
+
# # from app.db.mongodb import get_database
|
| 385 |
+
|
| 386 |
+
# # async def save_conversation(data: dict):
|
| 387 |
+
# # db = get_database()
|
| 388 |
+
# # collection = db["conversations"]
|
| 389 |
+
# # result = await collection.insert_one(data)
|
| 390 |
+
# # return str(result.inserted_id)
|
app/db/repositories/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Repositories package - Data access layer
|
| 3 |
+
Contains repository classes for database operations (CRUD)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
__version__ = "1.0.0"
|
app/db/repositories/conversation_repository.py
ADDED
|
@@ -0,0 +1,1049 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# """
|
| 2 |
+
# Conversation Repository - MongoDB CRUD operations
|
| 3 |
+
# Handles storing and retrieving conversations from MongoDB Atlas
|
| 4 |
+
|
| 5 |
+
# Repository Pattern: Separates database logic from business logic
|
| 6 |
+
# This makes code cleaner and easier to test
|
| 7 |
+
# """
|
| 8 |
+
|
| 9 |
+
# import uuid
|
| 10 |
+
# from datetime import datetime
|
| 11 |
+
# from typing import List, Dict, Optional
|
| 12 |
+
# from bson import ObjectId
|
| 13 |
+
|
| 14 |
+
# from app.db.mongodb import get_database
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# # ============================================================================
|
| 18 |
+
# # CONVERSATION REPOSITORY
|
| 19 |
+
# # ============================================================================
|
| 20 |
+
|
| 21 |
+
# class ConversationRepository:
|
| 22 |
+
# """
|
| 23 |
+
# Repository for conversation data in MongoDB.
|
| 24 |
+
|
| 25 |
+
# Collections used:
|
| 26 |
+
# - conversations: Stores complete conversations with messages
|
| 27 |
+
# - retrieval_logs: Logs each retrieval operation (for RL training)
|
| 28 |
+
# """
|
| 29 |
+
|
| 30 |
+
# def __init__(self):
|
| 31 |
+
# """Initialize repository with database connection"""
|
| 32 |
+
# self.db = get_database()
|
| 33 |
+
# self.conversations = self.db["conversations"]
|
| 34 |
+
# self.retrieval_logs = self.db["retrieval_logs"]
|
| 35 |
+
|
| 36 |
+
# # ========================================================================
|
| 37 |
+
# # CONVERSATION CRUD OPERATIONS
|
| 38 |
+
# # ========================================================================
|
| 39 |
+
|
| 40 |
+
# async def create_conversation(
|
| 41 |
+
# self,
|
| 42 |
+
# user_id: str,
|
| 43 |
+
# conversation_id: Optional[str] = None
|
| 44 |
+
# ) -> str:
|
| 45 |
+
# """
|
| 46 |
+
# Create a new conversation.
|
| 47 |
+
|
| 48 |
+
# Args:
|
| 49 |
+
# user_id: User ID who owns this conversation
|
| 50 |
+
# conversation_id: Optional custom conversation ID (auto-generated if None)
|
| 51 |
+
|
| 52 |
+
# Returns:
|
| 53 |
+
# str: Conversation ID
|
| 54 |
+
# """
|
| 55 |
+
# if conversation_id is None:
|
| 56 |
+
# conversation_id = str(uuid.uuid4())
|
| 57 |
+
|
| 58 |
+
# conversation = {
|
| 59 |
+
# "conversation_id": conversation_id,
|
| 60 |
+
# "user_id": user_id,
|
| 61 |
+
# "messages": [], # Will store all messages
|
| 62 |
+
# "created_at": datetime.now(),
|
| 63 |
+
# "updated_at": datetime.now(),
|
| 64 |
+
# "status": "active" # active, archived, deleted
|
| 65 |
+
# }
|
| 66 |
+
|
| 67 |
+
# await self.conversations.insert_one(conversation)
|
| 68 |
+
|
| 69 |
+
# return conversation_id
|
| 70 |
+
|
| 71 |
+
# async def get_conversation(self, conversation_id: str) -> Optional[Dict]:
|
| 72 |
+
# """
|
| 73 |
+
# Get a conversation by ID.
|
| 74 |
+
|
| 75 |
+
# Args:
|
| 76 |
+
# conversation_id: Conversation ID
|
| 77 |
+
|
| 78 |
+
# Returns:
|
| 79 |
+
# dict or None: Conversation document
|
| 80 |
+
# """
|
| 81 |
+
# conversation = await self.conversations.find_one(
|
| 82 |
+
# {"conversation_id": conversation_id}
|
| 83 |
+
# )
|
| 84 |
+
|
| 85 |
+
# # Convert MongoDB ObjectId to string for JSON serialization
|
| 86 |
+
# if conversation and "_id" in conversation:
|
| 87 |
+
# conversation["_id"] = str(conversation["_id"])
|
| 88 |
+
|
| 89 |
+
# return conversation
|
| 90 |
+
|
| 91 |
+
# async def get_user_conversations(
|
| 92 |
+
# self,
|
| 93 |
+
# user_id: str,
|
| 94 |
+
# limit: int = 10,
|
| 95 |
+
# skip: int = 0
|
| 96 |
+
# ) -> List[Dict]:
|
| 97 |
+
# """
|
| 98 |
+
# Get all conversations for a user.
|
| 99 |
+
|
| 100 |
+
# Args:
|
| 101 |
+
# user_id: User ID
|
| 102 |
+
# limit: Maximum number of conversations to return
|
| 103 |
+
# skip: Number of conversations to skip (for pagination)
|
| 104 |
+
|
| 105 |
+
# Returns:
|
| 106 |
+
# list: List of conversation documents
|
| 107 |
+
# """
|
| 108 |
+
# cursor = self.conversations.find(
|
| 109 |
+
# {"user_id": user_id, "status": "active"}
|
| 110 |
+
# ).sort("updated_at", -1).skip(skip).limit(limit)
|
| 111 |
+
|
| 112 |
+
# conversations = await cursor.to_list(length=limit)
|
| 113 |
+
|
| 114 |
+
# # Convert ObjectIds to strings
|
| 115 |
+
# for conv in conversations:
|
| 116 |
+
# if "_id" in conv:
|
| 117 |
+
# conv["_id"] = str(conv["_id"])
|
| 118 |
+
|
| 119 |
+
# return conversations
|
| 120 |
+
|
| 121 |
+
# async def add_message(
|
| 122 |
+
# self,
|
| 123 |
+
# conversation_id: str,
|
| 124 |
+
# message: Dict
|
| 125 |
+
# ) -> bool:
|
| 126 |
+
# """
|
| 127 |
+
# Add a message to a conversation.
|
| 128 |
+
|
| 129 |
+
# Args:
|
| 130 |
+
# conversation_id: Conversation ID
|
| 131 |
+
# message: Message dict
|
| 132 |
+
# {
|
| 133 |
+
# 'role': 'user' or 'assistant',
|
| 134 |
+
# 'content': str,
|
| 135 |
+
# 'timestamp': datetime,
|
| 136 |
+
# 'metadata': dict (optional - policy_action, docs_retrieved, etc.)
|
| 137 |
+
# }
|
| 138 |
+
|
| 139 |
+
# Returns:
|
| 140 |
+
# bool: Success status
|
| 141 |
+
# """
|
| 142 |
+
# # Ensure timestamp exists
|
| 143 |
+
# if "timestamp" not in message:
|
| 144 |
+
# message["timestamp"] = datetime.now()
|
| 145 |
+
|
| 146 |
+
# # Add message to conversation
|
| 147 |
+
# result = await self.conversations.update_one(
|
| 148 |
+
# {"conversation_id": conversation_id},
|
| 149 |
+
# {
|
| 150 |
+
# "$push": {"messages": message},
|
| 151 |
+
# "$set": {"updated_at": datetime.now()}
|
| 152 |
+
# }
|
| 153 |
+
# )
|
| 154 |
+
|
| 155 |
+
# return result.modified_count > 0
|
| 156 |
+
|
| 157 |
+
# async def get_conversation_history(
|
| 158 |
+
# self,
|
| 159 |
+
# conversation_id: str,
|
| 160 |
+
# max_messages: int = None
|
| 161 |
+
# ) -> List[Dict]:
|
| 162 |
+
# """
|
| 163 |
+
# Get conversation history (messages only).
|
| 164 |
+
|
| 165 |
+
# Args:
|
| 166 |
+
# conversation_id: Conversation ID
|
| 167 |
+
# max_messages: Optional limit on number of messages
|
| 168 |
+
|
| 169 |
+
# Returns:
|
| 170 |
+
# list: List of messages
|
| 171 |
+
# """
|
| 172 |
+
# conversation = await self.get_conversation(conversation_id)
|
| 173 |
+
|
| 174 |
+
# if not conversation:
|
| 175 |
+
# return []
|
| 176 |
+
|
| 177 |
+
# messages = conversation.get("messages", [])
|
| 178 |
+
|
| 179 |
+
# if max_messages:
|
| 180 |
+
# messages = messages[-max_messages:]
|
| 181 |
+
|
| 182 |
+
# return messages
|
| 183 |
+
|
| 184 |
+
# async def delete_conversation(self, conversation_id: str) -> bool:
|
| 185 |
+
# """
|
| 186 |
+
# Soft delete a conversation (mark as deleted, don't actually delete).
|
| 187 |
+
|
| 188 |
+
# Args:
|
| 189 |
+
# conversation_id: Conversation ID
|
| 190 |
+
|
| 191 |
+
# Returns:
|
| 192 |
+
# bool: Success status
|
| 193 |
+
# """
|
| 194 |
+
# result = await self.conversations.update_one(
|
| 195 |
+
# {"conversation_id": conversation_id},
|
| 196 |
+
# {
|
| 197 |
+
# "$set": {
|
| 198 |
+
# "status": "deleted",
|
| 199 |
+
# "deleted_at": datetime.now()
|
| 200 |
+
# }
|
| 201 |
+
# }
|
| 202 |
+
# )
|
| 203 |
+
|
| 204 |
+
# return result.modified_count > 0
|
| 205 |
+
|
| 206 |
+
# # ========================================================================
|
| 207 |
+
# # RETRIEVAL LOGS (for RL training)
|
| 208 |
+
# # ========================================================================
|
| 209 |
+
|
| 210 |
+
# async def log_retrieval(
|
| 211 |
+
# self,
|
| 212 |
+
# log_data: Dict
|
| 213 |
+
# ) -> str:
|
| 214 |
+
# """
|
| 215 |
+
# Log a retrieval operation (for RL training and analysis).
|
| 216 |
+
|
| 217 |
+
# Args:
|
| 218 |
+
# log_data: Log data dict
|
| 219 |
+
# {
|
| 220 |
+
# 'conversation_id': str,
|
| 221 |
+
# 'user_id': str,
|
| 222 |
+
# 'query': str,
|
| 223 |
+
# 'policy_action': 'FETCH' or 'NO_FETCH',
|
| 224 |
+
# 'policy_confidence': float,
|
| 225 |
+
# 'documents_retrieved': int,
|
| 226 |
+
# 'top_doc_score': float or None,
|
| 227 |
+
# 'retrieved_docs_metadata': list,
|
| 228 |
+
# 'response': str,
|
| 229 |
+
# 'retrieval_time_ms': float,
|
| 230 |
+
# 'generation_time_ms': float,
|
| 231 |
+
# 'total_time_ms': float,
|
| 232 |
+
# 'timestamp': datetime
|
| 233 |
+
# }
|
| 234 |
+
|
| 235 |
+
# Returns:
|
| 236 |
+
# str: Log ID
|
| 237 |
+
# """
|
| 238 |
+
# # Add timestamp if not present
|
| 239 |
+
# if "timestamp" not in log_data:
|
| 240 |
+
# log_data["timestamp"] = datetime.now()
|
| 241 |
+
|
| 242 |
+
# # Generate log ID
|
| 243 |
+
# log_id = str(uuid.uuid4())
|
| 244 |
+
# log_data["log_id"] = log_id
|
| 245 |
+
|
| 246 |
+
# # Insert log
|
| 247 |
+
# await self.retrieval_logs.insert_one(log_data)
|
| 248 |
+
|
| 249 |
+
# return log_id
|
| 250 |
+
|
| 251 |
+
# async def get_retrieval_logs(
|
| 252 |
+
# self,
|
| 253 |
+
# conversation_id: Optional[str] = None,
|
| 254 |
+
# user_id: Optional[str] = None,
|
| 255 |
+
# limit: int = 100,
|
| 256 |
+
# skip: int = 0
|
| 257 |
+
# ) -> List[Dict]:
|
| 258 |
+
# """
|
| 259 |
+
# Get retrieval logs (for analysis and RL training).
|
| 260 |
+
|
| 261 |
+
# Args:
|
| 262 |
+
# conversation_id: Optional filter by conversation
|
| 263 |
+
# user_id: Optional filter by user
|
| 264 |
+
# limit: Maximum number of logs
|
| 265 |
+
# skip: Number of logs to skip
|
| 266 |
+
|
| 267 |
+
# Returns:
|
| 268 |
+
# list: List of log documents
|
| 269 |
+
# """
|
| 270 |
+
# # Build query
|
| 271 |
+
# query = {}
|
| 272 |
+
# if conversation_id:
|
| 273 |
+
# query["conversation_id"] = conversation_id
|
| 274 |
+
# if user_id:
|
| 275 |
+
# query["user_id"] = user_id
|
| 276 |
+
|
| 277 |
+
# # Fetch logs
|
| 278 |
+
# cursor = self.retrieval_logs.find(query).sort("timestamp", -1).skip(skip).limit(limit)
|
| 279 |
+
# logs = await cursor.to_list(length=limit)
|
| 280 |
+
|
| 281 |
+
# # Convert ObjectIds to strings
|
| 282 |
+
# for log in logs:
|
| 283 |
+
# if "_id" in log:
|
| 284 |
+
# log["_id"] = str(log["_id"])
|
| 285 |
+
|
| 286 |
+
# return logs
|
| 287 |
+
|
| 288 |
+
# async def get_logs_for_rl_training(
|
| 289 |
+
# self,
|
| 290 |
+
# min_date: Optional[datetime] = None,
|
| 291 |
+
# limit: int = 1000
|
| 292 |
+
# ) -> List[Dict]:
|
| 293 |
+
# """
|
| 294 |
+
# Get logs specifically for RL training.
|
| 295 |
+
# Filters for logs with both policy decision and retrieval results.
|
| 296 |
+
|
| 297 |
+
# Args:
|
| 298 |
+
# min_date: Optional minimum date for logs
|
| 299 |
+
# limit: Maximum number of logs
|
| 300 |
+
|
| 301 |
+
# Returns:
|
| 302 |
+
# list: List of log documents suitable for RL training
|
| 303 |
+
# """
|
| 304 |
+
# # Build query
|
| 305 |
+
# query = {
|
| 306 |
+
# "policy_action": {"$exists": True},
|
| 307 |
+
# "response": {"$exists": True}
|
| 308 |
+
# }
|
| 309 |
+
|
| 310 |
+
# if min_date:
|
| 311 |
+
# query["timestamp"] = {"$gte": min_date}
|
| 312 |
+
|
| 313 |
+
# # Fetch logs
|
| 314 |
+
# cursor = self.retrieval_logs.find(query).sort("timestamp", -1).limit(limit)
|
| 315 |
+
# logs = await cursor.to_list(length=limit)
|
| 316 |
+
|
| 317 |
+
# # Convert ObjectIds
|
| 318 |
+
# for log in logs:
|
| 319 |
+
# if "_id" in log:
|
| 320 |
+
# log["_id"] = str(log["_id"])
|
| 321 |
+
|
| 322 |
+
# return logs
|
| 323 |
+
|
| 324 |
+
# # ========================================================================
|
| 325 |
+
# # ANALYTICS QUERIES
|
| 326 |
+
# # ========================================================================
|
| 327 |
+
|
| 328 |
+
# async def get_conversation_stats(self, user_id: str) -> Dict:
|
| 329 |
+
# """
|
| 330 |
+
# Get conversation statistics for a user.
|
| 331 |
+
|
| 332 |
+
# Args:
|
| 333 |
+
# user_id: User ID
|
| 334 |
+
|
| 335 |
+
# Returns:
|
| 336 |
+
# dict: Statistics
|
| 337 |
+
# """
|
| 338 |
+
# # Count total conversations
|
| 339 |
+
# total_conversations = await self.conversations.count_documents({
|
| 340 |
+
# "user_id": user_id,
|
| 341 |
+
# "status": "active"
|
| 342 |
+
# })
|
| 343 |
+
|
| 344 |
+
# # Count total messages
|
| 345 |
+
# pipeline = [
|
| 346 |
+
# {"$match": {"user_id": user_id, "status": "active"}},
|
| 347 |
+
# {"$project": {"message_count": {"$size": "$messages"}}}
|
| 348 |
+
# ]
|
| 349 |
+
|
| 350 |
+
# result = await self.conversations.aggregate(pipeline).to_list(length=None)
|
| 351 |
+
# total_messages = sum(doc.get("message_count", 0) for doc in result)
|
| 352 |
+
|
| 353 |
+
# return {
|
| 354 |
+
# "total_conversations": total_conversations,
|
| 355 |
+
# "total_messages": total_messages,
|
| 356 |
+
# "avg_messages_per_conversation": total_messages / total_conversations if total_conversations > 0 else 0
|
| 357 |
+
# }
|
| 358 |
+
|
| 359 |
+
# async def get_policy_stats(self, user_id: Optional[str] = None) -> Dict:
|
| 360 |
+
# """
|
| 361 |
+
# Get policy decision statistics.
|
| 362 |
+
|
| 363 |
+
# Args:
|
| 364 |
+
# user_id: Optional user ID filter
|
| 365 |
+
|
| 366 |
+
# Returns:
|
| 367 |
+
# dict: Policy statistics
|
| 368 |
+
# """
|
| 369 |
+
# # Build query
|
| 370 |
+
# query = {}
|
| 371 |
+
# if user_id:
|
| 372 |
+
# query["user_id"] = user_id
|
| 373 |
+
|
| 374 |
+
# # Count FETCH vs NO_FETCH
|
| 375 |
+
# fetch_count = await self.retrieval_logs.count_documents({
|
| 376 |
+
# **query,
|
| 377 |
+
# "policy_action": "FETCH"
|
| 378 |
+
# })
|
| 379 |
+
|
| 380 |
+
# no_fetch_count = await self.retrieval_logs.count_documents({
|
| 381 |
+
# **query,
|
| 382 |
+
# "policy_action": "NO_FETCH"
|
| 383 |
+
# })
|
| 384 |
+
|
| 385 |
+
# total = fetch_count + no_fetch_count
|
| 386 |
+
|
| 387 |
+
# return {
|
| 388 |
+
# "fetch_count": fetch_count,
|
| 389 |
+
# "no_fetch_count": no_fetch_count,
|
| 390 |
+
# "total": total,
|
| 391 |
+
# "fetch_rate": fetch_count / total if total > 0 else 0,
|
| 392 |
+
# "no_fetch_rate": no_fetch_count / total if total > 0 else 0
|
| 393 |
+
# }
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
# # ============================================================================
|
| 397 |
+
# # USAGE EXAMPLE (for reference)
|
| 398 |
+
# # ============================================================================
|
| 399 |
+
# """
|
| 400 |
+
# # In your service or API endpoint:
|
| 401 |
+
|
| 402 |
+
# from app.db.repositories.conversation_repository import ConversationRepository
|
| 403 |
+
|
| 404 |
+
# repo = ConversationRepository()
|
| 405 |
+
|
| 406 |
+
# # Create conversation
|
| 407 |
+
# conv_id = await repo.create_conversation(user_id="user_123")
|
| 408 |
+
|
| 409 |
+
# # Add user message
|
| 410 |
+
# await repo.add_message(conv_id, {
|
| 411 |
+
# 'role': 'user',
|
| 412 |
+
# 'content': 'What is my balance?',
|
| 413 |
+
# 'timestamp': datetime.now()
|
| 414 |
+
# })
|
| 415 |
+
|
| 416 |
+
# # Add assistant message
|
| 417 |
+
# await repo.add_message(conv_id, {
|
| 418 |
+
# 'role': 'assistant',
|
| 419 |
+
# 'content': 'Your balance is $1000',
|
| 420 |
+
# 'timestamp': datetime.now(),
|
| 421 |
+
# 'metadata': {
|
| 422 |
+
# 'policy_action': 'FETCH',
|
| 423 |
+
# 'documents_retrieved': 3
|
| 424 |
+
# }
|
| 425 |
+
# })
|
| 426 |
+
|
| 427 |
+
# # Get conversation history
|
| 428 |
+
# history = await repo.get_conversation_history(conv_id)
|
| 429 |
+
|
| 430 |
+
# # Log retrieval for RL training
|
| 431 |
+
# await repo.log_retrieval({
|
| 432 |
+
# 'conversation_id': conv_id,
|
| 433 |
+
# 'user_id': 'user_123',
|
| 434 |
+
# 'query': 'What is my balance?',
|
| 435 |
+
# 'policy_action': 'FETCH',
|
| 436 |
+
# 'documents_retrieved': 3,
|
| 437 |
+
# 'response': 'Your balance is $1000'
|
| 438 |
+
# })
|
| 439 |
+
# """
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
"""
|
| 502 |
+
Conversation Repository - MongoDB CRUD operations
|
| 503 |
+
Handles storing and retrieving conversations from MongoDB Atlas
|
| 504 |
+
|
| 505 |
+
Repository Pattern: Separates database logic from business logic
|
| 506 |
+
This makes code cleaner and easier to test
|
| 507 |
+
|
| 508 |
+
Collections:
|
| 509 |
+
- conversations: Stores complete conversations with messages
|
| 510 |
+
- retrieval_logs: Logs each retrieval operation (for RL training data)
|
| 511 |
+
"""
|
| 512 |
+
|
| 513 |
+
import uuid
|
| 514 |
+
from datetime import datetime
|
| 515 |
+
from typing import List, Dict, Optional
|
| 516 |
+
from bson import ObjectId
|
| 517 |
+
|
| 518 |
+
from app.db.mongodb import get_database
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
# ============================================================================
|
| 522 |
+
# CONVERSATION REPOSITORY
|
| 523 |
+
# ============================================================================
|
| 524 |
+
|
| 525 |
+
class ConversationRepository:
|
| 526 |
+
"""
|
| 527 |
+
Repository for conversation data in MongoDB.
|
| 528 |
+
|
| 529 |
+
Provides CRUD operations for:
|
| 530 |
+
1. Conversations (user chat sessions)
|
| 531 |
+
2. Retrieval logs (for RL training and analytics)
|
| 532 |
+
"""
|
| 533 |
+
|
| 534 |
+
def __init__(self):
|
| 535 |
+
"""
|
| 536 |
+
Initialize repository with database connection.
|
| 537 |
+
|
| 538 |
+
Gracefully handles case where MongoDB is not connected.
|
| 539 |
+
"""
|
| 540 |
+
self.db = get_database()
|
| 541 |
+
|
| 542 |
+
# Graceful handling if MongoDB not connected
|
| 543 |
+
if self.db is None:
|
| 544 |
+
print("⚠️ ConversationRepository: MongoDB not connected")
|
| 545 |
+
print(" Repository will not function until database is connected")
|
| 546 |
+
self.conversations = None
|
| 547 |
+
self.retrieval_logs = None
|
| 548 |
+
else:
|
| 549 |
+
self.conversations = self.db["conversations"]
|
| 550 |
+
self.retrieval_logs = self.db["retrieval_logs"]
|
| 551 |
+
print("✅ ConversationRepository initialized with MongoDB")
|
| 552 |
+
|
| 553 |
+
def _check_connection(self):
|
| 554 |
+
"""
|
| 555 |
+
Check if MongoDB is connected.
|
| 556 |
+
|
| 557 |
+
Raises:
|
| 558 |
+
RuntimeError: If MongoDB is not connected
|
| 559 |
+
"""
|
| 560 |
+
if self.db is None or self.conversations is None:
|
| 561 |
+
raise RuntimeError(
|
| 562 |
+
"MongoDB not connected. Cannot perform database operations. "
|
| 563 |
+
"Check MONGODB_URI in .env file."
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
# ========================================================================
|
| 567 |
+
# CONVERSATION CRUD OPERATIONS
|
| 568 |
+
# ========================================================================
|
| 569 |
+
|
| 570 |
+
async def create_conversation(
|
| 571 |
+
self,
|
| 572 |
+
user_id: str,
|
| 573 |
+
conversation_id: Optional[str] = None
|
| 574 |
+
) -> str:
|
| 575 |
+
"""
|
| 576 |
+
Create a new conversation.
|
| 577 |
+
|
| 578 |
+
Args:
|
| 579 |
+
user_id: User ID who owns this conversation
|
| 580 |
+
conversation_id: Optional custom conversation ID (auto-generated if None)
|
| 581 |
+
|
| 582 |
+
Returns:
|
| 583 |
+
str: Conversation ID
|
| 584 |
+
|
| 585 |
+
Raises:
|
| 586 |
+
RuntimeError: If MongoDB not connected
|
| 587 |
+
"""
|
| 588 |
+
self._check_connection()
|
| 589 |
+
|
| 590 |
+
if conversation_id is None:
|
| 591 |
+
conversation_id = str(uuid.uuid4())
|
| 592 |
+
|
| 593 |
+
conversation = {
|
| 594 |
+
"conversation_id": conversation_id,
|
| 595 |
+
"user_id": user_id,
|
| 596 |
+
"messages": [], # Will store all messages
|
| 597 |
+
"created_at": datetime.now(),
|
| 598 |
+
"updated_at": datetime.now(),
|
| 599 |
+
"status": "active" # active, archived, deleted
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
await self.conversations.insert_one(conversation)
|
| 603 |
+
|
| 604 |
+
return conversation_id
|
| 605 |
+
|
| 606 |
+
async def get_conversation(self, conversation_id: str) -> Optional[Dict]:
|
| 607 |
+
"""
|
| 608 |
+
Get a conversation by ID.
|
| 609 |
+
|
| 610 |
+
Args:
|
| 611 |
+
conversation_id: Conversation ID
|
| 612 |
+
|
| 613 |
+
Returns:
|
| 614 |
+
dict or None: Conversation document
|
| 615 |
+
|
| 616 |
+
Raises:
|
| 617 |
+
RuntimeError: If MongoDB not connected
|
| 618 |
+
"""
|
| 619 |
+
self._check_connection()
|
| 620 |
+
|
| 621 |
+
conversation = await self.conversations.find_one(
|
| 622 |
+
{"conversation_id": conversation_id}
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
# Convert MongoDB ObjectId to string for JSON serialization
|
| 626 |
+
if conversation and "_id" in conversation:
|
| 627 |
+
conversation["_id"] = str(conversation["_id"])
|
| 628 |
+
|
| 629 |
+
return conversation
|
| 630 |
+
|
| 631 |
+
# async def get_user_conversations(
|
| 632 |
+
# self,
|
| 633 |
+
# user_id: str,
|
| 634 |
+
# limit: int = 10,
|
| 635 |
+
# skip: int = 0
|
| 636 |
+
# ) -> List[Dict]:
|
| 637 |
+
# """
|
| 638 |
+
# Get all conversations for a user.
|
| 639 |
+
|
| 640 |
+
# Args:
|
| 641 |
+
# user_id: User ID
|
| 642 |
+
# limit: Maximum number of conversations to return
|
| 643 |
+
# skip: Number of conversations to skip (for pagination)
|
| 644 |
+
|
| 645 |
+
# Returns:
|
| 646 |
+
# list: List of conversation documents
|
| 647 |
+
|
| 648 |
+
# Raises:
|
| 649 |
+
# RuntimeError: If MongoDB not connected
|
| 650 |
+
# """
|
| 651 |
+
# self._check_connection()
|
| 652 |
+
|
| 653 |
+
# cursor = self.conversations.find(
|
| 654 |
+
# {"user_id": user_id, "status": "active"}
|
| 655 |
+
# ).sort("updated_at", -1).skip(skip).limit(limit)
|
| 656 |
+
|
| 657 |
+
# conversations = await cursor.to_list(length=limit)
|
| 658 |
+
|
| 659 |
+
# # Convert ObjectIds to strings
|
| 660 |
+
# for conv in conversations:
|
| 661 |
+
# if "_id" in conv:
|
| 662 |
+
# conv["_id"] = str(conv["_id"])
|
| 663 |
+
|
| 664 |
+
# return conversations
|
| 665 |
+
async def get_user_conversations(
|
| 666 |
+
self,
|
| 667 |
+
user_id: str,
|
| 668 |
+
limit: int = 10,
|
| 669 |
+
skip: int = 0
|
| 670 |
+
) -> List[Dict]:
|
| 671 |
+
"""Get all conversations for a user."""
|
| 672 |
+
# Gracefully return empty list if not connected
|
| 673 |
+
if self.db is None or self.conversations is None:
|
| 674 |
+
print("⚠️ MongoDB not connected - returning empty conversations list")
|
| 675 |
+
return []
|
| 676 |
+
|
| 677 |
+
cursor = self.conversations.find(
|
| 678 |
+
{"user_id": user_id, "status": "active"}
|
| 679 |
+
).sort("updated_at", -1).skip(skip).limit(limit)
|
| 680 |
+
|
| 681 |
+
conversations = await cursor.to_list(length=limit)
|
| 682 |
+
|
| 683 |
+
# Convert ObjectIds to strings
|
| 684 |
+
for conv in conversations:
|
| 685 |
+
if "_id" in conv:
|
| 686 |
+
conv["_id"] = str(conv["_id"])
|
| 687 |
+
|
| 688 |
+
return conversations
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
async def add_message(
|
| 692 |
+
self,
|
| 693 |
+
conversation_id: str,
|
| 694 |
+
message: Dict
|
| 695 |
+
) -> bool:
|
| 696 |
+
"""
|
| 697 |
+
Add a message to a conversation.
|
| 698 |
+
|
| 699 |
+
Args:
|
| 700 |
+
conversation_id: Conversation ID
|
| 701 |
+
message: Message dict
|
| 702 |
+
{
|
| 703 |
+
'role': 'user' or 'assistant',
|
| 704 |
+
'content': str,
|
| 705 |
+
'timestamp': datetime,
|
| 706 |
+
'metadata': dict (optional - policy_action, docs_retrieved, etc.)
|
| 707 |
+
}
|
| 708 |
+
|
| 709 |
+
Returns:
|
| 710 |
+
bool: Success status
|
| 711 |
+
|
| 712 |
+
Raises:
|
| 713 |
+
RuntimeError: If MongoDB not connected
|
| 714 |
+
"""
|
| 715 |
+
self._check_connection()
|
| 716 |
+
|
| 717 |
+
# Ensure timestamp exists
|
| 718 |
+
if "timestamp" not in message:
|
| 719 |
+
message["timestamp"] = datetime.now()
|
| 720 |
+
|
| 721 |
+
# Add message to conversation
|
| 722 |
+
result = await self.conversations.update_one(
|
| 723 |
+
{"conversation_id": conversation_id},
|
| 724 |
+
{
|
| 725 |
+
"$push": {"messages": message},
|
| 726 |
+
"$set": {"updated_at": datetime.now()}
|
| 727 |
+
}
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
return result.modified_count > 0
|
| 731 |
+
|
| 732 |
+
async def get_conversation_history(
|
| 733 |
+
self,
|
| 734 |
+
conversation_id: str,
|
| 735 |
+
max_messages: int = None
|
| 736 |
+
) -> List[Dict]:
|
| 737 |
+
"""
|
| 738 |
+
Get conversation history (messages only).
|
| 739 |
+
|
| 740 |
+
Args:
|
| 741 |
+
conversation_id: Conversation ID
|
| 742 |
+
max_messages: Optional limit on number of messages
|
| 743 |
+
|
| 744 |
+
Returns:
|
| 745 |
+
list: List of messages
|
| 746 |
+
|
| 747 |
+
Raises:
|
| 748 |
+
RuntimeError: If MongoDB not connected
|
| 749 |
+
"""
|
| 750 |
+
self._check_connection()
|
| 751 |
+
|
| 752 |
+
conversation = await self.get_conversation(conversation_id)
|
| 753 |
+
|
| 754 |
+
if not conversation:
|
| 755 |
+
return []
|
| 756 |
+
|
| 757 |
+
messages = conversation.get("messages", [])
|
| 758 |
+
|
| 759 |
+
if max_messages:
|
| 760 |
+
messages = messages[-max_messages:]
|
| 761 |
+
|
| 762 |
+
return messages
|
| 763 |
+
|
| 764 |
+
async def delete_conversation(self, conversation_id: str) -> bool:
|
| 765 |
+
"""
|
| 766 |
+
Soft delete a conversation (mark as deleted, don't actually delete).
|
| 767 |
+
|
| 768 |
+
Args:
|
| 769 |
+
conversation_id: Conversation ID
|
| 770 |
+
|
| 771 |
+
Returns:
|
| 772 |
+
bool: Success status
|
| 773 |
+
|
| 774 |
+
Raises:
|
| 775 |
+
RuntimeError: If MongoDB not connected
|
| 776 |
+
"""
|
| 777 |
+
self._check_connection()
|
| 778 |
+
|
| 779 |
+
result = await self.conversations.update_one(
|
| 780 |
+
{"conversation_id": conversation_id},
|
| 781 |
+
{
|
| 782 |
+
"$set": {
|
| 783 |
+
"status": "deleted",
|
| 784 |
+
"deleted_at": datetime.now()
|
| 785 |
+
}
|
| 786 |
+
}
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
+
return result.modified_count > 0
|
| 790 |
+
|
| 791 |
+
# ========================================================================
|
| 792 |
+
# RETRIEVAL LOGS (for RL training)
|
| 793 |
+
# ========================================================================
|
| 794 |
+
|
| 795 |
+
async def log_retrieval(
|
| 796 |
+
self,
|
| 797 |
+
log_data: Dict
|
| 798 |
+
) -> str:
|
| 799 |
+
"""
|
| 800 |
+
Log a retrieval operation (for RL training and analysis).
|
| 801 |
+
|
| 802 |
+
Args:
|
| 803 |
+
log_data: Log data dict
|
| 804 |
+
{
|
| 805 |
+
'conversation_id': str,
|
| 806 |
+
'user_id': str,
|
| 807 |
+
'query': str,
|
| 808 |
+
'policy_action': 'FETCH' or 'NO_FETCH',
|
| 809 |
+
'policy_confidence': float,
|
| 810 |
+
'documents_retrieved': int,
|
| 811 |
+
'top_doc_score': float or None,
|
| 812 |
+
'retrieved_docs_metadata': list,
|
| 813 |
+
'response': str,
|
| 814 |
+
'retrieval_time_ms': float,
|
| 815 |
+
'generation_time_ms': float,
|
| 816 |
+
'total_time_ms': float,
|
| 817 |
+
'timestamp': datetime
|
| 818 |
+
}
|
| 819 |
+
|
| 820 |
+
Returns:
|
| 821 |
+
str: Log ID
|
| 822 |
+
|
| 823 |
+
Raises:
|
| 824 |
+
RuntimeError: If MongoDB not connected
|
| 825 |
+
"""
|
| 826 |
+
self._check_connection()
|
| 827 |
+
|
| 828 |
+
# Add timestamp if not present
|
| 829 |
+
if "timestamp" not in log_data:
|
| 830 |
+
log_data["timestamp"] = datetime.now()
|
| 831 |
+
|
| 832 |
+
# Generate log ID
|
| 833 |
+
log_id = str(uuid.uuid4())
|
| 834 |
+
log_data["log_id"] = log_id
|
| 835 |
+
|
| 836 |
+
# Insert log
|
| 837 |
+
await self.retrieval_logs.insert_one(log_data)
|
| 838 |
+
|
| 839 |
+
return log_id
|
| 840 |
+
|
| 841 |
+
async def get_retrieval_logs(
|
| 842 |
+
self,
|
| 843 |
+
conversation_id: Optional[str] = None,
|
| 844 |
+
user_id: Optional[str] = None,
|
| 845 |
+
limit: int = 100,
|
| 846 |
+
skip: int = 0
|
| 847 |
+
) -> List[Dict]:
|
| 848 |
+
"""
|
| 849 |
+
Get retrieval logs (for analysis and RL training).
|
| 850 |
+
|
| 851 |
+
Args:
|
| 852 |
+
conversation_id: Optional filter by conversation
|
| 853 |
+
user_id: Optional filter by user
|
| 854 |
+
limit: Maximum number of logs
|
| 855 |
+
skip: Number of logs to skip
|
| 856 |
+
|
| 857 |
+
Returns:
|
| 858 |
+
list: List of log documents
|
| 859 |
+
|
| 860 |
+
Raises:
|
| 861 |
+
RuntimeError: If MongoDB not connected
|
| 862 |
+
"""
|
| 863 |
+
self._check_connection()
|
| 864 |
+
|
| 865 |
+
# Build query
|
| 866 |
+
query = {}
|
| 867 |
+
if conversation_id:
|
| 868 |
+
query["conversation_id"] = conversation_id
|
| 869 |
+
if user_id:
|
| 870 |
+
query["user_id"] = user_id
|
| 871 |
+
|
| 872 |
+
# Fetch logs
|
| 873 |
+
cursor = self.retrieval_logs.find(query).sort("timestamp", -1).skip(skip).limit(limit)
|
| 874 |
+
logs = await cursor.to_list(length=limit)
|
| 875 |
+
|
| 876 |
+
# Convert ObjectIds to strings
|
| 877 |
+
for log in logs:
|
| 878 |
+
if "_id" in log:
|
| 879 |
+
log["_id"] = str(log["_id"])
|
| 880 |
+
|
| 881 |
+
return logs
|
| 882 |
+
|
| 883 |
+
async def get_logs_for_rl_training(
|
| 884 |
+
self,
|
| 885 |
+
min_date: Optional[datetime] = None,
|
| 886 |
+
limit: int = 1000
|
| 887 |
+
) -> List[Dict]:
|
| 888 |
+
"""
|
| 889 |
+
Get logs specifically for RL training.
|
| 890 |
+
Filters for logs with both policy decision and retrieval results.
|
| 891 |
+
|
| 892 |
+
Args:
|
| 893 |
+
min_date: Optional minimum date for logs
|
| 894 |
+
limit: Maximum number of logs
|
| 895 |
+
|
| 896 |
+
Returns:
|
| 897 |
+
list: List of log documents suitable for RL training
|
| 898 |
+
|
| 899 |
+
Raises:
|
| 900 |
+
RuntimeError: If MongoDB not connected
|
| 901 |
+
"""
|
| 902 |
+
self._check_connection()
|
| 903 |
+
|
| 904 |
+
# Build query
|
| 905 |
+
query = {
|
| 906 |
+
"policy_action": {"$exists": True},
|
| 907 |
+
"response": {"$exists": True}
|
| 908 |
+
}
|
| 909 |
+
|
| 910 |
+
if min_date:
|
| 911 |
+
query["timestamp"] = {"$gte": min_date}
|
| 912 |
+
|
| 913 |
+
# Fetch logs
|
| 914 |
+
cursor = self.retrieval_logs.find(query).sort("timestamp", -1).limit(limit)
|
| 915 |
+
logs = await cursor.to_list(length=limit)
|
| 916 |
+
|
| 917 |
+
# Convert ObjectIds
|
| 918 |
+
for log in logs:
|
| 919 |
+
if "_id" in log:
|
| 920 |
+
log["_id"] = str(log["_id"])
|
| 921 |
+
|
| 922 |
+
return logs
|
| 923 |
+
|
| 924 |
+
# ========================================================================
|
| 925 |
+
# ANALYTICS QUERIES
|
| 926 |
+
# ========================================================================
|
| 927 |
+
|
| 928 |
+
async def get_conversation_stats(self, user_id: str) -> Dict:
|
| 929 |
+
"""
|
| 930 |
+
Get conversation statistics for a user.
|
| 931 |
+
|
| 932 |
+
Args:
|
| 933 |
+
user_id: User ID
|
| 934 |
+
|
| 935 |
+
Returns:
|
| 936 |
+
dict: Statistics
|
| 937 |
+
|
| 938 |
+
Raises:
|
| 939 |
+
RuntimeError: If MongoDB not connected
|
| 940 |
+
"""
|
| 941 |
+
self._check_connection()
|
| 942 |
+
|
| 943 |
+
# Count total conversations
|
| 944 |
+
total_conversations = await self.conversations.count_documents({
|
| 945 |
+
"user_id": user_id,
|
| 946 |
+
"status": "active"
|
| 947 |
+
})
|
| 948 |
+
|
| 949 |
+
# Count total messages
|
| 950 |
+
pipeline = [
|
| 951 |
+
{"$match": {"user_id": user_id, "status": "active"}},
|
| 952 |
+
{"$project": {"message_count": {"$size": "$messages"}}}
|
| 953 |
+
]
|
| 954 |
+
|
| 955 |
+
result = await self.conversations.aggregate(pipeline).to_list(length=None)
|
| 956 |
+
total_messages = sum(doc.get("message_count", 0) for doc in result)
|
| 957 |
+
|
| 958 |
+
return {
|
| 959 |
+
"total_conversations": total_conversations,
|
| 960 |
+
"total_messages": total_messages,
|
| 961 |
+
"avg_messages_per_conversation": total_messages / total_conversations if total_conversations > 0 else 0
|
| 962 |
+
}
|
| 963 |
+
|
| 964 |
+
async def get_policy_stats(self, user_id: Optional[str] = None) -> Dict:
|
| 965 |
+
"""
|
| 966 |
+
Get policy decision statistics.
|
| 967 |
+
|
| 968 |
+
Args:
|
| 969 |
+
user_id: Optional user ID filter
|
| 970 |
+
|
| 971 |
+
Returns:
|
| 972 |
+
dict: Policy statistics
|
| 973 |
+
|
| 974 |
+
Raises:
|
| 975 |
+
RuntimeError: If MongoDB not connected
|
| 976 |
+
"""
|
| 977 |
+
self._check_connection()
|
| 978 |
+
|
| 979 |
+
# Build query
|
| 980 |
+
query = {}
|
| 981 |
+
if user_id:
|
| 982 |
+
query["user_id"] = user_id
|
| 983 |
+
|
| 984 |
+
# Count FETCH vs NO_FETCH
|
| 985 |
+
fetch_count = await self.retrieval_logs.count_documents({
|
| 986 |
+
**query,
|
| 987 |
+
"policy_action": "FETCH"
|
| 988 |
+
})
|
| 989 |
+
|
| 990 |
+
no_fetch_count = await self.retrieval_logs.count_documents({
|
| 991 |
+
**query,
|
| 992 |
+
"policy_action": "NO_FETCH"
|
| 993 |
+
})
|
| 994 |
+
|
| 995 |
+
total = fetch_count + no_fetch_count
|
| 996 |
+
|
| 997 |
+
return {
|
| 998 |
+
"fetch_count": fetch_count,
|
| 999 |
+
"no_fetch_count": no_fetch_count,
|
| 1000 |
+
"total": total,
|
| 1001 |
+
"fetch_rate": fetch_count / total if total > 0 else 0,
|
| 1002 |
+
"no_fetch_rate": no_fetch_count / total if total > 0 else 0
|
| 1003 |
+
}
|
| 1004 |
+
|
| 1005 |
+
|
| 1006 |
+
# ============================================================================
|
| 1007 |
+
# USAGE EXAMPLE (for reference)
|
| 1008 |
+
# ============================================================================
|
| 1009 |
+
"""
|
| 1010 |
+
# In your service or API endpoint:
|
| 1011 |
+
|
| 1012 |
+
from app.db.repositories.conversation_repository import ConversationRepository
|
| 1013 |
+
|
| 1014 |
+
repo = ConversationRepository()
|
| 1015 |
+
|
| 1016 |
+
# Create conversation
|
| 1017 |
+
conv_id = await repo.create_conversation(user_id="user_123")
|
| 1018 |
+
|
| 1019 |
+
# Add user message
|
| 1020 |
+
await repo.add_message(conv_id, {
|
| 1021 |
+
'role': 'user',
|
| 1022 |
+
'content': 'What is my balance?',
|
| 1023 |
+
'timestamp': datetime.now()
|
| 1024 |
+
})
|
| 1025 |
+
|
| 1026 |
+
# Add assistant message
|
| 1027 |
+
await repo.add_message(conv_id, {
|
| 1028 |
+
'role': 'assistant',
|
| 1029 |
+
'content': 'Your balance is $1000',
|
| 1030 |
+
'timestamp': datetime.now(),
|
| 1031 |
+
'metadata': {
|
| 1032 |
+
'policy_action': 'FETCH',
|
| 1033 |
+
'documents_retrieved': 3
|
| 1034 |
+
}
|
| 1035 |
+
})
|
| 1036 |
+
|
| 1037 |
+
# Get conversation history
|
| 1038 |
+
history = await repo.get_conversation_history(conv_id)
|
| 1039 |
+
|
| 1040 |
+
# Log retrieval for RL training
|
| 1041 |
+
await repo.log_retrieval({
|
| 1042 |
+
'conversation_id': conv_id,
|
| 1043 |
+
'user_id': 'user_123',
|
| 1044 |
+
'query': 'What is my balance?',
|
| 1045 |
+
'policy_action': 'FETCH',
|
| 1046 |
+
'documents_retrieved': 3,
|
| 1047 |
+
'response': 'Your balance is $1000'
|
| 1048 |
+
})
|
| 1049 |
+
"""
|
app/db/repositories/user_repository.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
User Repository - MongoDB CRUD for Users
|
| 3 |
+
Handles user registration, retrieval, and management
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import uuid
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from typing import Optional, Dict
|
| 9 |
+
from app.db.mongodb import get_database
|
| 10 |
+
from app.utils.security import hash_password
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class UserRepository:
|
| 14 |
+
"""Repository for user data in MongoDB"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
"""Initialize repository with database connection"""
|
| 18 |
+
self.db = get_database()
|
| 19 |
+
|
| 20 |
+
if self.db is None:
|
| 21 |
+
print("⚠️ UserRepository: MongoDB not connected")
|
| 22 |
+
self.users = None
|
| 23 |
+
else:
|
| 24 |
+
self.users = self.db["users"]
|
| 25 |
+
print("✅ UserRepository initialized")
|
| 26 |
+
|
| 27 |
+
def _check_connection(self):
|
| 28 |
+
"""Check if MongoDB is connected"""
|
| 29 |
+
if self.db is None or self.users is None:
|
| 30 |
+
raise RuntimeError("MongoDB not connected")
|
| 31 |
+
|
| 32 |
+
async def create_user(
|
| 33 |
+
self,
|
| 34 |
+
email: str,
|
| 35 |
+
password: str,
|
| 36 |
+
full_name: str
|
| 37 |
+
) -> str:
|
| 38 |
+
"""
|
| 39 |
+
Create a new user.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
email: User email (unique)
|
| 43 |
+
password: Plain text password (will be hashed)
|
| 44 |
+
full_name: User's full name
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
str: User ID
|
| 48 |
+
|
| 49 |
+
Raises:
|
| 50 |
+
ValueError: If email already exists
|
| 51 |
+
"""
|
| 52 |
+
self._check_connection()
|
| 53 |
+
|
| 54 |
+
# Check if user already exists
|
| 55 |
+
existing_user = await self.users.find_one({"email": email})
|
| 56 |
+
if existing_user:
|
| 57 |
+
raise ValueError("Email already registered")
|
| 58 |
+
|
| 59 |
+
# Create user document
|
| 60 |
+
user_id = str(uuid.uuid4())
|
| 61 |
+
user = {
|
| 62 |
+
"user_id": user_id,
|
| 63 |
+
"email": email,
|
| 64 |
+
"hashed_password": hash_password(password),
|
| 65 |
+
"full_name": full_name,
|
| 66 |
+
"created_at": datetime.now(),
|
| 67 |
+
"is_active": True
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
await self.users.insert_one(user)
|
| 71 |
+
print(f"✅ Created user: {email}")
|
| 72 |
+
|
| 73 |
+
return user_id
|
| 74 |
+
|
| 75 |
+
async def get_user_by_email(self, email: str) -> Optional[Dict]:
|
| 76 |
+
"""
|
| 77 |
+
Get user by email.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
email: User email
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
dict or None: User document
|
| 84 |
+
"""
|
| 85 |
+
self._check_connection()
|
| 86 |
+
|
| 87 |
+
user = await self.users.find_one({"email": email})
|
| 88 |
+
|
| 89 |
+
if user and "_id" in user:
|
| 90 |
+
user["_id"] = str(user["_id"])
|
| 91 |
+
|
| 92 |
+
return user
|
| 93 |
+
|
| 94 |
+
async def get_user_by_id(self, user_id: str) -> Optional[Dict]:
|
| 95 |
+
"""
|
| 96 |
+
Get user by ID.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
user_id: User ID
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
dict or None: User document
|
| 103 |
+
"""
|
| 104 |
+
self._check_connection()
|
| 105 |
+
|
| 106 |
+
user = await self.users.find_one({"user_id": user_id})
|
| 107 |
+
|
| 108 |
+
if user and "_id" in user:
|
| 109 |
+
user["_id"] = str(user["_id"])
|
| 110 |
+
|
| 111 |
+
return user
|
| 112 |
+
|
| 113 |
+
async def update_user(self, user_id: str, updates: Dict) -> bool:
|
| 114 |
+
"""
|
| 115 |
+
Update user information.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
user_id: User ID
|
| 119 |
+
updates: Dictionary of fields to update
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
bool: Success status
|
| 123 |
+
"""
|
| 124 |
+
self._check_connection()
|
| 125 |
+
|
| 126 |
+
# Don't allow updating certain fields
|
| 127 |
+
forbidden_fields = ["user_id", "email", "hashed_password", "created_at"]
|
| 128 |
+
for field in forbidden_fields:
|
| 129 |
+
updates.pop(field, None)
|
| 130 |
+
|
| 131 |
+
result = await self.users.update_one(
|
| 132 |
+
{"user_id": user_id},
|
| 133 |
+
{"$set": updates}
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
return result.modified_count > 0
|
| 137 |
+
|
| 138 |
+
async def delete_user(self, user_id: str) -> bool:
|
| 139 |
+
"""
|
| 140 |
+
Soft delete a user (mark as inactive).
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
user_id: User ID
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
bool: Success status
|
| 147 |
+
"""
|
| 148 |
+
self._check_connection()
|
| 149 |
+
|
| 150 |
+
result = await self.users.update_one(
|
| 151 |
+
{"user_id": user_id},
|
| 152 |
+
{"$set": {"is_active": False, "deleted_at": datetime.now()}}
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
return result.modified_count > 0
|
app/main.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI Main Application Entry Point
|
| 3 |
+
|
| 4 |
+
Banking RAG Chatbot API with JWT Authentication
|
| 5 |
+
|
| 6 |
+
This file:
|
| 7 |
+
1. Creates the FastAPI app
|
| 8 |
+
2. Configures CORS middleware
|
| 9 |
+
3. Connects to MongoDB on startup/shutdown
|
| 10 |
+
4. Includes API routers (auth + chat)
|
| 11 |
+
5. Provides health check endpoints
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from fastapi import FastAPI, Request
|
| 15 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 16 |
+
from fastapi.responses import JSONResponse
|
| 17 |
+
from contextlib import asynccontextmanager
|
| 18 |
+
|
| 19 |
+
from app.config import settings
|
| 20 |
+
from app.db.mongodb import connect_to_mongo, close_mongo_connection
|
| 21 |
+
|
| 22 |
+
# ============================================================================
|
| 23 |
+
# LIFESPAN MANAGER (Startup & Shutdown)
|
| 24 |
+
# ============================================================================
|
| 25 |
+
|
| 26 |
+
@asynccontextmanager
|
| 27 |
+
async def lifespan(app: FastAPI):
|
| 28 |
+
"""
|
| 29 |
+
Manage application lifespan events.
|
| 30 |
+
|
| 31 |
+
Startup:
|
| 32 |
+
- Connect to MongoDB Atlas
|
| 33 |
+
- ML models load lazily on first use
|
| 34 |
+
|
| 35 |
+
Shutdown:
|
| 36 |
+
- Close MongoDB connection
|
| 37 |
+
- Cleanup resources
|
| 38 |
+
"""
|
| 39 |
+
# ========================================================================
|
| 40 |
+
# STARTUP
|
| 41 |
+
# ========================================================================
|
| 42 |
+
print("\n" + "=" * 80)
|
| 43 |
+
print("🚀 STARTING BANKING RAG CHATBOT API")
|
| 44 |
+
print("=" * 80)
|
| 45 |
+
print(f"Environment: {settings.ENVIRONMENT}")
|
| 46 |
+
print(f"Debug Mode: {settings.DEBUG}")
|
| 47 |
+
print("=" * 80)
|
| 48 |
+
|
| 49 |
+
# Connect to MongoDB
|
| 50 |
+
await connect_to_mongo()
|
| 51 |
+
|
| 52 |
+
print("\n💡 ML Models Info:")
|
| 53 |
+
print(" Policy Network: Loads on first chat request (lazy loading)")
|
| 54 |
+
print(" Retriever Model: Loads on first retrieval (lazy loading)")
|
| 55 |
+
print(" LLM: Groq (ChatGroq) with HuggingFace fallback")
|
| 56 |
+
print("\n🤖 LLM Configuration:")
|
| 57 |
+
print(f" Chat Model: {settings.GROQ_CHAT_MODEL} (Llama 3 8B)")
|
| 58 |
+
print(f" Eval Model: {settings.GROQ_EVAL_MODEL} (Llama 3 70B)")
|
| 59 |
+
print(f" Groq API Keys: {len(settings.get_groq_api_keys())} configured")
|
| 60 |
+
print(f" HuggingFace Tokens: {len(settings.get_hf_tokens())} configured")
|
| 61 |
+
print(f" Fallback: Groq → HuggingFace")
|
| 62 |
+
|
| 63 |
+
print("\n✅ Backend startup complete!")
|
| 64 |
+
print("=" * 80)
|
| 65 |
+
print(f"📖 API Docs: http://localhost:8000/docs")
|
| 66 |
+
print(f"🏥 Health Check: http://localhost:8000/health")
|
| 67 |
+
print(f"🔐 Register: POST http://localhost:8000/api/v1/auth/register")
|
| 68 |
+
print(f"🔑 Login: POST http://localhost:8000/api/v1/auth/login")
|
| 69 |
+
print("=" * 80 + "\n")
|
| 70 |
+
|
| 71 |
+
yield # Application runs here
|
| 72 |
+
|
| 73 |
+
# ========================================================================
|
| 74 |
+
# SHUTDOWN
|
| 75 |
+
# ========================================================================
|
| 76 |
+
print("\n" + "=" * 80)
|
| 77 |
+
print("🛑 SHUTTING DOWN API")
|
| 78 |
+
print("=" * 80)
|
| 79 |
+
|
| 80 |
+
# Close MongoDB connection
|
| 81 |
+
await close_mongo_connection()
|
| 82 |
+
|
| 83 |
+
print("✅ Shutdown complete")
|
| 84 |
+
print("=" * 80 + "\n")
|
| 85 |
+
|
| 86 |
+
# ============================================================================
|
| 87 |
+
# CREATE FASTAPI APPLICATION
|
| 88 |
+
# ============================================================================
|
| 89 |
+
|
| 90 |
+
app = FastAPI(
|
| 91 |
+
title="Banking RAG Chatbot API",
|
| 92 |
+
description="""
|
| 93 |
+
🤖 AI-powered Banking Assistant with:
|
| 94 |
+
|
| 95 |
+
**Features:**
|
| 96 |
+
- 🔐 JWT Authentication (Sign up, Login, Protected routes)
|
| 97 |
+
- 💬 RAG (Retrieval-Augmented Generation)
|
| 98 |
+
- 🧠 RL-based Policy Network (BERT)
|
| 99 |
+
- 🔍 Custom E5 Retriever
|
| 100 |
+
- ⚡ Groq LLM with HuggingFace Fallback (Llama 3 models)
|
| 101 |
+
|
| 102 |
+
**Capabilities:**
|
| 103 |
+
- Intelligent document retrieval
|
| 104 |
+
- Context-aware responses
|
| 105 |
+
- Conversation history
|
| 106 |
+
- Real-time chat
|
| 107 |
+
- User authentication & authorization
|
| 108 |
+
- Multi-provider LLM with automatic fallback
|
| 109 |
+
""",
|
| 110 |
+
version="1.0.0",
|
| 111 |
+
docs_url="/docs",
|
| 112 |
+
redoc_url="/redoc",
|
| 113 |
+
lifespan=lifespan
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# ============================================================================
|
| 117 |
+
# CORS MIDDLEWARE
|
| 118 |
+
# ============================================================================
|
| 119 |
+
|
| 120 |
+
allowed_origins = settings.get_allowed_origins()
|
| 121 |
+
print("\n🌐 CORS Configuration:")
|
| 122 |
+
print(f" Allowed Origins: {allowed_origins}")
|
| 123 |
+
|
| 124 |
+
app.add_middleware(
|
| 125 |
+
CORSMiddleware,
|
| 126 |
+
allow_origins=allowed_origins,
|
| 127 |
+
allow_credentials=True,
|
| 128 |
+
allow_methods=["*"],
|
| 129 |
+
allow_headers=["*"],
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# ============================================================================
|
| 133 |
+
# INCLUDE API ROUTERS
|
| 134 |
+
# ============================================================================
|
| 135 |
+
|
| 136 |
+
from app.api.v1 import chat, auth
|
| 137 |
+
|
| 138 |
+
# Auth router (public endpoints - register, login)
|
| 139 |
+
app.include_router(
|
| 140 |
+
auth.router,
|
| 141 |
+
prefix="/api/v1/auth",
|
| 142 |
+
tags=["🔐 Authentication"]
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Chat router (protected endpoints - requires JWT token)
|
| 146 |
+
app.include_router(
|
| 147 |
+
chat.router,
|
| 148 |
+
prefix="/api/v1/chat",
|
| 149 |
+
tags=["💬 Chat"]
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# ============================================================================
|
| 153 |
+
# ROOT ENDPOINTS
|
| 154 |
+
# ============================================================================
|
| 155 |
+
|
| 156 |
+
@app.get("/", tags=["📍 Root"])
|
| 157 |
+
async def root():
|
| 158 |
+
"""
|
| 159 |
+
Root endpoint - API information and available endpoints
|
| 160 |
+
"""
|
| 161 |
+
return {
|
| 162 |
+
"message": "Banking RAG Chatbot API with Authentication",
|
| 163 |
+
"version": "1.0.0",
|
| 164 |
+
"status": "online",
|
| 165 |
+
"authentication": "JWT Bearer Token Required for chat endpoints",
|
| 166 |
+
"llm_provider": "Groq (ChatGroq) with HuggingFace fallback",
|
| 167 |
+
"models": {
|
| 168 |
+
"chat": settings.GROQ_CHAT_MODEL,
|
| 169 |
+
"evaluation": settings.GROQ_EVAL_MODEL
|
| 170 |
+
},
|
| 171 |
+
"documentation": {
|
| 172 |
+
"swagger_ui": "/docs",
|
| 173 |
+
"redoc": "/redoc"
|
| 174 |
+
},
|
| 175 |
+
"endpoints": {
|
| 176 |
+
"auth": {
|
| 177 |
+
"register": "POST /api/v1/auth/register",
|
| 178 |
+
"login": "POST /api/v1/auth/login",
|
| 179 |
+
"me": "GET /api/v1/auth/me (requires token)",
|
| 180 |
+
"logout": "POST /api/v1/auth/logout (requires token)"
|
| 181 |
+
},
|
| 182 |
+
"chat": {
|
| 183 |
+
"send_message": "POST /api/v1/chat/ (requires token)",
|
| 184 |
+
"get_history": "GET /api/v1/chat/history/{conversation_id} (requires token)",
|
| 185 |
+
"list_conversations": "GET /api/v1/chat/conversations (requires token)",
|
| 186 |
+
"delete_conversation": "DELETE /api/v1/chat/conversation/{conversation_id} (requires token)"
|
| 187 |
+
},
|
| 188 |
+
"health": "GET /health"
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
@app.get("/health", tags=["🏥 Health"])
|
| 193 |
+
async def health_check():
|
| 194 |
+
"""
|
| 195 |
+
Comprehensive health check endpoint
|
| 196 |
+
|
| 197 |
+
Checks status of:
|
| 198 |
+
- API service
|
| 199 |
+
- MongoDB connection
|
| 200 |
+
- ML models (lazy loaded)
|
| 201 |
+
- Authentication system
|
| 202 |
+
- LLM providers (Groq & HuggingFace)
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
dict: Health status of all components
|
| 206 |
+
"""
|
| 207 |
+
from app.db.mongodb import get_database
|
| 208 |
+
|
| 209 |
+
# Check MongoDB
|
| 210 |
+
mongodb_status = "connected" if get_database() is not None else "disconnected"
|
| 211 |
+
|
| 212 |
+
# Check ML models (don't load them, just check readiness)
|
| 213 |
+
ml_models_status = {
|
| 214 |
+
"policy_network": "ready (lazy load)",
|
| 215 |
+
"retriever": "ready (lazy load)",
|
| 216 |
+
"llm": "ready (API-based)"
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
# Check LLM providers
|
| 220 |
+
llm_providers = {
|
| 221 |
+
"groq": {
|
| 222 |
+
"enabled": settings.is_groq_enabled(),
|
| 223 |
+
"api_keys_configured": len(settings.get_groq_api_keys()),
|
| 224 |
+
"chat_model": settings.GROQ_CHAT_MODEL,
|
| 225 |
+
"eval_model": settings.GROQ_EVAL_MODEL
|
| 226 |
+
},
|
| 227 |
+
"huggingface": {
|
| 228 |
+
"enabled": settings.is_hf_enabled(),
|
| 229 |
+
"tokens_configured": len(settings.get_hf_tokens()),
|
| 230 |
+
"chat_model": settings.HF_CHAT_MODEL,
|
| 231 |
+
"eval_model": settings.HF_EVAL_MODEL
|
| 232 |
+
}
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
# Check authentication
|
| 236 |
+
auth_status = {
|
| 237 |
+
"jwt_enabled": bool(settings.SECRET_KEY and settings.SECRET_KEY != "your-secret-key-change-in-production"),
|
| 238 |
+
"algorithm": settings.ALGORITHM,
|
| 239 |
+
"token_expiry_minutes": settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
# Overall health
|
| 243 |
+
is_healthy = (
|
| 244 |
+
mongodb_status == "connected" and
|
| 245 |
+
auth_status["jwt_enabled"] and
|
| 246 |
+
(llm_providers["groq"]["enabled"] or llm_providers["huggingface"]["enabled"])
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
return {
|
| 250 |
+
"status": "healthy" if is_healthy else "degraded",
|
| 251 |
+
"api": "online",
|
| 252 |
+
"mongodb": mongodb_status,
|
| 253 |
+
"authentication": auth_status,
|
| 254 |
+
"llm_providers": llm_providers,
|
| 255 |
+
"ml_models": ml_models_status,
|
| 256 |
+
"environment": settings.ENVIRONMENT,
|
| 257 |
+
"debug_mode": settings.DEBUG
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
# ============================================================================
|
| 261 |
+
# GLOBAL EXCEPTION HANDLER
|
| 262 |
+
# ============================================================================
|
| 263 |
+
|
| 264 |
+
@app.exception_handler(Exception)
|
| 265 |
+
async def global_exception_handler(request: Request, exc: Exception):
|
| 266 |
+
"""
|
| 267 |
+
Global exception handler for unhandled errors
|
| 268 |
+
"""
|
| 269 |
+
print(f"\n❌ Unhandled Exception:")
|
| 270 |
+
print(f" Path: {request.url.path}")
|
| 271 |
+
print(f" Error: {str(exc)}")
|
| 272 |
+
|
| 273 |
+
if settings.DEBUG:
|
| 274 |
+
import traceback
|
| 275 |
+
traceback.print_exc()
|
| 276 |
+
|
| 277 |
+
return JSONResponse(
|
| 278 |
+
status_code=500,
|
| 279 |
+
content={
|
| 280 |
+
"error": "Internal Server Error",
|
| 281 |
+
"detail": str(exc) if settings.DEBUG else "An unexpected error occurred",
|
| 282 |
+
"path": str(request.url.path)
|
| 283 |
+
}
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# ============================================================================
|
| 287 |
+
# MAIN ENTRY POINT (for direct execution)
|
| 288 |
+
# ============================================================================
|
| 289 |
+
|
| 290 |
+
if __name__ == "__main__":
|
| 291 |
+
import uvicorn
|
| 292 |
+
|
| 293 |
+
print("\n🚀 Starting server directly...")
|
| 294 |
+
print(" Note: For production, use: uvicorn app.main:app --host 0.0.0.0 --port 8000")
|
| 295 |
+
|
| 296 |
+
uvicorn.run(
|
| 297 |
+
"app.main:app",
|
| 298 |
+
host="0.0.0.0",
|
| 299 |
+
port=8000,
|
| 300 |
+
reload=settings.DEBUG # Auto-reload only in debug mode
|
| 301 |
+
)
|
app/ml/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Machine Learning package - ML models and inference
|
| 3 |
+
Contains retriever, policy network, and model loading utilities
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
__version__ = "1.0.0"
|
app/ml/policy_network.py
ADDED
|
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BERT-based Policy Network for FETCH/NO_FETCH decisions
|
| 3 |
+
Trained with Reinforcement Learning (Policy Gradient + Entropy Regularization)
|
| 4 |
+
|
| 5 |
+
This is adapted from your RL.py with:
|
| 6 |
+
- PolicyNetwork class (BERT-based)
|
| 7 |
+
- State encoding from conversation history
|
| 8 |
+
- Action prediction (FETCH vs NO_FETCH)
|
| 9 |
+
- Module-level caching (load once on startup)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
import numpy as np
|
| 16 |
+
from typing import List, Dict, Optional, Tuple
|
| 17 |
+
from transformers import AutoTokenizer, AutoModel
|
| 18 |
+
|
| 19 |
+
from app.config import settings
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# ============================================================================
|
| 23 |
+
# POLICY NETWORK (From RL.py)
|
| 24 |
+
# ============================================================================
|
| 25 |
+
|
| 26 |
+
class PolicyNetwork(nn.Module):
|
| 27 |
+
"""
|
| 28 |
+
BERT-based Policy Network for deciding FETCH vs NO_FETCH actions.
|
| 29 |
+
|
| 30 |
+
Architecture:
|
| 31 |
+
- Base: BERT-base-uncased (pre-trained)
|
| 32 |
+
- Input: Current query + conversation history + previous actions
|
| 33 |
+
- Output: 2-class softmax (FETCH=0, NO_FETCH=1)
|
| 34 |
+
- Special tokens: [FETCH], [NO_FETCH] for action encoding
|
| 35 |
+
|
| 36 |
+
Training Details:
|
| 37 |
+
- Loss: Policy Gradient + Entropy Regularization
|
| 38 |
+
- Optimizer: AdamW
|
| 39 |
+
- Reward structure:
|
| 40 |
+
* FETCH: +0.5 (always)
|
| 41 |
+
* NO_FETCH + Good: +2.0
|
| 42 |
+
* NO_FETCH + Bad: -0.5
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, model_name: str = "bert-base-uncased", dropout_rate: float = 0.1):
|
| 46 |
+
super(PolicyNetwork, self).__init__()
|
| 47 |
+
|
| 48 |
+
# Load pre-trained BERT
|
| 49 |
+
self.bert = AutoModel.from_pretrained(model_name)
|
| 50 |
+
|
| 51 |
+
# Load tokenizer
|
| 52 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 53 |
+
|
| 54 |
+
# Add special tokens for actions: [FETCH] and [NO_FETCH]
|
| 55 |
+
special_tokens = {"additional_special_tokens": ["[FETCH]", "[NO_FETCH]"]}
|
| 56 |
+
self.tokenizer.add_special_tokens(special_tokens)
|
| 57 |
+
|
| 58 |
+
# Resize BERT embeddings to accommodate new tokens
|
| 59 |
+
self.bert.resize_token_embeddings(len(self.tokenizer))
|
| 60 |
+
|
| 61 |
+
# Initialize random embeddings for special tokens
|
| 62 |
+
self._init_action_embeddings()
|
| 63 |
+
|
| 64 |
+
# Classification head: BERT hidden size (768) → 2 classes
|
| 65 |
+
self.classifier = nn.Linear(self.bert.config.hidden_size, 2)
|
| 66 |
+
|
| 67 |
+
# Dropout for regularization
|
| 68 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 69 |
+
|
| 70 |
+
def _init_action_embeddings(self):
|
| 71 |
+
"""
|
| 72 |
+
Initialize random embeddings for [FETCH] and [NO_FETCH] tokens.
|
| 73 |
+
These are learned during training.
|
| 74 |
+
"""
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
# Get token IDs for special tokens
|
| 77 |
+
fetch_id = self.tokenizer.convert_tokens_to_ids("[FETCH]")
|
| 78 |
+
no_fetch_id = self.tokenizer.convert_tokens_to_ids("[NO_FETCH]")
|
| 79 |
+
|
| 80 |
+
# Get embedding dimension
|
| 81 |
+
embedding_dim = self.bert.config.hidden_size
|
| 82 |
+
|
| 83 |
+
# Initialize with small random values (same as BERT initialization)
|
| 84 |
+
self.bert.embeddings.word_embeddings.weight[fetch_id] = torch.randn(embedding_dim) * 0.02
|
| 85 |
+
self.bert.embeddings.word_embeddings.weight[no_fetch_id] = torch.randn(embedding_dim) * 0.02
|
| 86 |
+
|
| 87 |
+
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 88 |
+
"""
|
| 89 |
+
Forward pass through BERT + classifier.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
input_ids: Tokenized input IDs (shape: [batch_size, seq_len])
|
| 93 |
+
attention_mask: Attention mask (shape: [batch_size, seq_len])
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
logits: Raw logits (shape: [batch_size, 2])
|
| 97 |
+
probs: Softmax probabilities (shape: [batch_size, 2])
|
| 98 |
+
"""
|
| 99 |
+
# Pass through BERT
|
| 100 |
+
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
| 101 |
+
|
| 102 |
+
# Extract [CLS] token representation (first token)
|
| 103 |
+
cls_output = outputs.last_hidden_state[:, 0, :]
|
| 104 |
+
|
| 105 |
+
# Apply dropout
|
| 106 |
+
cls_output = self.dropout(cls_output)
|
| 107 |
+
|
| 108 |
+
# Classification
|
| 109 |
+
logits = self.classifier(cls_output)
|
| 110 |
+
|
| 111 |
+
# Softmax for probabilities
|
| 112 |
+
probs = F.softmax(logits, dim=-1)
|
| 113 |
+
|
| 114 |
+
return logits, probs
|
| 115 |
+
|
| 116 |
+
def encode_state(
|
| 117 |
+
self,
|
| 118 |
+
state: Dict,
|
| 119 |
+
max_length: int = None
|
| 120 |
+
) -> Dict[str, torch.Tensor]:
|
| 121 |
+
"""
|
| 122 |
+
Encode conversation state into BERT input format.
|
| 123 |
+
|
| 124 |
+
State structure:
|
| 125 |
+
{
|
| 126 |
+
'previous_queries': [query1, query2, ...],
|
| 127 |
+
'previous_actions': ['FETCH', 'NO_FETCH', ...],
|
| 128 |
+
'current_query': 'user query'
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
Encoding format:
|
| 132 |
+
"Previous query 1: <text> [Action: [FETCH]] Previous query 2: <text> [Action: [NO_FETCH]] Current query: <text>"
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
state: State dictionary
|
| 136 |
+
max_length: Maximum sequence length (default from config)
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
dict: Tokenized inputs (input_ids, attention_mask)
|
| 140 |
+
"""
|
| 141 |
+
if max_length is None:
|
| 142 |
+
max_length = settings.POLICY_MAX_LEN
|
| 143 |
+
|
| 144 |
+
# Build state text from conversation history
|
| 145 |
+
state_text = ""
|
| 146 |
+
|
| 147 |
+
# Add previous queries and their actions
|
| 148 |
+
prev_queries = state.get('previous_queries', [])
|
| 149 |
+
prev_actions = state.get('previous_actions', [])
|
| 150 |
+
|
| 151 |
+
if prev_queries and prev_actions:
|
| 152 |
+
for i, (prev_query, prev_action) in enumerate(zip(prev_queries, prev_actions)):
|
| 153 |
+
state_text += f"Previous query {i+1}: {prev_query} [Action: [{prev_action}]] "
|
| 154 |
+
|
| 155 |
+
# Add current query
|
| 156 |
+
current_query = state.get('current_query', '')
|
| 157 |
+
state_text += f"Current query: {current_query}"
|
| 158 |
+
|
| 159 |
+
# Tokenize
|
| 160 |
+
encoding = self.tokenizer(
|
| 161 |
+
state_text,
|
| 162 |
+
truncation=True,
|
| 163 |
+
padding='max_length',
|
| 164 |
+
max_length=max_length,
|
| 165 |
+
return_tensors='pt'
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
return encoding
|
| 169 |
+
|
| 170 |
+
def predict_action(
|
| 171 |
+
self,
|
| 172 |
+
state: Dict,
|
| 173 |
+
use_dropout: bool = False,
|
| 174 |
+
num_samples: int = 10
|
| 175 |
+
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
| 176 |
+
"""
|
| 177 |
+
Predict action probabilities for a given state.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
state: Conversation state dictionary
|
| 181 |
+
use_dropout: Whether to use MC Dropout for uncertainty estimation
|
| 182 |
+
num_samples: Number of MC Dropout samples (if use_dropout=True)
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
probs: Action probabilities (shape: [1, 2]) - [P(FETCH), P(NO_FETCH)]
|
| 186 |
+
uncertainty: Standard deviation across samples (if use_dropout=True)
|
| 187 |
+
"""
|
| 188 |
+
device = next(self.parameters()).device
|
| 189 |
+
|
| 190 |
+
if use_dropout:
|
| 191 |
+
# MC Dropout for uncertainty estimation
|
| 192 |
+
self.train() # Enable dropout during inference
|
| 193 |
+
all_probs = []
|
| 194 |
+
|
| 195 |
+
for _ in range(num_samples):
|
| 196 |
+
with torch.no_grad():
|
| 197 |
+
encoding = self.encode_state(state)
|
| 198 |
+
input_ids = encoding['input_ids'].to(device)
|
| 199 |
+
attention_mask = encoding['attention_mask'].to(device)
|
| 200 |
+
|
| 201 |
+
_, probs = self.forward(input_ids, attention_mask)
|
| 202 |
+
all_probs.append(probs.cpu().numpy())
|
| 203 |
+
|
| 204 |
+
# Average probabilities across samples
|
| 205 |
+
avg_probs = np.mean(all_probs, axis=0)
|
| 206 |
+
|
| 207 |
+
# Calculate uncertainty (standard deviation)
|
| 208 |
+
uncertainty = np.std(all_probs, axis=0)
|
| 209 |
+
|
| 210 |
+
return avg_probs, uncertainty
|
| 211 |
+
|
| 212 |
+
else:
|
| 213 |
+
# Standard inference (no uncertainty estimation)
|
| 214 |
+
self.eval()
|
| 215 |
+
|
| 216 |
+
with torch.no_grad():
|
| 217 |
+
encoding = self.encode_state(state)
|
| 218 |
+
input_ids = encoding['input_ids'].to(device)
|
| 219 |
+
attention_mask = encoding['attention_mask'].to(device)
|
| 220 |
+
|
| 221 |
+
_, probs = self.forward(input_ids, attention_mask)
|
| 222 |
+
|
| 223 |
+
return probs.cpu().numpy(), None
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# ============================================================================
|
| 227 |
+
# MODULE-LEVEL CACHING (Load once on import)
|
| 228 |
+
# ============================================================================
|
| 229 |
+
|
| 230 |
+
# Global variables for caching
|
| 231 |
+
POLICY_MODEL: Optional[PolicyNetwork] = None
|
| 232 |
+
POLICY_TOKENIZER: Optional[AutoTokenizer] = None
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# =============================================================================================
|
| 237 |
+
# Latest version given by perplexity, should work, if not then use one of the other versions.
|
| 238 |
+
# =============================================================================================
|
| 239 |
+
|
| 240 |
+
def load_policy_model() -> PolicyNetwork:
|
| 241 |
+
"""
|
| 242 |
+
Load trained policy model (called once on startup).
|
| 243 |
+
Downloads from HuggingFace Hub if not present locally.
|
| 244 |
+
Uses module-level caching - model stays in RAM.
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
PolicyNetwork: Loaded policy model
|
| 248 |
+
"""
|
| 249 |
+
global POLICY_MODEL, POLICY_TOKENIZER
|
| 250 |
+
|
| 251 |
+
if POLICY_MODEL is None:
|
| 252 |
+
# Download model from HF Hub if needed (for deployment)
|
| 253 |
+
settings.download_model_if_needed(
|
| 254 |
+
hf_filename="models/best_policy_model.pth",
|
| 255 |
+
local_path=settings.POLICY_MODEL_PATH
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...")
|
| 259 |
+
|
| 260 |
+
try:
|
| 261 |
+
# Load checkpoint first to get vocab size
|
| 262 |
+
checkpoint = torch.load(settings.POLICY_MODEL_PATH, map_location=settings.DEVICE)
|
| 263 |
+
|
| 264 |
+
# Create model instance
|
| 265 |
+
POLICY_MODEL = PolicyNetwork(
|
| 266 |
+
model_name="bert-base-uncased",
|
| 267 |
+
dropout_rate=0.1
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# **KEY FIX**: Resize model embeddings to match saved checkpoint BEFORE loading weights
|
| 271 |
+
saved_vocab_size = checkpoint['bert.embeddings.word_embeddings.weight'].shape[0]
|
| 272 |
+
current_vocab_size = len(POLICY_MODEL.tokenizer)
|
| 273 |
+
|
| 274 |
+
if saved_vocab_size != current_vocab_size:
|
| 275 |
+
print(f"⚠️ Vocab size mismatch: saved={saved_vocab_size}, current={current_vocab_size}")
|
| 276 |
+
print(f"✅ Resizing tokenizer and embeddings to match saved model...")
|
| 277 |
+
# Resize model to match saved checkpoint
|
| 278 |
+
POLICY_MODEL.bert.resize_token_embeddings(saved_vocab_size)
|
| 279 |
+
|
| 280 |
+
# Move to device
|
| 281 |
+
POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
|
| 282 |
+
|
| 283 |
+
# Now load trained weights (sizes will match!)
|
| 284 |
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 285 |
+
POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
|
| 286 |
+
else:
|
| 287 |
+
POLICY_MODEL.load_state_dict(checkpoint)
|
| 288 |
+
|
| 289 |
+
# Set to evaluation mode
|
| 290 |
+
POLICY_MODEL.eval()
|
| 291 |
+
|
| 292 |
+
# Cache tokenizer
|
| 293 |
+
POLICY_TOKENIZER = POLICY_MODEL.tokenizer
|
| 294 |
+
|
| 295 |
+
print("✅ Policy network loaded and cached")
|
| 296 |
+
|
| 297 |
+
except FileNotFoundError:
|
| 298 |
+
print(f"❌ Policy model file not found: {settings.POLICY_MODEL_PATH}")
|
| 299 |
+
print(f"⚠️ Make sure models are uploaded to HuggingFace Hub: {settings.HF_MODEL_REPO}")
|
| 300 |
+
raise
|
| 301 |
+
except Exception as e:
|
| 302 |
+
print(f"❌ Failed to load policy model: {e}")
|
| 303 |
+
raise
|
| 304 |
+
|
| 305 |
+
return POLICY_MODEL
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# ===========================================================================
|
| 321 |
+
# This version is used in the code, atleast for localhost testing
|
| 322 |
+
# ===========================================================================
|
| 323 |
+
|
| 324 |
+
# def load_policy_model() -> PolicyNetwork:
|
| 325 |
+
# """
|
| 326 |
+
# Load trained policy model (called once on startup).
|
| 327 |
+
# Uses module-level caching - model stays in RAM.
|
| 328 |
+
|
| 329 |
+
# Returns:
|
| 330 |
+
# PolicyNetwork: Loaded policy model
|
| 331 |
+
# """
|
| 332 |
+
# global POLICY_MODEL, POLICY_TOKENIZER
|
| 333 |
+
|
| 334 |
+
# if POLICY_MODEL is None:
|
| 335 |
+
# print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...")
|
| 336 |
+
|
| 337 |
+
# try:
|
| 338 |
+
# # Load checkpoint first to get vocab size
|
| 339 |
+
# checkpoint = torch.load(settings.POLICY_MODEL_PATH, map_location=settings.DEVICE)
|
| 340 |
+
|
| 341 |
+
# # Create model instance
|
| 342 |
+
# POLICY_MODEL = PolicyNetwork(
|
| 343 |
+
# model_name="bert-base-uncased",
|
| 344 |
+
# dropout_rate=0.1
|
| 345 |
+
# )
|
| 346 |
+
|
| 347 |
+
# # **KEY FIX**: Resize model embeddings to match saved checkpoint BEFORE loading weights
|
| 348 |
+
# saved_vocab_size = checkpoint['bert.embeddings.word_embeddings.weight'].shape[0]
|
| 349 |
+
# current_vocab_size = len(POLICY_MODEL.tokenizer)
|
| 350 |
+
|
| 351 |
+
# if saved_vocab_size != current_vocab_size:
|
| 352 |
+
# print(f"⚠️ Vocab size mismatch: saved={saved_vocab_size}, current={current_vocab_size}")
|
| 353 |
+
# print(f"✅ Resizing tokenizer and embeddings to match saved model...")
|
| 354 |
+
|
| 355 |
+
# # Resize model to match saved checkpoint
|
| 356 |
+
# POLICY_MODEL.bert.resize_token_embeddings(saved_vocab_size)
|
| 357 |
+
|
| 358 |
+
# # Move to device
|
| 359 |
+
# POLICY_MODEL = POLICY_MODEL.to(settings.DEVICE)
|
| 360 |
+
|
| 361 |
+
# # Now load trained weights (sizes will match!)
|
| 362 |
+
# if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 363 |
+
# POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
|
| 364 |
+
# else:
|
| 365 |
+
# POLICY_MODEL.load_state_dict(checkpoint)
|
| 366 |
+
|
| 367 |
+
# # Set to evaluation mode
|
| 368 |
+
# POLICY_MODEL.eval()
|
| 369 |
+
|
| 370 |
+
# # Cache tokenizer
|
| 371 |
+
# POLICY_TOKENIZER = POLICY_MODEL.tokenizer
|
| 372 |
+
|
| 373 |
+
# print("✅ Policy network loaded and cached")
|
| 374 |
+
|
| 375 |
+
# except FileNotFoundError:
|
| 376 |
+
# print(f"❌ Policy model file not found: {settings.POLICY_MODEL_PATH}")
|
| 377 |
+
# print("⚠️ You need to train the policy network first!")
|
| 378 |
+
# raise
|
| 379 |
+
|
| 380 |
+
# except Exception as e:
|
| 381 |
+
# print(f"❌ Failed to load policy model: {e}")
|
| 382 |
+
# raise
|
| 383 |
+
|
| 384 |
+
# return POLICY_MODEL
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
# =====================================================================================
|
| 409 |
+
# This is the older version or proably a different version, potentially still useful
|
| 410 |
+
# =====================================================================================
|
| 411 |
+
|
| 412 |
+
# def load_policy_model() -> PolicyNetwork:
|
| 413 |
+
# """
|
| 414 |
+
# Load trained policy model (called once on startup).
|
| 415 |
+
# Uses module-level caching - model stays in RAM.
|
| 416 |
+
|
| 417 |
+
# Returns:
|
| 418 |
+
# PolicyNetwork: Loaded policy model
|
| 419 |
+
# """
|
| 420 |
+
# global POLICY_MODEL, POLICY_TOKENIZER
|
| 421 |
+
|
| 422 |
+
# if POLICY_MODEL is None:
|
| 423 |
+
# print(f"Loading policy network from {settings.POLICY_MODEL_PATH}...")
|
| 424 |
+
|
| 425 |
+
# try:
|
| 426 |
+
# # Create model instance
|
| 427 |
+
# POLICY_MODEL = PolicyNetwork(
|
| 428 |
+
# model_name="bert-base-uncased",
|
| 429 |
+
# dropout_rate=0.1
|
| 430 |
+
# ).to(settings.DEVICE)
|
| 431 |
+
|
| 432 |
+
# # Load trained weights
|
| 433 |
+
# checkpoint = torch.load(settings.POLICY_MODEL_PATH, map_location=settings.DEVICE)
|
| 434 |
+
|
| 435 |
+
# # Handle different checkpoint formats
|
| 436 |
+
# if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 437 |
+
# # Full checkpoint with metadata
|
| 438 |
+
# POLICY_MODEL.load_state_dict(checkpoint['model_state_dict'])
|
| 439 |
+
# else:
|
| 440 |
+
# # Just state dict
|
| 441 |
+
# POLICY_MODEL.load_state_dict(checkpoint)
|
| 442 |
+
|
| 443 |
+
# # Set to evaluation mode
|
| 444 |
+
# POLICY_MODEL.eval()
|
| 445 |
+
|
| 446 |
+
# # Cache tokenizer
|
| 447 |
+
# POLICY_TOKENIZER = POLICY_MODEL.tokenizer
|
| 448 |
+
|
| 449 |
+
# print("✅ Policy network loaded and cached")
|
| 450 |
+
|
| 451 |
+
# except FileNotFoundError:
|
| 452 |
+
# print(f"❌ Policy model file not found: {settings.POLICY_MODEL_PATH}")
|
| 453 |
+
# print("⚠️ You need to train the policy network first!")
|
| 454 |
+
# raise
|
| 455 |
+
|
| 456 |
+
# except Exception as e:
|
| 457 |
+
# print(f"❌ Failed to load policy model: {e}")
|
| 458 |
+
# raise
|
| 459 |
+
|
| 460 |
+
# return POLICY_MODEL
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
# ============================================================================
|
| 471 |
+
# PREDICTION FUNCTIONS
|
| 472 |
+
# ============================================================================
|
| 473 |
+
|
| 474 |
+
def create_state_from_history(
|
| 475 |
+
current_query: str,
|
| 476 |
+
conversation_history: List[Dict],
|
| 477 |
+
max_history: int = 2
|
| 478 |
+
) -> Dict:
|
| 479 |
+
"""
|
| 480 |
+
Create state dictionary from conversation history.
|
| 481 |
+
Extracts last N query-action pairs.
|
| 482 |
+
|
| 483 |
+
Args:
|
| 484 |
+
current_query: Current user query
|
| 485 |
+
conversation_history: List of conversation turns
|
| 486 |
+
Each turn: {'role': 'user'/'assistant', 'content': '...', 'metadata': {...}}
|
| 487 |
+
max_history: Maximum number of previous turns to include (default: 2)
|
| 488 |
+
|
| 489 |
+
Returns:
|
| 490 |
+
dict: State dictionary for policy network
|
| 491 |
+
"""
|
| 492 |
+
state = {
|
| 493 |
+
'current_query': current_query,
|
| 494 |
+
'previous_queries': [],
|
| 495 |
+
'previous_actions': []
|
| 496 |
+
}
|
| 497 |
+
|
| 498 |
+
if not conversation_history:
|
| 499 |
+
return state
|
| 500 |
+
|
| 501 |
+
# Extract last N conversation turns (user + assistant pairs)
|
| 502 |
+
relevant_history = conversation_history[-(max_history * 2):]
|
| 503 |
+
|
| 504 |
+
for i, turn in enumerate(relevant_history):
|
| 505 |
+
# User turns
|
| 506 |
+
if turn.get('role') == 'user':
|
| 507 |
+
query = turn.get('content', '')
|
| 508 |
+
state['previous_queries'].append(query)
|
| 509 |
+
|
| 510 |
+
# Look for corresponding assistant turn
|
| 511 |
+
if i + 1 < len(relevant_history):
|
| 512 |
+
bot_turn = relevant_history[i + 1]
|
| 513 |
+
if bot_turn.get('role') == 'assistant':
|
| 514 |
+
metadata = bot_turn.get('metadata', {})
|
| 515 |
+
action = metadata.get('policy_action', 'FETCH')
|
| 516 |
+
state['previous_actions'].append(action)
|
| 517 |
+
|
| 518 |
+
return state
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def predict_policy_action(
|
| 522 |
+
query: str,
|
| 523 |
+
history: List[Dict] = None,
|
| 524 |
+
return_probs: bool = False
|
| 525 |
+
) -> Dict:
|
| 526 |
+
"""
|
| 527 |
+
Predict FETCH/NO_FETCH action for a query.
|
| 528 |
+
|
| 529 |
+
Args:
|
| 530 |
+
query: User query text
|
| 531 |
+
history: Conversation history (optional)
|
| 532 |
+
return_probs: Whether to return full probability distribution
|
| 533 |
+
|
| 534 |
+
Returns:
|
| 535 |
+
dict: Prediction results
|
| 536 |
+
{
|
| 537 |
+
'action': 'FETCH' or 'NO_FETCH',
|
| 538 |
+
'confidence': float (0-1),
|
| 539 |
+
'fetch_prob': float,
|
| 540 |
+
'no_fetch_prob': float,
|
| 541 |
+
'should_retrieve': bool
|
| 542 |
+
}
|
| 543 |
+
"""
|
| 544 |
+
# Load model (cached after first call)
|
| 545 |
+
model = load_policy_model()
|
| 546 |
+
|
| 547 |
+
# Create state from history
|
| 548 |
+
if history is None:
|
| 549 |
+
history = []
|
| 550 |
+
|
| 551 |
+
state = create_state_from_history(query, history)
|
| 552 |
+
|
| 553 |
+
# Predict action
|
| 554 |
+
probs, _ = model.predict_action(state, use_dropout=False)
|
| 555 |
+
|
| 556 |
+
# Extract probabilities
|
| 557 |
+
fetch_prob = float(probs[0][0])
|
| 558 |
+
no_fetch_prob = float(probs[0][1])
|
| 559 |
+
|
| 560 |
+
# Determine action (argmax)
|
| 561 |
+
action_idx = np.argmax(probs[0])
|
| 562 |
+
action = "FETCH" if action_idx == 0 else "NO_FETCH"
|
| 563 |
+
confidence = float(probs[0][action_idx])
|
| 564 |
+
|
| 565 |
+
# Check confidence threshold
|
| 566 |
+
should_retrieve = (action == "FETCH") or (action == "NO_FETCH" and confidence < settings.CONFIDENCE_THRESHOLD)
|
| 567 |
+
|
| 568 |
+
result = {
|
| 569 |
+
'action': action,
|
| 570 |
+
'confidence': confidence,
|
| 571 |
+
'should_retrieve': should_retrieve,
|
| 572 |
+
'policy_decision': action
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
if return_probs:
|
| 576 |
+
result['fetch_prob'] = fetch_prob
|
| 577 |
+
result['no_fetch_prob'] = no_fetch_prob
|
| 578 |
+
|
| 579 |
+
return result
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
# ============================================================================
|
| 583 |
+
# USAGE EXAMPLE (for reference)
|
| 584 |
+
# ============================================================================
|
| 585 |
+
"""
|
| 586 |
+
# In your service file:
|
| 587 |
+
|
| 588 |
+
from app.ml.policy_network import predict_policy_action
|
| 589 |
+
|
| 590 |
+
# Predict action
|
| 591 |
+
history = [
|
| 592 |
+
{'role': 'user', 'content': 'What is my balance?'},
|
| 593 |
+
{'role': 'assistant', 'content': '$1000', 'metadata': {'policy_action': 'FETCH'}}
|
| 594 |
+
]
|
| 595 |
+
|
| 596 |
+
result = predict_policy_action(
|
| 597 |
+
query="Thank you!",
|
| 598 |
+
history=history,
|
| 599 |
+
return_probs=True
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
print(result)
|
| 603 |
+
# {
|
| 604 |
+
# 'action': 'NO_FETCH',
|
| 605 |
+
# 'confidence': 0.95,
|
| 606 |
+
# 'should_retrieve': False,
|
| 607 |
+
# 'fetch_prob': 0.05,
|
| 608 |
+
# 'no_fetch_prob': 0.95
|
| 609 |
+
# }
|
| 610 |
+
"""
|
app/ml/retriever.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom Retriever with E5-Base-V2 and FAISS
|
| 3 |
+
Trained with InfoNCE + Triplet Loss for banking domain
|
| 4 |
+
|
| 5 |
+
This is adapted from your RAG.py with:
|
| 6 |
+
- CustomSentenceTransformer (e5-base-v2)
|
| 7 |
+
- Mean pooling + L2 normalization
|
| 8 |
+
- FAISS vector search
|
| 9 |
+
- Module-level caching (load once on startup)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import json
|
| 14 |
+
import pickle
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import faiss
|
| 19 |
+
import numpy as np
|
| 20 |
+
from typing import List, Dict, Optional
|
| 21 |
+
from transformers import AutoTokenizer, AutoModel
|
| 22 |
+
|
| 23 |
+
from app.config import settings
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ============================================================================
|
| 27 |
+
# CUSTOM SENTENCE TRANSFORMER (From RAG.py)
|
| 28 |
+
# ============================================================================
|
| 29 |
+
|
| 30 |
+
class CustomSentenceTransformer(nn.Module):
|
| 31 |
+
"""
|
| 32 |
+
Custom SentenceTransformer matching your training code.
|
| 33 |
+
Uses e5-base-v2 with mean pooling and L2 normalization.
|
| 34 |
+
|
| 35 |
+
Training Details:
|
| 36 |
+
- Base model: intfloat/e5-base-v2
|
| 37 |
+
- Loss: InfoNCE + Triplet Loss
|
| 38 |
+
- Pooling: Mean pooling on last hidden state
|
| 39 |
+
- Normalization: L2 normalization
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, model_name: str = "intfloat/e5-base-v2"):
|
| 43 |
+
super().__init__()
|
| 44 |
+
# Load pre-trained e5-base-v2 encoder
|
| 45 |
+
self.encoder = AutoModel.from_pretrained(model_name)
|
| 46 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 47 |
+
self.config = self.encoder.config
|
| 48 |
+
|
| 49 |
+
def forward(self, input_ids, attention_mask):
|
| 50 |
+
"""
|
| 51 |
+
Forward pass through BERT encoder.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
input_ids: Tokenized input IDs
|
| 55 |
+
attention_mask: Attention mask for padding
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
torch.Tensor: L2-normalized embeddings (shape: [batch_size, 768])
|
| 59 |
+
"""
|
| 60 |
+
# Get BERT outputs
|
| 61 |
+
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
| 62 |
+
|
| 63 |
+
# Mean pooling - same as training
|
| 64 |
+
# Take hidden states from last layer
|
| 65 |
+
token_embeddings = outputs.last_hidden_state
|
| 66 |
+
|
| 67 |
+
# Expand attention mask to match token embeddings shape
|
| 68 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 69 |
+
|
| 70 |
+
# Sum embeddings (weighted by attention mask) and divide by sum of mask
|
| 71 |
+
embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
|
| 72 |
+
input_mask_expanded.sum(1), min=1e-9
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# L2 normalize embeddings - same as training
|
| 76 |
+
embeddings = F.normalize(embeddings, p=2, dim=1)
|
| 77 |
+
|
| 78 |
+
return embeddings
|
| 79 |
+
|
| 80 |
+
def encode(
|
| 81 |
+
self,
|
| 82 |
+
sentences: List[str],
|
| 83 |
+
batch_size: int = 32,
|
| 84 |
+
convert_to_numpy: bool = True,
|
| 85 |
+
show_progress_bar: bool = False
|
| 86 |
+
) -> np.ndarray:
|
| 87 |
+
"""
|
| 88 |
+
Encode sentences using the same method as training.
|
| 89 |
+
Adds 'query: ' prefix for e5-base-v2 compatibility.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
sentences: List of sentences to encode
|
| 93 |
+
batch_size: Batch size for encoding
|
| 94 |
+
convert_to_numpy: Whether to convert to numpy array
|
| 95 |
+
show_progress_bar: Whether to show progress bar
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
np.ndarray: Encoded embeddings (shape: [num_sentences, 768])
|
| 99 |
+
"""
|
| 100 |
+
self.eval() # Set model to evaluation mode
|
| 101 |
+
|
| 102 |
+
# Handle single string input
|
| 103 |
+
if isinstance(sentences, str):
|
| 104 |
+
sentences = [sentences]
|
| 105 |
+
|
| 106 |
+
# Add 'query: ' prefix for e5-base-v2 (required by model)
|
| 107 |
+
# Handle None values and empty strings
|
| 108 |
+
processed_sentences = []
|
| 109 |
+
for sentence in sentences:
|
| 110 |
+
if sentence is None:
|
| 111 |
+
processed_sentences.append("query: ") # Default empty query
|
| 112 |
+
elif isinstance(sentence, str):
|
| 113 |
+
processed_sentences.append(f"query: {sentence.strip()}")
|
| 114 |
+
else:
|
| 115 |
+
processed_sentences.append(f"query: {str(sentence)}")
|
| 116 |
+
|
| 117 |
+
all_embeddings = []
|
| 118 |
+
|
| 119 |
+
# Encode in batches
|
| 120 |
+
with torch.no_grad(): # No gradient computation
|
| 121 |
+
for i in range(0, len(processed_sentences), batch_size):
|
| 122 |
+
batch_sentences = processed_sentences[i:i + batch_size]
|
| 123 |
+
|
| 124 |
+
# Tokenize batch
|
| 125 |
+
tokens = self.tokenizer(
|
| 126 |
+
batch_sentences,
|
| 127 |
+
truncation=True,
|
| 128 |
+
padding=True,
|
| 129 |
+
max_length=128, # Same as training
|
| 130 |
+
return_tensors='pt'
|
| 131 |
+
).to(next(self.parameters()).device)
|
| 132 |
+
|
| 133 |
+
# Get embeddings
|
| 134 |
+
embeddings = self.forward(tokens['input_ids'], tokens['attention_mask'])
|
| 135 |
+
|
| 136 |
+
# Convert to numpy if requested
|
| 137 |
+
if convert_to_numpy:
|
| 138 |
+
embeddings = embeddings.cpu().numpy()
|
| 139 |
+
|
| 140 |
+
all_embeddings.append(embeddings)
|
| 141 |
+
|
| 142 |
+
# Combine all batches
|
| 143 |
+
if convert_to_numpy:
|
| 144 |
+
all_embeddings = np.vstack(all_embeddings)
|
| 145 |
+
else:
|
| 146 |
+
all_embeddings = torch.cat(all_embeddings, dim=0)
|
| 147 |
+
|
| 148 |
+
return all_embeddings
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# ============================================================================
|
| 152 |
+
# CUSTOM RETRIEVER MODEL (Wrapper)
|
| 153 |
+
# ============================================================================
|
| 154 |
+
|
| 155 |
+
class CustomRetrieverModel:
|
| 156 |
+
"""
|
| 157 |
+
Wrapper for your custom trained retriever model.
|
| 158 |
+
Handles both knowledge base documents and query encoding.
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
def __init__(self, model_path: str, device: str = "cpu"):
|
| 162 |
+
"""
|
| 163 |
+
Initialize retriever model.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
model_path: Path to trained model weights (.pth file)
|
| 167 |
+
device: Device to load model on ('cpu' or 'cuda')
|
| 168 |
+
"""
|
| 169 |
+
self.device = device
|
| 170 |
+
|
| 171 |
+
# Create model instance
|
| 172 |
+
self.model = CustomSentenceTransformer("intfloat/e5-base-v2").to(device)
|
| 173 |
+
|
| 174 |
+
# Load your trained weights
|
| 175 |
+
try:
|
| 176 |
+
state_dict = torch.load(model_path, map_location=device)
|
| 177 |
+
self.model.load_state_dict(state_dict)
|
| 178 |
+
print(f"✅ Custom retriever model loaded from {model_path}")
|
| 179 |
+
except Exception as e:
|
| 180 |
+
print(f"❌ Failed to load custom model: {e}")
|
| 181 |
+
print("🔄 Using base e5-base-v2 model (not trained)...")
|
| 182 |
+
|
| 183 |
+
# Set to evaluation mode
|
| 184 |
+
self.model.eval()
|
| 185 |
+
|
| 186 |
+
def encode_documents(self, documents: List[str], batch_size: int = 32) -> np.ndarray:
|
| 187 |
+
"""
|
| 188 |
+
Encode knowledge base documents.
|
| 189 |
+
These are the responses/instructions we're retrieving.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
documents: List of document texts
|
| 193 |
+
batch_size: Batch size for encoding
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
np.ndarray: Document embeddings (shape: [num_docs, 768])
|
| 197 |
+
"""
|
| 198 |
+
return self.model.encode(documents, batch_size=batch_size, convert_to_numpy=True)
|
| 199 |
+
|
| 200 |
+
def encode_query(self, query: str) -> np.ndarray:
|
| 201 |
+
"""
|
| 202 |
+
Encode user query for retrieval.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
query: User query text
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
np.ndarray: Query embedding (shape: [1, 768])
|
| 209 |
+
"""
|
| 210 |
+
return self.model.encode([query], convert_to_numpy=True)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# ============================================================================
|
| 214 |
+
# MODULE-LEVEL CACHING (Load once on import)
|
| 215 |
+
# ============================================================================
|
| 216 |
+
|
| 217 |
+
# Global variables for caching
|
| 218 |
+
RETRIEVER_MODEL: Optional[CustomRetrieverModel] = None
|
| 219 |
+
FAISS_INDEX: Optional[faiss.Index] = None
|
| 220 |
+
KB_DATA: Optional[List[Dict]] = None
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# =============================================================================================
|
| 227 |
+
# Latest version given by perplexity, should work, if not then use one of the other versions.
|
| 228 |
+
# =============================================================================================
|
| 229 |
+
|
| 230 |
+
def load_retriever() -> CustomRetrieverModel:
|
| 231 |
+
"""
|
| 232 |
+
Load custom retriever model (called once on startup).
|
| 233 |
+
Downloads from HuggingFace Hub if not present locally.
|
| 234 |
+
Uses module-level caching - model stays in RAM.
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
CustomRetrieverModel: Loaded retriever model
|
| 238 |
+
"""
|
| 239 |
+
global RETRIEVER_MODEL
|
| 240 |
+
|
| 241 |
+
if RETRIEVER_MODEL is None:
|
| 242 |
+
# Download model from HF Hub if needed (for deployment)
|
| 243 |
+
settings.download_model_if_needed(
|
| 244 |
+
hf_filename="models/best_retriever_model.pth",
|
| 245 |
+
local_path=settings.RETRIEVER_MODEL_PATH
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
print(f"Loading custom retriever from {settings.RETRIEVER_MODEL_PATH}...")
|
| 249 |
+
|
| 250 |
+
RETRIEVER_MODEL = CustomRetrieverModel(
|
| 251 |
+
model_path=settings.RETRIEVER_MODEL_PATH,
|
| 252 |
+
device=settings.DEVICE
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
print("✅ Retriever model loaded and cached")
|
| 256 |
+
|
| 257 |
+
return RETRIEVER_MODEL
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# ===========================================================================
|
| 268 |
+
# This version is used in the code, atleast for localhost testing
|
| 269 |
+
# ===========================================================================
|
| 270 |
+
|
| 271 |
+
# def load_retriever() -> CustomRetrieverModel:
|
| 272 |
+
# """
|
| 273 |
+
# Load custom retriever model (called once on startup).
|
| 274 |
+
# Uses module-level caching - model stays in RAM.
|
| 275 |
+
|
| 276 |
+
# Returns:
|
| 277 |
+
# CustomRetrieverModel: Loaded retriever model
|
| 278 |
+
# """
|
| 279 |
+
# global RETRIEVER_MODEL
|
| 280 |
+
|
| 281 |
+
# if RETRIEVER_MODEL is None:
|
| 282 |
+
# print(f"Loading custom retriever from {settings.RETRIEVER_MODEL_PATH}...")
|
| 283 |
+
# RETRIEVER_MODEL = CustomRetrieverModel(
|
| 284 |
+
# model_path=settings.RETRIEVER_MODEL_PATH,
|
| 285 |
+
# device=settings.DEVICE
|
| 286 |
+
# )
|
| 287 |
+
# print("✅ Retriever model loaded and cached")
|
| 288 |
+
|
| 289 |
+
# return RETRIEVER_MODEL
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# =============================================================================================
|
| 294 |
+
# Latest version given by perplexity, should work, if not then use one of the other versions.
|
| 295 |
+
# =============================================================================================
|
| 296 |
+
|
| 297 |
+
def load_faiss_index():
|
| 298 |
+
"""
|
| 299 |
+
Load FAISS index + knowledge base from pickle file.
|
| 300 |
+
Downloads from HuggingFace Hub if not present locally.
|
| 301 |
+
Uses module-level caching - loaded once on startup.
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
tuple: (faiss.Index, List[Dict]) - FAISS index and KB data
|
| 305 |
+
"""
|
| 306 |
+
global FAISS_INDEX, KB_DATA
|
| 307 |
+
|
| 308 |
+
if FAISS_INDEX is None or KB_DATA is None:
|
| 309 |
+
# Download FAISS index from HF Hub if needed (for deployment)
|
| 310 |
+
settings.download_model_if_needed(
|
| 311 |
+
hf_filename="models/faiss_index.pkl",
|
| 312 |
+
local_path=settings.FAISS_INDEX_PATH
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
# Download knowledge base from HF Hub if needed (for deployment)
|
| 316 |
+
settings.download_model_if_needed(
|
| 317 |
+
hf_filename="data/final_knowledge_base.jsonl",
|
| 318 |
+
local_path=settings.KB_PATH
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
print(f"Loading FAISS index from {settings.FAISS_INDEX_PATH}...")
|
| 322 |
+
|
| 323 |
+
try:
|
| 324 |
+
# Load pickled FAISS index + KB data
|
| 325 |
+
with open(settings.FAISS_INDEX_PATH, 'rb') as f:
|
| 326 |
+
FAISS_INDEX, KB_DATA = pickle.load(f)
|
| 327 |
+
|
| 328 |
+
print(f"✅ FAISS index loaded: {FAISS_INDEX.ntotal} vectors")
|
| 329 |
+
print(f"✅ Knowledge base loaded: {len(KB_DATA)} documents")
|
| 330 |
+
|
| 331 |
+
except FileNotFoundError:
|
| 332 |
+
print(f"❌ FAISS index file not found: {settings.FAISS_INDEX_PATH}")
|
| 333 |
+
print(f"⚠️ Make sure models are uploaded to HuggingFace Hub: {settings.HF_MODEL_REPO}")
|
| 334 |
+
raise
|
| 335 |
+
except Exception as e:
|
| 336 |
+
print(f"❌ Failed to load FAISS index: {e}")
|
| 337 |
+
raise
|
| 338 |
+
|
| 339 |
+
return FAISS_INDEX, KB_DATA
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
# ===========================================================================
|
| 344 |
+
# This version is used in the code, atleast for localhost testing
|
| 345 |
+
# ===========================================================================
|
| 346 |
+
|
| 347 |
+
# def load_faiss_index():
|
| 348 |
+
# """
|
| 349 |
+
# Load FAISS index + knowledge base from pickle file.
|
| 350 |
+
# Uses module-level caching - loaded once on startup.
|
| 351 |
+
|
| 352 |
+
# Returns:
|
| 353 |
+
# tuple: (faiss.Index, List[Dict]) - FAISS index and KB data
|
| 354 |
+
# """
|
| 355 |
+
# global FAISS_INDEX, KB_DATA
|
| 356 |
+
|
| 357 |
+
# if FAISS_INDEX is None or KB_DATA is None:
|
| 358 |
+
# print(f"Loading FAISS index from {settings.FAISS_INDEX_PATH}...")
|
| 359 |
+
|
| 360 |
+
# try:
|
| 361 |
+
# # Load pickled FAISS index + KB data
|
| 362 |
+
# with open(settings.FAISS_INDEX_PATH, 'rb') as f:
|
| 363 |
+
# FAISS_INDEX, KB_DATA = pickle.load(f)
|
| 364 |
+
|
| 365 |
+
# print(f"✅ FAISS index loaded: {FAISS_INDEX.ntotal} vectors")
|
| 366 |
+
# print(f"✅ Knowledge base loaded: {len(KB_DATA)} documents")
|
| 367 |
+
|
| 368 |
+
# except FileNotFoundError:
|
| 369 |
+
# print(f"❌ FAISS index file not found: {settings.FAISS_INDEX_PATH}")
|
| 370 |
+
# print("⚠️ You need to create the FAISS index first!")
|
| 371 |
+
# raise
|
| 372 |
+
|
| 373 |
+
# except Exception as e:
|
| 374 |
+
# print(f"❌ Failed to load FAISS index: {e}")
|
| 375 |
+
# raise
|
| 376 |
+
|
| 377 |
+
# return FAISS_INDEX, KB_DATA
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
# ============================================================================
|
| 381 |
+
# RETRIEVAL FUNCTIONS
|
| 382 |
+
# ============================================================================
|
| 383 |
+
|
| 384 |
+
def retrieve_documents(
|
| 385 |
+
query: str,
|
| 386 |
+
top_k: int = None,
|
| 387 |
+
min_similarity: float = None
|
| 388 |
+
) -> List[Dict]:
|
| 389 |
+
"""
|
| 390 |
+
Retrieve top-k documents for a query using custom retriever + FAISS.
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
query: User query text
|
| 394 |
+
top_k: Number of documents to retrieve (default from config)
|
| 395 |
+
min_similarity: Minimum similarity threshold (default from config)
|
| 396 |
+
|
| 397 |
+
Returns:
|
| 398 |
+
List[Dict]: Retrieved documents with scores
|
| 399 |
+
Each dict contains:
|
| 400 |
+
- instruction: FAQ question
|
| 401 |
+
- response: FAQ answer
|
| 402 |
+
- category: Document category
|
| 403 |
+
- intent: Document intent
|
| 404 |
+
- score: Similarity score (0-1)
|
| 405 |
+
- rank: Rank in results (1-indexed)
|
| 406 |
+
- faq_id: Document ID
|
| 407 |
+
"""
|
| 408 |
+
# Use config defaults if not provided
|
| 409 |
+
if top_k is None:
|
| 410 |
+
top_k = settings.TOP_K
|
| 411 |
+
if min_similarity is None:
|
| 412 |
+
min_similarity = settings.SIMILARITY_THRESHOLD
|
| 413 |
+
|
| 414 |
+
# Validate query
|
| 415 |
+
if not query or query.strip() == "":
|
| 416 |
+
print("⚠️ Empty query provided")
|
| 417 |
+
return []
|
| 418 |
+
|
| 419 |
+
# Load models (cached, no overhead after first call)
|
| 420 |
+
retriever = load_retriever()
|
| 421 |
+
index, kb = load_faiss_index()
|
| 422 |
+
|
| 423 |
+
try:
|
| 424 |
+
# Step 1: Encode query
|
| 425 |
+
query_embedding = retriever.encode_query(query)
|
| 426 |
+
|
| 427 |
+
# Step 2: Normalize for cosine similarity
|
| 428 |
+
faiss.normalize_L2(query_embedding)
|
| 429 |
+
|
| 430 |
+
# Step 3: Search in FAISS index
|
| 431 |
+
similarities, indices = index.search(query_embedding, top_k)
|
| 432 |
+
|
| 433 |
+
# Step 4: Check similarity threshold for top result
|
| 434 |
+
if similarities[0][0] < min_similarity:
|
| 435 |
+
print(f"🚫 NO_FETCH (similarity: {similarities[0][0]:.3f} < {min_similarity})")
|
| 436 |
+
return []
|
| 437 |
+
|
| 438 |
+
print(f"✅ FETCH (similarity: {similarities[0][0]:.3f} >= {min_similarity})")
|
| 439 |
+
|
| 440 |
+
# Step 5: Format results
|
| 441 |
+
results = []
|
| 442 |
+
for rank, (similarity, idx) in enumerate(zip(similarities[0], indices[0])):
|
| 443 |
+
if idx < len(kb):
|
| 444 |
+
doc = kb[idx]
|
| 445 |
+
results.append({
|
| 446 |
+
'instruction': doc.get('instruction', ''),
|
| 447 |
+
'response': doc.get('response', ''),
|
| 448 |
+
'category': doc.get('category', 'Unknown'),
|
| 449 |
+
'intent': doc.get('intent', 'Unknown'),
|
| 450 |
+
'score': float(similarity),
|
| 451 |
+
'rank': rank + 1,
|
| 452 |
+
'faq_id': doc.get('faq_id', f'doc_{idx}')
|
| 453 |
+
})
|
| 454 |
+
|
| 455 |
+
return results
|
| 456 |
+
|
| 457 |
+
except Exception as e:
|
| 458 |
+
print(f"❌ Retrieval error: {e}")
|
| 459 |
+
import traceback
|
| 460 |
+
traceback.print_exc()
|
| 461 |
+
return []
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def format_context(retrieved_docs: List[Dict], max_context_length: int = None) -> str:
|
| 465 |
+
"""
|
| 466 |
+
Format retrieved documents into context string for LLM.
|
| 467 |
+
Prioritizes by score and limits total length.
|
| 468 |
+
|
| 469 |
+
Args:
|
| 470 |
+
retrieved_docs: List of retrieved documents
|
| 471 |
+
max_context_length: Maximum context length in characters
|
| 472 |
+
|
| 473 |
+
Returns:
|
| 474 |
+
str: Formatted context string
|
| 475 |
+
"""
|
| 476 |
+
if max_context_length is None:
|
| 477 |
+
max_context_length = settings.MAX_CONTEXT_LENGTH
|
| 478 |
+
|
| 479 |
+
if not retrieved_docs:
|
| 480 |
+
return ""
|
| 481 |
+
|
| 482 |
+
context_parts = []
|
| 483 |
+
current_length = 0
|
| 484 |
+
|
| 485 |
+
for doc in retrieved_docs:
|
| 486 |
+
# Create context entry with None checks
|
| 487 |
+
instruction = doc.get('instruction', '') or ''
|
| 488 |
+
response = doc.get('response', '') or ''
|
| 489 |
+
category = doc.get('category', 'N/A') or 'N/A'
|
| 490 |
+
|
| 491 |
+
context_entry = f"[Rank {doc['rank']}, Score: {doc['score']:.3f}]\n"
|
| 492 |
+
context_entry += f"Q: {instruction}\n"
|
| 493 |
+
context_entry += f"A: {response}\n"
|
| 494 |
+
context_entry += f"Category: {category}\n\n"
|
| 495 |
+
|
| 496 |
+
# Check length limit
|
| 497 |
+
if current_length + len(context_entry) > max_context_length:
|
| 498 |
+
break
|
| 499 |
+
|
| 500 |
+
context_parts.append(context_entry)
|
| 501 |
+
current_length += len(context_entry)
|
| 502 |
+
|
| 503 |
+
return "".join(context_parts)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
# ============================================================================
|
| 507 |
+
# USAGE EXAMPLE (for reference)
|
| 508 |
+
# ============================================================================
|
| 509 |
+
"""
|
| 510 |
+
# In your service file:
|
| 511 |
+
|
| 512 |
+
from app.ml.retriever import retrieve_documents, format_context
|
| 513 |
+
|
| 514 |
+
# Retrieve documents
|
| 515 |
+
docs = retrieve_documents("What is my account balance?", top_k=5)
|
| 516 |
+
|
| 517 |
+
# Format context for LLM
|
| 518 |
+
context = format_context(docs)
|
| 519 |
+
|
| 520 |
+
# Use context in LLM prompt
|
| 521 |
+
prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
|
| 522 |
+
"""
|
app/models/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic models package
|
| 3 |
+
"""
|
app/models/user.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
User Models for Authentication
|
| 3 |
+
Pydantic models for user registration, login, and responses
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from pydantic import BaseModel, EmailStr, Field
|
| 7 |
+
from typing import Optional
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class UserRegister(BaseModel):
|
| 12 |
+
"""User registration request"""
|
| 13 |
+
email: EmailStr
|
| 14 |
+
password: str = Field(..., min_length=6, max_length=100)
|
| 15 |
+
full_name: str = Field(..., min_length=2, max_length=100)
|
| 16 |
+
|
| 17 |
+
class Config:
|
| 18 |
+
json_schema_extra = {
|
| 19 |
+
"example": {
|
| 20 |
+
"email": "user@example.com",
|
| 21 |
+
"password": "SecurePass123",
|
| 22 |
+
"full_name": "John Doe"
|
| 23 |
+
}
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class UserLogin(BaseModel):
|
| 28 |
+
"""User login request"""
|
| 29 |
+
email: EmailStr
|
| 30 |
+
password: str
|
| 31 |
+
|
| 32 |
+
class Config:
|
| 33 |
+
json_schema_extra = {
|
| 34 |
+
"example": {
|
| 35 |
+
"email": "user@example.com",
|
| 36 |
+
"password": "SecurePass123"
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class UserResponse(BaseModel):
|
| 42 |
+
"""User response (without password)"""
|
| 43 |
+
user_id: str
|
| 44 |
+
email: str
|
| 45 |
+
full_name: str
|
| 46 |
+
created_at: datetime
|
| 47 |
+
|
| 48 |
+
class Config:
|
| 49 |
+
json_schema_extra = {
|
| 50 |
+
"example": {
|
| 51 |
+
"user_id": "abc-123",
|
| 52 |
+
"email": "user@example.com",
|
| 53 |
+
"full_name": "John Doe",
|
| 54 |
+
"created_at": "2025-10-28T20:00:00"
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class Token(BaseModel):
|
| 60 |
+
"""JWT Token response"""
|
| 61 |
+
access_token: str
|
| 62 |
+
token_type: str = "bearer"
|
| 63 |
+
user: UserResponse
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class TokenData(BaseModel):
|
| 67 |
+
"""Data stored in JWT token"""
|
| 68 |
+
user_id: Optional[str] = None
|
| 69 |
+
email: Optional[str] = None
|
app/services/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Services package - Business logic layer
|
| 3 |
+
Contains all service classes that handle core application logic
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
__version__ = "1.0.0"
|
app/services/chat_service.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chat Service - Main RAG Pipeline
|
| 3 |
+
|
| 4 |
+
Combines: Policy Network → Retriever → LLM Generator
|
| 5 |
+
|
| 6 |
+
This is the core service that orchestrates:
|
| 7 |
+
1. Policy decision (FETCH vs NO_FETCH)
|
| 8 |
+
2. Document retrieval (if FETCH)
|
| 9 |
+
3. Response generation (Groq/HuggingFace with Llama 3)
|
| 10 |
+
4. Logging to MongoDB
|
| 11 |
+
|
| 12 |
+
Adapted from your RAG.py workflow
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import time
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
from typing import List, Dict, Any, Optional
|
| 18 |
+
|
| 19 |
+
from app.config import settings
|
| 20 |
+
from app.ml.policy_network import predict_policy_action
|
| 21 |
+
from app.ml.retriever import retrieve_documents, format_context
|
| 22 |
+
from app.core.llm_manager import llm_manager
|
| 23 |
+
|
| 24 |
+
# ============================================================================
|
| 25 |
+
# SYSTEM PROMPTS
|
| 26 |
+
# ============================================================================
|
| 27 |
+
|
| 28 |
+
BANKING_SYSTEM_PROMPT = """You are an expert banking assistant specialized in Indian financial regulations and banking practices. You have access to a comprehensive knowledge base of banking policies, procedures, and RBI regulations.
|
| 29 |
+
|
| 30 |
+
Instructions:
|
| 31 |
+
- Answer the user query accurately using the provided context when available
|
| 32 |
+
- If context is insufficient or query is outside banking domain, still respond helpfully but mention your banking specialization
|
| 33 |
+
- If no banking context is available, provide a general helpful response but acknowledge your expertise is in banking
|
| 34 |
+
- Never refuse to answer - always be helpful while being transparent about your specialization
|
| 35 |
+
- Cite relevant policy numbers or document references when available in context
|
| 36 |
+
- Never fabricate specific policies, rates, or eligibility criteria
|
| 37 |
+
- If uncertain about current rates or policies, acknowledge the limitation
|
| 38 |
+
- Maintain a helpful and professional tone
|
| 39 |
+
- Keep responses concise, clear, and actionable
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
EVALUATION_PROMPT = """You are evaluating a banking assistant's response for quality and accuracy.
|
| 43 |
+
|
| 44 |
+
Criteria:
|
| 45 |
+
1. Accuracy: Is the response factually correct?
|
| 46 |
+
2. Relevance: Does it address the user's question?
|
| 47 |
+
3. Completeness: Are all aspects of the question covered?
|
| 48 |
+
4. Clarity: Is the response easy to understand?
|
| 49 |
+
5. Context Usage: Does it properly use the retrieved context?
|
| 50 |
+
|
| 51 |
+
Rate the response as:
|
| 52 |
+
- "Good": Accurate, relevant, complete, and clear
|
| 53 |
+
- "Bad": Inaccurate, irrelevant, incomplete, or unclear
|
| 54 |
+
|
| 55 |
+
Provide your rating and brief explanation."""
|
| 56 |
+
|
| 57 |
+
# ============================================================================
|
| 58 |
+
# CHAT SERVICE
|
| 59 |
+
# ============================================================================
|
| 60 |
+
|
| 61 |
+
class ChatService:
|
| 62 |
+
"""
|
| 63 |
+
Main chat service that handles the complete RAG pipeline.
|
| 64 |
+
|
| 65 |
+
Pipeline:
|
| 66 |
+
1. User query comes in
|
| 67 |
+
2. Policy network decides: FETCH or NO_FETCH
|
| 68 |
+
3. If FETCH: Retrieve documents from FAISS
|
| 69 |
+
4. Generate response using Groq/HuggingFace (with or without context)
|
| 70 |
+
5. Return response + metadata
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(self):
|
| 74 |
+
"""Initialize chat service"""
|
| 75 |
+
print("🤖 ChatService initialized")
|
| 76 |
+
|
| 77 |
+
async def process_query(
|
| 78 |
+
self,
|
| 79 |
+
query: str,
|
| 80 |
+
conversation_history: List[Dict[str, str]] = None,
|
| 81 |
+
user_id: Optional[str] = None
|
| 82 |
+
) -> Dict[str, Any]:
|
| 83 |
+
"""
|
| 84 |
+
Process a user query through the complete RAG pipeline.
|
| 85 |
+
|
| 86 |
+
This is the MAIN function that combines everything:
|
| 87 |
+
- Policy decision
|
| 88 |
+
- Retrieval
|
| 89 |
+
- Generation
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
query: User query text
|
| 93 |
+
conversation_history: Previous conversation turns
|
| 94 |
+
Format: [{'role': 'user'/'assistant', 'content': '...', 'metadata': {...}}]
|
| 95 |
+
user_id: Optional user ID for logging
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
dict: Complete response with metadata
|
| 99 |
+
{
|
| 100 |
+
'response': str, # Generated response
|
| 101 |
+
'policy_action': str, # FETCH or NO_FETCH
|
| 102 |
+
'policy_confidence': float, # Confidence score
|
| 103 |
+
'should_retrieve': bool, # Whether retrieval was done
|
| 104 |
+
'documents_retrieved': int, # Number of docs retrieved
|
| 105 |
+
'top_doc_score': float or None, # Best similarity score
|
| 106 |
+
'retrieval_time_ms': float, # Time spent on retrieval
|
| 107 |
+
'generation_time_ms': float, # Time spent on generation
|
| 108 |
+
'total_time_ms': float, # Total processing time
|
| 109 |
+
'timestamp': str # ISO timestamp
|
| 110 |
+
}
|
| 111 |
+
"""
|
| 112 |
+
start_time = time.time()
|
| 113 |
+
|
| 114 |
+
# Initialize history if None
|
| 115 |
+
if conversation_history is None:
|
| 116 |
+
conversation_history = []
|
| 117 |
+
|
| 118 |
+
# Validate query
|
| 119 |
+
if not query or query.strip() == "":
|
| 120 |
+
return {
|
| 121 |
+
'response': "I didn't receive a valid question. Could you please try again?",
|
| 122 |
+
'policy_action': 'NO_FETCH',
|
| 123 |
+
'policy_confidence': 1.0,
|
| 124 |
+
'should_retrieve': False,
|
| 125 |
+
'documents_retrieved': 0,
|
| 126 |
+
'top_doc_score': None,
|
| 127 |
+
'retrieval_time_ms': 0,
|
| 128 |
+
'generation_time_ms': 0,
|
| 129 |
+
'total_time_ms': 0,
|
| 130 |
+
'timestamp': datetime.now().isoformat()
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
# ====================================================================
|
| 134 |
+
# STEP 1: POLICY DECISION (Local BERT model)
|
| 135 |
+
# ====================================================================
|
| 136 |
+
print(f"\n{'='*80}")
|
| 137 |
+
print(f"🔍 Processing Query: {query[:50]}...")
|
| 138 |
+
print(f"{'='*80}")
|
| 139 |
+
|
| 140 |
+
policy_start = time.time()
|
| 141 |
+
|
| 142 |
+
# Predict action using policy network
|
| 143 |
+
policy_result = predict_policy_action(
|
| 144 |
+
query=query,
|
| 145 |
+
history=conversation_history,
|
| 146 |
+
return_probs=True
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
policy_time = (time.time() - policy_start) * 1000
|
| 150 |
+
|
| 151 |
+
print(f"\n📊 Policy Decision:")
|
| 152 |
+
print(f" Action: {policy_result['action']}")
|
| 153 |
+
print(f" Confidence: {policy_result['confidence']:.3f}")
|
| 154 |
+
print(f" Should Retrieve: {policy_result['should_retrieve']}")
|
| 155 |
+
print(f" Time: {policy_time:.2f}ms")
|
| 156 |
+
|
| 157 |
+
# ====================================================================
|
| 158 |
+
# STEP 2: RETRIEVAL (if FETCH or low confidence NO_FETCH)
|
| 159 |
+
# ====================================================================
|
| 160 |
+
retrieved_docs = []
|
| 161 |
+
context = ""
|
| 162 |
+
retrieval_time = 0
|
| 163 |
+
|
| 164 |
+
if policy_result['should_retrieve']:
|
| 165 |
+
print(f"\n🔎 Retrieving documents...")
|
| 166 |
+
retrieval_start = time.time()
|
| 167 |
+
|
| 168 |
+
try:
|
| 169 |
+
# Retrieve documents using custom retriever + FAISS
|
| 170 |
+
retrieved_docs = retrieve_documents(
|
| 171 |
+
query=query,
|
| 172 |
+
top_k=settings.TOP_K,
|
| 173 |
+
min_similarity=settings.SIMILARITY_THRESHOLD
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
retrieval_time = (time.time() - retrieval_start) * 1000
|
| 177 |
+
|
| 178 |
+
if retrieved_docs:
|
| 179 |
+
print(f" ✅ Retrieved {len(retrieved_docs)} documents")
|
| 180 |
+
print(f" Top score: {retrieved_docs[0]['score']:.3f}")
|
| 181 |
+
|
| 182 |
+
# Format context for LLM
|
| 183 |
+
context = format_context(
|
| 184 |
+
retrieved_docs,
|
| 185 |
+
max_context_length=settings.MAX_CONTEXT_LENGTH
|
| 186 |
+
)
|
| 187 |
+
else:
|
| 188 |
+
print(f" ⚠️ No documents above threshold")
|
| 189 |
+
|
| 190 |
+
except Exception as e:
|
| 191 |
+
print(f" ❌ Retrieval error: {e}")
|
| 192 |
+
# Continue without retrieval
|
| 193 |
+
|
| 194 |
+
else:
|
| 195 |
+
print(f"\n🚫 Skipping retrieval (Policy: {policy_result['action']})")
|
| 196 |
+
|
| 197 |
+
# ====================================================================
|
| 198 |
+
# STEP 3: GENERATE RESPONSE (Groq/HuggingFace with fallback)
|
| 199 |
+
# ====================================================================
|
| 200 |
+
print(f"\n💬 Generating response...")
|
| 201 |
+
generation_start = time.time()
|
| 202 |
+
|
| 203 |
+
try:
|
| 204 |
+
# Generate response using LLM manager (Groq → HuggingFace fallback)
|
| 205 |
+
response = await llm_manager.generate_chat_response(
|
| 206 |
+
query=query,
|
| 207 |
+
context=context,
|
| 208 |
+
history=conversation_history
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
generation_time = (time.time() - generation_start) * 1000
|
| 212 |
+
|
| 213 |
+
print(f" ✅ Response generated")
|
| 214 |
+
print(f" Length: {len(response)} chars")
|
| 215 |
+
print(f" Time: {generation_time:.2f}ms")
|
| 216 |
+
|
| 217 |
+
except Exception as e:
|
| 218 |
+
print(f" ❌ Generation error: {e}")
|
| 219 |
+
response = "I apologize, but I encountered an error generating a response. Please try again."
|
| 220 |
+
generation_time = (time.time() - generation_start) * 1000
|
| 221 |
+
|
| 222 |
+
# ====================================================================
|
| 223 |
+
# STEP 4: COMPILE RESULTS
|
| 224 |
+
# ====================================================================
|
| 225 |
+
total_time = (time.time() - start_time) * 1000
|
| 226 |
+
|
| 227 |
+
result = {
|
| 228 |
+
'response': response,
|
| 229 |
+
'policy_action': policy_result['action'],
|
| 230 |
+
'policy_confidence': policy_result['confidence'],
|
| 231 |
+
'should_retrieve': policy_result['should_retrieve'],
|
| 232 |
+
'documents_retrieved': len(retrieved_docs),
|
| 233 |
+
'top_doc_score': retrieved_docs[0]['score'] if retrieved_docs else None,
|
| 234 |
+
'retrieval_time_ms': round(retrieval_time, 2),
|
| 235 |
+
'generation_time_ms': round(generation_time, 2),
|
| 236 |
+
'total_time_ms': round(total_time, 2),
|
| 237 |
+
'timestamp': datetime.now().isoformat()
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
# Add retrieved docs metadata (for logging, not sent to user)
|
| 241 |
+
if retrieved_docs:
|
| 242 |
+
result['retrieved_docs_metadata'] = [
|
| 243 |
+
{
|
| 244 |
+
'faq_id': doc['faq_id'],
|
| 245 |
+
'score': doc['score'],
|
| 246 |
+
'category': doc['category'],
|
| 247 |
+
'rank': doc['rank']
|
| 248 |
+
}
|
| 249 |
+
for doc in retrieved_docs
|
| 250 |
+
]
|
| 251 |
+
|
| 252 |
+
print(f"\n{'='*80}")
|
| 253 |
+
print(f"✅ Query processed successfully")
|
| 254 |
+
print(f" Total time: {total_time:.2f}ms")
|
| 255 |
+
print(f"{'='*80}\n")
|
| 256 |
+
|
| 257 |
+
return result
|
| 258 |
+
|
| 259 |
+
async def health_check(self) -> Dict[str, Any]:
|
| 260 |
+
"""
|
| 261 |
+
Check health of all service components.
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
dict: Health status
|
| 265 |
+
"""
|
| 266 |
+
health = {
|
| 267 |
+
'service': 'chat_service',
|
| 268 |
+
'status': 'healthy',
|
| 269 |
+
'components': {}
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
# Check policy network
|
| 273 |
+
try:
|
| 274 |
+
from app.ml.policy_network import POLICY_MODEL
|
| 275 |
+
health['components']['policy_network'] = 'loaded' if POLICY_MODEL else 'not_loaded'
|
| 276 |
+
except Exception as e:
|
| 277 |
+
health['components']['policy_network'] = f'error: {str(e)}'
|
| 278 |
+
|
| 279 |
+
# Check retriever
|
| 280 |
+
try:
|
| 281 |
+
from app.ml.retriever import RETRIEVER_MODEL, FAISS_INDEX
|
| 282 |
+
health['components']['retriever'] = 'loaded' if RETRIEVER_MODEL else 'not_loaded'
|
| 283 |
+
health['components']['faiss_index'] = 'loaded' if FAISS_INDEX else 'not_loaded'
|
| 284 |
+
except Exception as e:
|
| 285 |
+
health['components']['retriever'] = f'error: {str(e)}'
|
| 286 |
+
|
| 287 |
+
# Check LLM manager
|
| 288 |
+
try:
|
| 289 |
+
from app.core.llm_manager import llm_manager as llm
|
| 290 |
+
health['components']['groq'] = 'enabled' if llm.groq else 'disabled'
|
| 291 |
+
health['components']['huggingface'] = 'enabled' if llm.huggingface else 'disabled'
|
| 292 |
+
except Exception as e:
|
| 293 |
+
health['components']['llm_manager'] = f'error: {str(e)}'
|
| 294 |
+
|
| 295 |
+
# Overall status
|
| 296 |
+
failed_components = [k for k, v in health['components'].items() if 'error' in str(v)]
|
| 297 |
+
if failed_components:
|
| 298 |
+
health['status'] = 'degraded'
|
| 299 |
+
health['failed_components'] = failed_components
|
| 300 |
+
|
| 301 |
+
return health
|
| 302 |
+
|
| 303 |
+
# ============================================================================
|
| 304 |
+
# GLOBAL CHAT SERVICE INSTANCE
|
| 305 |
+
# ============================================================================
|
| 306 |
+
|
| 307 |
+
chat_service = ChatService()
|
| 308 |
+
|
| 309 |
+
# ============================================================================
|
| 310 |
+
# USAGE EXAMPLE (for reference)
|
| 311 |
+
# ============================================================================
|
| 312 |
+
"""
|
| 313 |
+
# In your API endpoint (chat.py):
|
| 314 |
+
from app.services.chat_service import chat_service
|
| 315 |
+
|
| 316 |
+
# Process user query
|
| 317 |
+
result = await chat_service.process_query(
|
| 318 |
+
query="What is my account balance?",
|
| 319 |
+
conversation_history=[
|
| 320 |
+
{'role': 'user', 'content': 'Hello'},
|
| 321 |
+
{'role': 'assistant', 'content': 'Hi! How can I help?', 'metadata': {'policy_action': 'NO_FETCH'}}
|
| 322 |
+
],
|
| 323 |
+
user_id="user_123"
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
# Result contains:
|
| 327 |
+
# - response: "Your account balance is $1,234.56"
|
| 328 |
+
# - policy_action: "FETCH"
|
| 329 |
+
# - documents_retrieved: 3
|
| 330 |
+
# - total_time_ms: 450.23
|
| 331 |
+
# etc.
|
| 332 |
+
|
| 333 |
+
# Get service health
|
| 334 |
+
health = await chat_service.health_check()
|
| 335 |
+
"""
|
app/utils/dependencies.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI Dependencies
|
| 3 |
+
Authentication and authorization dependencies
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from fastapi import Depends, HTTPException, status
|
| 7 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
from app.utils.security import decode_access_token
|
| 11 |
+
from app.db.repositories.user_repository import UserRepository
|
| 12 |
+
from app.models.user import TokenData
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# HTTP Bearer token scheme
|
| 16 |
+
security = HTTPBearer()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
async def get_current_user(
|
| 20 |
+
credentials: HTTPAuthorizationCredentials = Depends(security)
|
| 21 |
+
) -> TokenData:
|
| 22 |
+
"""
|
| 23 |
+
Get current authenticated user from JWT token.
|
| 24 |
+
|
| 25 |
+
This dependency extracts and validates the JWT token from the
|
| 26 |
+
Authorization header and returns the user data.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
credentials: HTTP Bearer credentials
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
TokenData: User data from token
|
| 33 |
+
|
| 34 |
+
Raises:
|
| 35 |
+
HTTPException: If token is invalid or expired
|
| 36 |
+
"""
|
| 37 |
+
credentials_exception = HTTPException(
|
| 38 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 39 |
+
detail="Could not validate credentials",
|
| 40 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Decode token
|
| 44 |
+
token = credentials.credentials
|
| 45 |
+
payload = decode_access_token(token)
|
| 46 |
+
|
| 47 |
+
if payload is None:
|
| 48 |
+
raise credentials_exception
|
| 49 |
+
|
| 50 |
+
# Extract user data
|
| 51 |
+
user_id: str = payload.get("user_id")
|
| 52 |
+
email: str = payload.get("email")
|
| 53 |
+
|
| 54 |
+
if user_id is None or email is None:
|
| 55 |
+
raise credentials_exception
|
| 56 |
+
|
| 57 |
+
# Verify user exists
|
| 58 |
+
user_repo = UserRepository()
|
| 59 |
+
user = await user_repo.get_user_by_id(user_id)
|
| 60 |
+
|
| 61 |
+
if user is None or not user.get("is_active", False):
|
| 62 |
+
raise credentials_exception
|
| 63 |
+
|
| 64 |
+
return TokenData(user_id=user_id, email=email)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
async def get_optional_current_user(
|
| 68 |
+
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
| 69 |
+
) -> Optional[TokenData]:
|
| 70 |
+
"""
|
| 71 |
+
Get current user if authenticated, None otherwise.
|
| 72 |
+
|
| 73 |
+
This is a non-required version of get_current_user for optional auth.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
credentials: Optional HTTP Bearer credentials
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
TokenData or None: User data from token or None
|
| 80 |
+
"""
|
| 81 |
+
if credentials is None:
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
return await get_current_user(credentials)
|
| 86 |
+
except HTTPException:
|
| 87 |
+
return None
|
app/utils/security.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# """
|
| 2 |
+
# Security Utilities
|
| 3 |
+
# Password hashing and JWT token management
|
| 4 |
+
# """
|
| 5 |
+
|
| 6 |
+
# from datetime import datetime, timedelta
|
| 7 |
+
# from typing import Optional
|
| 8 |
+
# from jose import JWTError, jwt
|
| 9 |
+
# from passlib.context import CryptContext
|
| 10 |
+
|
| 11 |
+
# from app.config import settings
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# # Password hashing context
|
| 15 |
+
# pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# # ============================================================================
|
| 19 |
+
# # PASSWORD HASHING
|
| 20 |
+
# # ============================================================================
|
| 21 |
+
|
| 22 |
+
# def hash_password(password: str) -> str:
|
| 23 |
+
# """
|
| 24 |
+
# Hash a password using bcrypt.
|
| 25 |
+
|
| 26 |
+
# Args:
|
| 27 |
+
# password: Plain text password
|
| 28 |
+
|
| 29 |
+
# Returns:
|
| 30 |
+
# str: Hashed password
|
| 31 |
+
# """
|
| 32 |
+
# return pwd_context.hash(password)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# def verify_password(plain_password: str, hashed_password: str) -> bool:
|
| 36 |
+
# """
|
| 37 |
+
# Verify a password against a hash.
|
| 38 |
+
|
| 39 |
+
# Args:
|
| 40 |
+
# plain_password: Plain text password to verify
|
| 41 |
+
# hashed_password: Hashed password from database
|
| 42 |
+
|
| 43 |
+
# Returns:
|
| 44 |
+
# bool: True if password matches, False otherwise
|
| 45 |
+
# """
|
| 46 |
+
# return pwd_context.verify(plain_password, hashed_password)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# # ============================================================================
|
| 50 |
+
# # JWT TOKEN MANAGEMENT
|
| 51 |
+
# # ============================================================================
|
| 52 |
+
|
| 53 |
+
# def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
| 54 |
+
# """
|
| 55 |
+
# Create a JWT access token.
|
| 56 |
+
|
| 57 |
+
# Args:
|
| 58 |
+
# data: Data to encode in token (user_id, email, etc.)
|
| 59 |
+
# expires_delta: Optional custom expiration time
|
| 60 |
+
|
| 61 |
+
# Returns:
|
| 62 |
+
# str: Encoded JWT token
|
| 63 |
+
# """
|
| 64 |
+
# to_encode = data.copy()
|
| 65 |
+
|
| 66 |
+
# if expires_delta:
|
| 67 |
+
# expire = datetime.utcnow() + expires_delta
|
| 68 |
+
# else:
|
| 69 |
+
# expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
| 70 |
+
|
| 71 |
+
# to_encode.update({"exp": expire})
|
| 72 |
+
# encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
| 73 |
+
|
| 74 |
+
# return encoded_jwt
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# def decode_access_token(token: str) -> Optional[dict]:
|
| 78 |
+
# """
|
| 79 |
+
# Decode and verify a JWT token.
|
| 80 |
+
|
| 81 |
+
# Args:
|
| 82 |
+
# token: JWT token to decode
|
| 83 |
+
|
| 84 |
+
# Returns:
|
| 85 |
+
# dict: Decoded token data or None if invalid
|
| 86 |
+
# """
|
| 87 |
+
# try:
|
| 88 |
+
# payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
| 89 |
+
# return payload
|
| 90 |
+
# except JWTError:
|
| 91 |
+
# return None
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
"""
|
| 102 |
+
Security utilities for password hashing and JWT tokens
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
from passlib.context import CryptContext
|
| 106 |
+
from datetime import datetime, timedelta
|
| 107 |
+
from jose import JWTError, jwt
|
| 108 |
+
from app.config import settings
|
| 109 |
+
|
| 110 |
+
# Password hashing context
|
| 111 |
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def hash_password(password: str) -> str:
|
| 115 |
+
"""
|
| 116 |
+
Hash a password using bcrypt
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
password: Plain text password
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
Hashed password
|
| 123 |
+
"""
|
| 124 |
+
# Bcrypt has a 72 byte limit, truncate if longer
|
| 125 |
+
if len(password.encode('utf-8')) > 72:
|
| 126 |
+
password = password[:72]
|
| 127 |
+
|
| 128 |
+
return pwd_context.hash(password)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
| 132 |
+
"""
|
| 133 |
+
Verify a password against its hash
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
plain_password: Plain text password to verify
|
| 137 |
+
hashed_password: Stored hashed password
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
True if password matches, False otherwise
|
| 141 |
+
"""
|
| 142 |
+
# Truncate to 72 bytes for bcrypt
|
| 143 |
+
if len(plain_password.encode('utf-8')) > 72:
|
| 144 |
+
plain_password = plain_password[:72]
|
| 145 |
+
|
| 146 |
+
return pwd_context.verify(plain_password, hashed_password)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def create_access_token(data: dict, expires_delta: timedelta = None) -> str:
|
| 150 |
+
"""
|
| 151 |
+
Create a JWT access token
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
data: Data to encode in the token (usually user_id, email)
|
| 155 |
+
expires_delta: Token expiration time (default: from settings)
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
Encoded JWT token string
|
| 159 |
+
"""
|
| 160 |
+
to_encode = data.copy()
|
| 161 |
+
|
| 162 |
+
if expires_delta:
|
| 163 |
+
expire = datetime.utcnow() + expires_delta
|
| 164 |
+
else:
|
| 165 |
+
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
| 166 |
+
|
| 167 |
+
to_encode.update({"exp": expire})
|
| 168 |
+
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
| 169 |
+
|
| 170 |
+
return encoded_jwt
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def decode_access_token(token: str) -> dict:
|
| 174 |
+
"""
|
| 175 |
+
Decode and verify a JWT token
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
token: JWT token string
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
Decoded token data (dict)
|
| 182 |
+
|
| 183 |
+
Raises:
|
| 184 |
+
JWTError: If token is invalid or expired
|
| 185 |
+
"""
|
| 186 |
+
try:
|
| 187 |
+
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
| 188 |
+
return payload
|
| 189 |
+
except JWTError:
|
| 190 |
+
return None
|
backups/backup_chat_service.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# """
|
| 2 |
+
# Chat Service - Main RAG Pipeline
|
| 3 |
+
# Combines: Policy Network → Retriever → LLM Generator
|
| 4 |
+
|
| 5 |
+
# This is the core service that orchestrates:
|
| 6 |
+
# 1. Policy decision (FETCH vs NO_FETCH)
|
| 7 |
+
# 2. Document retrieval (if FETCH)
|
| 8 |
+
# 3. Response generation (Gemini)
|
| 9 |
+
# 4. Logging to MongoDB
|
| 10 |
+
|
| 11 |
+
# Adapted from your RAG.py workflow
|
| 12 |
+
# """
|
| 13 |
+
|
| 14 |
+
# import time
|
| 15 |
+
# from datetime import datetime
|
| 16 |
+
# from typing import List, Dict, Any, Optional
|
| 17 |
+
|
| 18 |
+
# from app.config import settings
|
| 19 |
+
# from app.ml.policy_network import predict_policy_action
|
| 20 |
+
# from app.ml.retriever import retrieve_documents, format_context
|
| 21 |
+
# from app.core.llm_manager import llm_manager
|
| 22 |
+
|
| 23 |
+
# # ============================================================================
|
| 24 |
+
# # SYSTEM PROMPTS
|
| 25 |
+
# # ============================================================================
|
| 26 |
+
|
| 27 |
+
# BANKING_SYSTEM_PROMPT = """You are an expert banking assistant specialized in Indian financial regulations and banking practices. You have access to a comprehensive knowledge base of banking policies, procedures, and RBI regulations.
|
| 28 |
+
|
| 29 |
+
# Instructions:
|
| 30 |
+
# - Answer the user query accurately using the provided context when available
|
| 31 |
+
# - If context is insufficient or query is outside banking domain, still respond helpfully but mention your banking specialization
|
| 32 |
+
# - If no banking context is available, provide a general helpful response but acknowledge your expertise is in banking
|
| 33 |
+
# - Never refuse to answer - always be helpful while being transparent about your specialization
|
| 34 |
+
# - Cite relevant policy numbers or document references when available in context
|
| 35 |
+
# - Never fabricate specific policies, rates, or eligibility criteria
|
| 36 |
+
# - If uncertain about current rates or policies, acknowledge the limitation
|
| 37 |
+
# - Maintain a helpful and professional tone
|
| 38 |
+
# - Keep responses concise, clear, and actionable
|
| 39 |
+
# """
|
| 40 |
+
|
| 41 |
+
# EVALUATION_PROMPT = """You are evaluating a banking assistant's response for quality and accuracy.
|
| 42 |
+
|
| 43 |
+
# Criteria:
|
| 44 |
+
# 1. Accuracy: Is the response factually correct?
|
| 45 |
+
# 2. Relevance: Does it address the user's question?
|
| 46 |
+
# 3. Completeness: Are all aspects of the question covered?
|
| 47 |
+
# 4. Clarity: Is the response easy to understand?
|
| 48 |
+
# 5. Context Usage: Does it properly use the retrieved context?
|
| 49 |
+
|
| 50 |
+
# Rate the response as:
|
| 51 |
+
# - "Good": Accurate, relevant, complete, and clear
|
| 52 |
+
# - "Bad": Inaccurate, irrelevant, incomplete, or unclear
|
| 53 |
+
|
| 54 |
+
# Provide your rating and brief explanation."""
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# # ============================================================================
|
| 59 |
+
# # CHAT SERVICE
|
| 60 |
+
# # ============================================================================
|
| 61 |
+
|
| 62 |
+
# class ChatService:
|
| 63 |
+
# """
|
| 64 |
+
# Main chat service that handles the complete RAG pipeline.
|
| 65 |
+
|
| 66 |
+
# Pipeline:
|
| 67 |
+
# 1. User query comes in
|
| 68 |
+
# 2. Policy network decides: FETCH or NO_FETCH
|
| 69 |
+
# 3. If FETCH: Retrieve documents from FAISS
|
| 70 |
+
# 4. Generate response using Gemini (with or without context)
|
| 71 |
+
# 5. Return response + metadata
|
| 72 |
+
# """
|
| 73 |
+
|
| 74 |
+
# def __init__(self):
|
| 75 |
+
# """Initialize chat service"""
|
| 76 |
+
# print("🤖 ChatService initialized")
|
| 77 |
+
|
| 78 |
+
# async def process_query(
|
| 79 |
+
# self,
|
| 80 |
+
# query: str,
|
| 81 |
+
# conversation_history: List[Dict[str, str]] = None,
|
| 82 |
+
# user_id: Optional[str] = None
|
| 83 |
+
# ) -> Dict[str, Any]:
|
| 84 |
+
# """
|
| 85 |
+
# Process a user query through the complete RAG pipeline.
|
| 86 |
+
|
| 87 |
+
# This is the MAIN function that combines everything:
|
| 88 |
+
# - Policy decision
|
| 89 |
+
# - Retrieval
|
| 90 |
+
# - Generation
|
| 91 |
+
|
| 92 |
+
# Args:
|
| 93 |
+
# query: User query text
|
| 94 |
+
# conversation_history: Previous conversation turns
|
| 95 |
+
# Format: [{'role': 'user'/'assistant', 'content': '...', 'metadata': {...}}]
|
| 96 |
+
# user_id: Optional user ID for logging
|
| 97 |
+
|
| 98 |
+
# Returns:
|
| 99 |
+
# dict: Complete response with metadata
|
| 100 |
+
# {
|
| 101 |
+
# 'response': str, # Generated response
|
| 102 |
+
# 'policy_action': str, # FETCH or NO_FETCH
|
| 103 |
+
# 'policy_confidence': float, # Confidence score
|
| 104 |
+
# 'should_retrieve': bool, # Whether retrieval was done
|
| 105 |
+
# 'documents_retrieved': int, # Number of docs retrieved
|
| 106 |
+
# 'top_doc_score': float or None, # Best similarity score
|
| 107 |
+
# 'retrieval_time_ms': float, # Time spent on retrieval
|
| 108 |
+
# 'generation_time_ms': float, # Time spent on generation
|
| 109 |
+
# 'total_time_ms': float, # Total processing time
|
| 110 |
+
# 'timestamp': str # ISO timestamp
|
| 111 |
+
# }
|
| 112 |
+
# """
|
| 113 |
+
# start_time = time.time()
|
| 114 |
+
|
| 115 |
+
# # Initialize history if None
|
| 116 |
+
# if conversation_history is None:
|
| 117 |
+
# conversation_history = []
|
| 118 |
+
|
| 119 |
+
# # Validate query
|
| 120 |
+
# if not query or query.strip() == "":
|
| 121 |
+
# return {
|
| 122 |
+
# 'response': "I didn't receive a valid question. Could you please try again?",
|
| 123 |
+
# 'policy_action': 'NO_FETCH',
|
| 124 |
+
# 'policy_confidence': 1.0,
|
| 125 |
+
# 'should_retrieve': False,
|
| 126 |
+
# 'documents_retrieved': 0,
|
| 127 |
+
# 'top_doc_score': None,
|
| 128 |
+
# 'retrieval_time_ms': 0,
|
| 129 |
+
# 'generation_time_ms': 0,
|
| 130 |
+
# 'total_time_ms': 0,
|
| 131 |
+
# 'timestamp': datetime.now().isoformat()
|
| 132 |
+
# }
|
| 133 |
+
|
| 134 |
+
# # ====================================================================
|
| 135 |
+
# # STEP 1: POLICY DECISION (Local BERT model)
|
| 136 |
+
# # ====================================================================
|
| 137 |
+
# print(f"\n{'='*80}")
|
| 138 |
+
# print(f"🔍 Processing Query: {query[:50]}...")
|
| 139 |
+
# print(f"{'='*80}")
|
| 140 |
+
|
| 141 |
+
# policy_start = time.time()
|
| 142 |
+
|
| 143 |
+
# # Predict action using policy network
|
| 144 |
+
# policy_result = predict_policy_action(
|
| 145 |
+
# query=query,
|
| 146 |
+
# history=conversation_history,
|
| 147 |
+
# return_probs=True
|
| 148 |
+
# )
|
| 149 |
+
|
| 150 |
+
# policy_time = (time.time() - policy_start) * 1000
|
| 151 |
+
|
| 152 |
+
# print(f"\n📊 Policy Decision:")
|
| 153 |
+
# print(f" Action: {policy_result['action']}")
|
| 154 |
+
# print(f" Confidence: {policy_result['confidence']:.3f}")
|
| 155 |
+
# print(f" Should Retrieve: {policy_result['should_retrieve']}")
|
| 156 |
+
# print(f" Time: {policy_time:.2f}ms")
|
| 157 |
+
|
| 158 |
+
# # ====================================================================
|
| 159 |
+
# # STEP 2: RETRIEVAL (if FETCH or low confidence NO_FETCH)
|
| 160 |
+
# # ====================================================================
|
| 161 |
+
# retrieved_docs = []
|
| 162 |
+
# context = ""
|
| 163 |
+
# retrieval_time = 0
|
| 164 |
+
|
| 165 |
+
# if policy_result['should_retrieve']:
|
| 166 |
+
# print(f"\n🔎 Retrieving documents...")
|
| 167 |
+
# retrieval_start = time.time()
|
| 168 |
+
|
| 169 |
+
# try:
|
| 170 |
+
# # Retrieve documents using custom retriever + FAISS
|
| 171 |
+
# retrieved_docs = retrieve_documents(
|
| 172 |
+
# query=query,
|
| 173 |
+
# top_k=settings.TOP_K,
|
| 174 |
+
# min_similarity=settings.SIMILARITY_THRESHOLD
|
| 175 |
+
# )
|
| 176 |
+
|
| 177 |
+
# retrieval_time = (time.time() - retrieval_start) * 1000
|
| 178 |
+
|
| 179 |
+
# if retrieved_docs:
|
| 180 |
+
# print(f" ✅ Retrieved {len(retrieved_docs)} documents")
|
| 181 |
+
# print(f" Top score: {retrieved_docs[0]['score']:.3f}")
|
| 182 |
+
|
| 183 |
+
# # Format context for LLM
|
| 184 |
+
# context = format_context(
|
| 185 |
+
# retrieved_docs,
|
| 186 |
+
# max_context_length=settings.MAX_CONTEXT_LENGTH
|
| 187 |
+
# )
|
| 188 |
+
# else:
|
| 189 |
+
# print(f" ⚠️ No documents above threshold")
|
| 190 |
+
|
| 191 |
+
# except Exception as e:
|
| 192 |
+
# print(f" ❌ Retrieval error: {e}")
|
| 193 |
+
# # Continue without retrieval
|
| 194 |
+
|
| 195 |
+
# else:
|
| 196 |
+
# print(f"\n🚫 Skipping retrieval (Policy: {policy_result['action']})")
|
| 197 |
+
|
| 198 |
+
# # ====================================================================
|
| 199 |
+
# # STEP 3: GENERATE RESPONSE (Gemini)
|
| 200 |
+
# # ====================================================================
|
| 201 |
+
# print(f"\n💬 Generating response...")
|
| 202 |
+
# generation_start = time.time()
|
| 203 |
+
|
| 204 |
+
# try:
|
| 205 |
+
# # Generate response using LLM manager (Gemini)
|
| 206 |
+
# response = await llm_manager.generate_chat_response(
|
| 207 |
+
# query=query,
|
| 208 |
+
# context=context,
|
| 209 |
+
# history=conversation_history
|
| 210 |
+
# )
|
| 211 |
+
|
| 212 |
+
# generation_time = (time.time() - generation_start) * 1000
|
| 213 |
+
|
| 214 |
+
# print(f" ✅ Response generated")
|
| 215 |
+
# print(f" Length: {len(response)} chars")
|
| 216 |
+
# print(f" Time: {generation_time:.2f}ms")
|
| 217 |
+
|
| 218 |
+
# except Exception as e:
|
| 219 |
+
# print(f" ❌ Generation error: {e}")
|
| 220 |
+
# response = "I apologize, but I encountered an error generating a response. Please try again."
|
| 221 |
+
# generation_time = (time.time() - generation_start) * 1000
|
| 222 |
+
|
| 223 |
+
# # ====================================================================
|
| 224 |
+
# # STEP 4: COMPILE RESULTS
|
| 225 |
+
# # ====================================================================
|
| 226 |
+
# total_time = (time.time() - start_time) * 1000
|
| 227 |
+
|
| 228 |
+
# result = {
|
| 229 |
+
# 'response': response,
|
| 230 |
+
# 'policy_action': policy_result['action'],
|
| 231 |
+
# 'policy_confidence': policy_result['confidence'],
|
| 232 |
+
# 'should_retrieve': policy_result['should_retrieve'],
|
| 233 |
+
# 'documents_retrieved': len(retrieved_docs),
|
| 234 |
+
# 'top_doc_score': retrieved_docs[0]['score'] if retrieved_docs else None,
|
| 235 |
+
# 'retrieval_time_ms': round(retrieval_time, 2),
|
| 236 |
+
# 'generation_time_ms': round(generation_time, 2),
|
| 237 |
+
# 'total_time_ms': round(total_time, 2),
|
| 238 |
+
# 'timestamp': datetime.now().isoformat()
|
| 239 |
+
# }
|
| 240 |
+
|
| 241 |
+
# # Add retrieved docs metadata (for logging, not sent to user)
|
| 242 |
+
# if retrieved_docs:
|
| 243 |
+
# result['retrieved_docs_metadata'] = [
|
| 244 |
+
# {
|
| 245 |
+
# 'faq_id': doc['faq_id'],
|
| 246 |
+
# 'score': doc['score'],
|
| 247 |
+
# 'category': doc['category'],
|
| 248 |
+
# 'rank': doc['rank']
|
| 249 |
+
# }
|
| 250 |
+
# for doc in retrieved_docs
|
| 251 |
+
# ]
|
| 252 |
+
|
| 253 |
+
# print(f"\n{'='*80}")
|
| 254 |
+
# print(f"✅ Query processed successfully")
|
| 255 |
+
# print(f" Total time: {total_time:.2f}ms")
|
| 256 |
+
# print(f"{'='*80}\n")
|
| 257 |
+
|
| 258 |
+
# return result
|
| 259 |
+
|
| 260 |
+
# async def health_check(self) -> Dict[str, Any]:
|
| 261 |
+
# """
|
| 262 |
+
# Check health of all service components.
|
| 263 |
+
|
| 264 |
+
# Returns:
|
| 265 |
+
# dict: Health status
|
| 266 |
+
# """
|
| 267 |
+
# health = {
|
| 268 |
+
# 'service': 'chat_service',
|
| 269 |
+
# 'status': 'healthy',
|
| 270 |
+
# 'components': {}
|
| 271 |
+
# }
|
| 272 |
+
|
| 273 |
+
# # Check policy network
|
| 274 |
+
# try:
|
| 275 |
+
# from app.ml.policy_network import POLICY_MODEL
|
| 276 |
+
# health['components']['policy_network'] = 'loaded' if POLICY_MODEL else 'not_loaded'
|
| 277 |
+
# except Exception as e:
|
| 278 |
+
# health['components']['policy_network'] = f'error: {str(e)}'
|
| 279 |
+
|
| 280 |
+
# # Check retriever
|
| 281 |
+
# try:
|
| 282 |
+
# from app.ml.retriever import RETRIEVER_MODEL, FAISS_INDEX
|
| 283 |
+
# health['components']['retriever'] = 'loaded' if RETRIEVER_MODEL else 'not_loaded'
|
| 284 |
+
# health['components']['faiss_index'] = 'loaded' if FAISS_INDEX else 'not_loaded'
|
| 285 |
+
# except Exception as e:
|
| 286 |
+
# health['components']['retriever'] = f'error: {str(e)}'
|
| 287 |
+
|
| 288 |
+
# # Check LLM manager
|
| 289 |
+
# try:
|
| 290 |
+
# from app.core.llm_manager import llm_manager as llm
|
| 291 |
+
# health['components']['gemini'] = 'enabled' if llm.gemini else 'disabled'
|
| 292 |
+
# health['components']['groq'] = 'enabled' if llm.groq else 'disabled'
|
| 293 |
+
# except Exception as e:
|
| 294 |
+
# health['components']['llm_manager'] = f'error: {str(e)}'
|
| 295 |
+
|
| 296 |
+
# # Overall status
|
| 297 |
+
# failed_components = [k for k, v in health['components'].items() if 'error' in str(v)]
|
| 298 |
+
# if failed_components:
|
| 299 |
+
# health['status'] = 'degraded'
|
| 300 |
+
# health['failed_components'] = failed_components
|
| 301 |
+
|
| 302 |
+
# return health
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
# # ============================================================================
|
| 306 |
+
# # GLOBAL CHAT SERVICE INSTANCE
|
| 307 |
+
# # ============================================================================
|
| 308 |
+
# chat_service = ChatService()
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
# # ============================================================================
|
| 312 |
+
# # USAGE EXAMPLE (for reference)
|
| 313 |
+
# # ============================================================================
|
| 314 |
+
# """
|
| 315 |
+
# # In your API endpoint (chat.py):
|
| 316 |
+
|
| 317 |
+
# from app.services.chat_service import chat_service
|
| 318 |
+
|
| 319 |
+
# # Process user query
|
| 320 |
+
# result = await chat_service.process_query(
|
| 321 |
+
# query="What is my account balance?",
|
| 322 |
+
# conversation_history=[
|
| 323 |
+
# {'role': 'user', 'content': 'Hello'},
|
| 324 |
+
# {'role': 'assistant', 'content': 'Hi! How can I help?', 'metadata': {'policy_action': 'NO_FETCH'}}
|
| 325 |
+
# ],
|
| 326 |
+
# user_id="user_123"
|
| 327 |
+
# )
|
| 328 |
+
|
| 329 |
+
# # Result contains:
|
| 330 |
+
# # - response: "Your account balance is $1,234.56"
|
| 331 |
+
# # - policy_action: "FETCH"
|
| 332 |
+
# # - documents_retrieved: 3
|
| 333 |
+
# # - total_time_ms: 450.23
|
| 334 |
+
# # etc.
|
| 335 |
+
|
| 336 |
+
# # Get service health
|
| 337 |
+
# health = await chat_service.health_check()
|
| 338 |
+
# """
|
| 339 |
+
|
| 340 |
+
|
backups/backup_config.py
ADDED
|
@@ -0,0 +1,640 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LINE 80 VERY IMP CHANGE OF LLM MAX TOKENS FROM 512 TO 1024
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Application Configuration
|
| 6 |
+
Settings for Banking RAG Chatbot with JWT Authentication
|
| 7 |
+
Includes all settings needed by existing llm_manager.py
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
from typing import List
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
|
| 14 |
+
load_dotenv()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Settings:
|
| 18 |
+
"""Application settings loaded from environment variables"""
|
| 19 |
+
|
| 20 |
+
# ========================================================================
|
| 21 |
+
# ENVIRONMENT
|
| 22 |
+
# ========================================================================
|
| 23 |
+
ENVIRONMENT: str = os.getenv("ENVIRONMENT", "development")
|
| 24 |
+
DEBUG: bool = os.getenv("DEBUG", "True").lower() == "true"
|
| 25 |
+
|
| 26 |
+
# ========================================================================
|
| 27 |
+
# MONGODB
|
| 28 |
+
# ========================================================================
|
| 29 |
+
MONGODB_URI: str = os.getenv("MONGODB_URI", "")
|
| 30 |
+
DATABASE_NAME: str = os.getenv("DATABASE_NAME", "aml_ia_db")
|
| 31 |
+
|
| 32 |
+
# ========================================================================
|
| 33 |
+
# JWT AUTHENTICATION
|
| 34 |
+
# ========================================================================
|
| 35 |
+
SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key-change-in-production")
|
| 36 |
+
ALGORITHM: str = os.getenv("ALGORITHM", "HS256")
|
| 37 |
+
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440"))
|
| 38 |
+
|
| 39 |
+
# ========================================================================
|
| 40 |
+
# CORS (for frontend)
|
| 41 |
+
# ========================================================================
|
| 42 |
+
ALLOWED_ORIGINS: str = os.getenv("ALLOWED_ORIGINS", "*")
|
| 43 |
+
|
| 44 |
+
# ========================================================================
|
| 45 |
+
# GOOGLE GEMINI API
|
| 46 |
+
# ========================================================================
|
| 47 |
+
GOOGLE_API_KEY: str = os.getenv("GOOGLE_API_KEY", "")
|
| 48 |
+
GEMINI_MODEL: str = os.getenv("GEMINI_MODEL", "gemini-2.0-flash-lite")
|
| 49 |
+
GEMINI_REQUESTS_PER_MINUTE: int = int(os.getenv("GEMINI_REQUESTS_PER_MINUTE", "60"))
|
| 50 |
+
|
| 51 |
+
# ========================================================================
|
| 52 |
+
# GROQ API (Optional - for evaluation)
|
| 53 |
+
# ========================================================================
|
| 54 |
+
GROQ_API_KEY: str = os.getenv("GROQ_API_KEY", "")
|
| 55 |
+
GROQ_MODEL: str = os.getenv("GROQ_MODEL", "llama3-70b-8192")
|
| 56 |
+
GROQ_REQUESTS_PER_MINUTE: int = int(os.getenv("GROQ_REQUESTS_PER_MINUTE", "30"))
|
| 57 |
+
|
| 58 |
+
# ========================================================================
|
| 59 |
+
# HUGGING FACE (Optional - for model downloads)
|
| 60 |
+
# ========================================================================
|
| 61 |
+
HF_TOKEN: str = os.getenv("HF_TOKEN", "")
|
| 62 |
+
|
| 63 |
+
# ========================================================================
|
| 64 |
+
# MODEL PATHS (for RL Policy Network and RAG models)
|
| 65 |
+
# ========================================================================
|
| 66 |
+
POLICY_MODEL_PATH: str = os.getenv("POLICY_MODEL_PATH", "app/models/best_policy_model.pth")
|
| 67 |
+
RETRIEVER_MODEL_PATH: str = os.getenv("RETRIEVER_MODEL_PATH", "app/models/best_retriever_model.pth")
|
| 68 |
+
FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH", "app/models/faiss_index.pkl")
|
| 69 |
+
KB_PATH: str = os.getenv("KB_PATH", "app/data/final_knowledge_base.jsonl")
|
| 70 |
+
|
| 71 |
+
# ========================================================================
|
| 72 |
+
# DEVICE SETTINGS (for PyTorch/TensorFlow models)
|
| 73 |
+
# ========================================================================
|
| 74 |
+
DEVICE: str = os.getenv("DEVICE", "cpu")
|
| 75 |
+
|
| 76 |
+
# ========================================================================
|
| 77 |
+
# LLM PARAMETERS
|
| 78 |
+
# ========================================================================
|
| 79 |
+
LLM_TEMPERATURE: float = float(os.getenv("LLM_TEMPERATURE", "0.7"))
|
| 80 |
+
LLM_MAX_TOKENS: int = int(os.getenv("LLM_MAX_TOKENS", "1024")) # VERY IMPORTANT CHANGE =============================================================================================
|
| 81 |
+
# ============================================================================
|
| 82 |
+
|
| 83 |
+
# ========================================================================
|
| 84 |
+
# RAG PARAMETERS
|
| 85 |
+
# ========================================================================
|
| 86 |
+
TOP_K: int = int(os.getenv("TOP_K", "5"))
|
| 87 |
+
SIMILARITY_THRESHOLD: float = float(os.getenv("SIMILARITY_THRESHOLD", "0.5"))
|
| 88 |
+
MAX_CONTEXT_LENGTH: int = int(os.getenv("MAX_CONTEXT_LENGTH", "2000"))
|
| 89 |
+
|
| 90 |
+
# ========================================================================
|
| 91 |
+
# POLICY NETWORK PARAMETERS
|
| 92 |
+
# ========================================================================
|
| 93 |
+
POLICY_MAX_LEN: int = int(os.getenv("POLICY_MAX_LEN", "256"))
|
| 94 |
+
CONFIDENCE_THRESHOLD: float = float(os.getenv("CONFIDENCE_THRESHOLD", "0.7"))
|
| 95 |
+
|
| 96 |
+
# ========================================================================
|
| 97 |
+
# HELPER METHODS (Required by llm_manager.py)
|
| 98 |
+
# ========================================================================
|
| 99 |
+
|
| 100 |
+
def is_gemini_enabled(self) -> bool:
|
| 101 |
+
"""Check if Google Gemini API is configured"""
|
| 102 |
+
return bool(self.GOOGLE_API_KEY and self.GOOGLE_API_KEY != "")
|
| 103 |
+
|
| 104 |
+
def is_groq_enabled(self) -> bool:
|
| 105 |
+
"""Check if Groq API is configured"""
|
| 106 |
+
return bool(self.GROQ_API_KEY and self.GROQ_API_KEY != "")
|
| 107 |
+
|
| 108 |
+
def is_hf_enabled(self) -> bool:
|
| 109 |
+
"""Check if HuggingFace token is configured"""
|
| 110 |
+
return bool(self.HF_TOKEN and self.HF_TOKEN != "")
|
| 111 |
+
|
| 112 |
+
def get_allowed_origins(self) -> List[str]:
|
| 113 |
+
"""Parse allowed origins from comma-separated string"""
|
| 114 |
+
if self.ALLOWED_ORIGINS == "*":
|
| 115 |
+
return ["*"]
|
| 116 |
+
return [origin.strip() for origin in self.ALLOWED_ORIGINS.split(",")]
|
| 117 |
+
|
| 118 |
+
def get_llm_for_task(self, task: str = "qa") -> str:
|
| 119 |
+
"""
|
| 120 |
+
Get LLM name for a specific task.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
task: Task type ('chat', 'evaluation', etc.')
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
str: LLM name ('gemini' or 'groq')
|
| 127 |
+
"""
|
| 128 |
+
# Use Gemini for chat, Groq for evaluation
|
| 129 |
+
if task == "evaluation":
|
| 130 |
+
return "groq" if self.is_groq_enabled() else "gemini"
|
| 131 |
+
else:
|
| 132 |
+
return "gemini" # Default to Gemini for all tasks
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# ============================================================================
|
| 136 |
+
# CREATE GLOBAL SETTINGS INSTANCE
|
| 137 |
+
# ============================================================================
|
| 138 |
+
settings = Settings()
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# ============================================================================
|
| 142 |
+
# PRINT CONFIGURATION ON LOAD
|
| 143 |
+
# ============================================================================
|
| 144 |
+
print("=" * 80)
|
| 145 |
+
print("✅ Configuration Loaded")
|
| 146 |
+
print("=" * 80)
|
| 147 |
+
print(f"Environment: {settings.ENVIRONMENT}")
|
| 148 |
+
print(f"Debug Mode: {settings.DEBUG}")
|
| 149 |
+
print(f"Database: {settings.DATABASE_NAME}")
|
| 150 |
+
print(f"Device: {settings.DEVICE}")
|
| 151 |
+
print(f"CORS Origins: {settings.ALLOWED_ORIGINS}")
|
| 152 |
+
print()
|
| 153 |
+
print("🔑 API Keys:")
|
| 154 |
+
print(f" Google Gemini: {'✅ Configured' if settings.is_gemini_enabled() else '❌ Missing'}")
|
| 155 |
+
print(f" Groq API: {'✅ Configured' if settings.is_groq_enabled() else '⚠️ Optional (not set)'}")
|
| 156 |
+
print(f" HuggingFace: {'✅ Configured' if settings.is_hf_enabled() else '⚠️ Optional (not set)'}")
|
| 157 |
+
print(f" MongoDB: {'✅ Configured' if settings.MONGODB_URI else '❌ Missing'}")
|
| 158 |
+
print(f" JWT Secret: {'✅ Configured' if settings.SECRET_KEY != 'your-secret-key-change-in-production' else '⚠️ Using default (CHANGE THIS!)'}")
|
| 159 |
+
print()
|
| 160 |
+
print("🤖 Model Paths:")
|
| 161 |
+
print(f" Policy Model: {settings.POLICY_MODEL_PATH}")
|
| 162 |
+
print(f" Retriever Model: {settings.RETRIEVER_MODEL_PATH}")
|
| 163 |
+
print(f" FAISS Index: {settings.FAISS_INDEX_PATH}")
|
| 164 |
+
print(f" Knowledge Base: {settings.KB_PATH}")
|
| 165 |
+
print("=" * 80)
|
| 166 |
+
# ============================================================================
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# """
|
| 185 |
+
# Application Configuration
|
| 186 |
+
# Settings for Banking RAG Chatbot with JWT Authentication
|
| 187 |
+
# Includes all settings needed by existing llm_manager.py
|
| 188 |
+
# """
|
| 189 |
+
|
| 190 |
+
# import os
|
| 191 |
+
# from typing import List
|
| 192 |
+
# from dotenv import load_dotenv
|
| 193 |
+
|
| 194 |
+
# load_dotenv()
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# class Settings:
|
| 198 |
+
# """Application settings loaded from environment variables"""
|
| 199 |
+
|
| 200 |
+
# # ========================================================================
|
| 201 |
+
# # ENVIRONMENT
|
| 202 |
+
# # ========================================================================
|
| 203 |
+
# ENVIRONMENT: str = os.getenv("ENVIRONMENT", "development")
|
| 204 |
+
# DEBUG: bool = os.getenv("DEBUG", "True").lower() == "true"
|
| 205 |
+
|
| 206 |
+
# # ========================================================================
|
| 207 |
+
# # MONGODB
|
| 208 |
+
# # ========================================================================
|
| 209 |
+
# MONGODB_URI: str = os.getenv("MONGODB_URI", "")
|
| 210 |
+
# DATABASE_NAME: str = os.getenv("DATABASE_NAME", "aml_ia_db")
|
| 211 |
+
|
| 212 |
+
# # ========================================================================
|
| 213 |
+
# # JWT AUTHENTICATION
|
| 214 |
+
# # ========================================================================
|
| 215 |
+
# SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key-change-in-production")
|
| 216 |
+
# ALGORITHM: str = os.getenv("ALGORITHM", "HS256")
|
| 217 |
+
# ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440"))
|
| 218 |
+
|
| 219 |
+
# # ========================================================================
|
| 220 |
+
# # CORS (for frontend)
|
| 221 |
+
# # ========================================================================
|
| 222 |
+
# ALLOWED_ORIGINS: str = os.getenv("ALLOWED_ORIGINS", "*")
|
| 223 |
+
|
| 224 |
+
# # ========================================================================
|
| 225 |
+
# # GOOGLE GEMINI API
|
| 226 |
+
# # ========================================================================
|
| 227 |
+
# GOOGLE_API_KEY: str = os.getenv("GOOGLE_API_KEY", "")
|
| 228 |
+
# GEMINI_MODEL: str = os.getenv("GEMINI_MODEL", "gemini-2.0-flash-lite")
|
| 229 |
+
|
| 230 |
+
# # ========================================================================
|
| 231 |
+
# # GROQ API (Optional - for your llm_manager)
|
| 232 |
+
# # ========================================================================
|
| 233 |
+
# GROQ_API_KEY: str = os.getenv("GROQ_API_KEY", "")
|
| 234 |
+
# GROQ_MODEL: str = os.getenv("GROQ_MODEL", "llama3-70b-8192")
|
| 235 |
+
|
| 236 |
+
# # ========================================================================
|
| 237 |
+
# # HUGGING FACE (Optional - for model downloads)
|
| 238 |
+
# # ========================================================================
|
| 239 |
+
# HF_TOKEN: str = os.getenv("HF_TOKEN", "")
|
| 240 |
+
|
| 241 |
+
# # ========================================================================
|
| 242 |
+
# # MODEL PATHS (for RL Policy Network and RAG models)
|
| 243 |
+
# # ========================================================================
|
| 244 |
+
# POLICY_MODEL_PATH: str = os.getenv("POLICY_MODEL_PATH", "models/best_policy_model.pth")
|
| 245 |
+
# RETRIEVER_MODEL_PATH: str = os.getenv("RETRIEVER_MODEL_PATH", "models/best_retriever_model.pth")
|
| 246 |
+
# FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH", "models/faiss_index.pkl")
|
| 247 |
+
# KB_PATH: str = os.getenv("KB_PATH", "data/final_knowledge_base.jsonl")
|
| 248 |
+
|
| 249 |
+
# # ========================================================================
|
| 250 |
+
# # DEVICE SETTINGS (for PyTorch/TensorFlow models)
|
| 251 |
+
# # ========================================================================
|
| 252 |
+
# DEVICE: str = os.getenv("DEVICE", "cpu")
|
| 253 |
+
|
| 254 |
+
# # ========================================================================
|
| 255 |
+
# # LLM PARAMETERS
|
| 256 |
+
# # ========================================================================
|
| 257 |
+
# LLM_TEMPERATURE: float = float(os.getenv("LLM_TEMPERATURE", "0.7"))
|
| 258 |
+
# LLM_MAX_TOKENS: int = int(os.getenv("LLM_MAX_TOKENS", "512"))
|
| 259 |
+
|
| 260 |
+
# # ========================================================================
|
| 261 |
+
# # RAG PARAMETERS
|
| 262 |
+
# # ========================================================================
|
| 263 |
+
# TOP_K: int = int(os.getenv("TOP_K", "5"))
|
| 264 |
+
# SIMILARITY_THRESHOLD: float = float(os.getenv("SIMILARITY_THRESHOLD", "0.5"))
|
| 265 |
+
# MAX_CONTEXT_LENGTH: int = int(os.getenv("MAX_CONTEXT_LENGTH", "2000"))
|
| 266 |
+
|
| 267 |
+
# # ========================================================================
|
| 268 |
+
# # POLICY NETWORK PARAMETERS
|
| 269 |
+
# # ========================================================================
|
| 270 |
+
# POLICY_MAX_LEN: int = int(os.getenv("POLICY_MAX_LEN", "256"))
|
| 271 |
+
# CONFIDENCE_THRESHOLD: float = float(os.getenv("CONFIDENCE_THRESHOLD", "0.7"))
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
# # ========================================================================
|
| 275 |
+
# # HELPER METHODS (Required by llm_manager.py)
|
| 276 |
+
# # ========================================================================
|
| 277 |
+
|
| 278 |
+
# def is_gemini_enabled(self) -> bool:
|
| 279 |
+
# """Check if Google Gemini API is configured"""
|
| 280 |
+
# return bool(self.GOOGLE_API_KEY and self.GOOGLE_API_KEY != "")
|
| 281 |
+
|
| 282 |
+
# def is_groq_enabled(self) -> bool:
|
| 283 |
+
# """Check if Groq API is configured"""
|
| 284 |
+
# return bool(self.GROQ_API_KEY and self.GROQ_API_KEY != "")
|
| 285 |
+
|
| 286 |
+
# def is_hf_enabled(self) -> bool:
|
| 287 |
+
# """Check if HuggingFace token is configured"""
|
| 288 |
+
# return bool(self.HF_TOKEN and self.HF_TOKEN != "")
|
| 289 |
+
|
| 290 |
+
# def get_allowed_origins(self) -> List[str]:
|
| 291 |
+
# """Parse allowed origins from comma-separated string"""
|
| 292 |
+
# if self.ALLOWED_ORIGINS == "*":
|
| 293 |
+
# return ["*"]
|
| 294 |
+
# return [origin.strip() for origin in self.ALLOWED_ORIGINS.split(",")]
|
| 295 |
+
|
| 296 |
+
# # def get_llm_for_task(self, task: str = "qa"):
|
| 297 |
+
# # """
|
| 298 |
+
# # Get LLM configuration for a specific task.
|
| 299 |
+
# # Returns a dict with model settings.
|
| 300 |
+
|
| 301 |
+
# # Args:
|
| 302 |
+
# # task: Task type ('qa', 'retrieval', 'summary', etc.)
|
| 303 |
+
|
| 304 |
+
# # Returns:
|
| 305 |
+
# # dict: LLM configuration
|
| 306 |
+
# # """
|
| 307 |
+
# # return {
|
| 308 |
+
# # 'api_key': self.GOOGLE_API_KEY,
|
| 309 |
+
# # 'model': self.GEMINI_MODEL,
|
| 310 |
+
# # 'temperature': self.LLM_TEMPERATURE,
|
| 311 |
+
# # 'max_tokens': self.LLM_MAX_TOKENS,
|
| 312 |
+
# # 'task': task
|
| 313 |
+
# # }
|
| 314 |
+
# def get_llm_for_task(self, task: str = "qa") -> str:
|
| 315 |
+
# """
|
| 316 |
+
# Get LLM name for a specific task.
|
| 317 |
+
|
| 318 |
+
# Args:
|
| 319 |
+
# task: Task type ('chat', 'evaluation', etc.)
|
| 320 |
+
|
| 321 |
+
# Returns:
|
| 322 |
+
# str: LLM name ('gemini' or 'groq')
|
| 323 |
+
# """
|
| 324 |
+
# # Use Gemini for chat, Groq for evaluation
|
| 325 |
+
# if task == "evaluation":
|
| 326 |
+
# return "groq" if self.is_groq_enabled() else "gemini"
|
| 327 |
+
# else:
|
| 328 |
+
# return "gemini" # Default to Gemini for all other tasks
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
# # ============================================================================
|
| 334 |
+
# # CREATE GLOBAL SETTINGS INSTANCE
|
| 335 |
+
# # ============================================================================
|
| 336 |
+
# settings = Settings()
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
# # ============================================================================
|
| 340 |
+
# # PRINT CONFIGURATION ON LOAD
|
| 341 |
+
# # ============================================================================
|
| 342 |
+
# print("=" * 80)
|
| 343 |
+
# print("✅ Configuration Loaded")
|
| 344 |
+
# print("=" * 80)
|
| 345 |
+
# print(f"Environment: {settings.ENVIRONMENT}")
|
| 346 |
+
# print(f"Debug Mode: {settings.DEBUG}")
|
| 347 |
+
# print(f"Database: {settings.DATABASE_NAME}")
|
| 348 |
+
# print(f"Device: {settings.DEVICE}")
|
| 349 |
+
# print(f"CORS Origins: {settings.ALLOWED_ORIGINS}")
|
| 350 |
+
# print()
|
| 351 |
+
# print("🔑 API Keys:")
|
| 352 |
+
# print(f" Google Gemini: {'✅ Configured' if settings.is_gemini_enabled() else '❌ Missing'}")
|
| 353 |
+
# print(f" Groq API: {'✅ Configured' if settings.is_groq_enabled() else '⚠️ Optional (not set)'}")
|
| 354 |
+
# print(f" HuggingFace: {'✅ Configured' if settings.is_hf_enabled() else '⚠️ Optional (not set)'}")
|
| 355 |
+
# print(f" MongoDB: {'✅ Configured' if settings.MONGODB_URI else '❌ Missing'}")
|
| 356 |
+
# print(f" JWT Secret: {'✅ Configured' if settings.SECRET_KEY != 'your-secret-key-change-in-production' else '⚠️ Using default (CHANGE THIS!)'}")
|
| 357 |
+
# print()
|
| 358 |
+
# print("🤖 Model Paths:")
|
| 359 |
+
# print(f" Policy Model: {settings.POLICY_MODEL_PATH}")
|
| 360 |
+
# print(f" Retriever Model: {settings.RETRIEVER_MODEL_PATH}")
|
| 361 |
+
# print(f" FAISS Index: {settings.FAISS_INDEX_PATH}")
|
| 362 |
+
# print(f" Knowledge Base: {settings.KB_PATH}")
|
| 363 |
+
# print("=" * 80)
|
| 364 |
+
# # # ============================================================================
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
# # """
|
| 387 |
+
# # Application Configuration
|
| 388 |
+
# # Settings for Banking RAG Chatbot with JWT Authentication
|
| 389 |
+
# # Includes all settings needed by existing llm_manager.py
|
| 390 |
+
# # """
|
| 391 |
+
|
| 392 |
+
# # import os
|
| 393 |
+
# # from typing import List
|
| 394 |
+
# # from dotenv import load_dotenv
|
| 395 |
+
|
| 396 |
+
# # load_dotenv()
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
# # class Settings:
|
| 400 |
+
# # """Application settings loaded from environment variables"""
|
| 401 |
+
|
| 402 |
+
# # # ========================================================================
|
| 403 |
+
# # # ENVIRONMENT
|
| 404 |
+
# # # ========================================================================
|
| 405 |
+
# # ENVIRONMENT: str = os.getenv("ENVIRONMENT", "development")
|
| 406 |
+
# # DEBUG: bool = os.getenv("DEBUG", "True").lower() == "true"
|
| 407 |
+
|
| 408 |
+
# # # ========================================================================
|
| 409 |
+
# # # MONGODB
|
| 410 |
+
# # # ========================================================================
|
| 411 |
+
# # MONGODB_URI: str = os.getenv("MONGODB_URI", "")
|
| 412 |
+
# # DATABASE_NAME: str = os.getenv("DATABASE_NAME", "aml_ia_db")
|
| 413 |
+
|
| 414 |
+
# # # ========================================================================
|
| 415 |
+
# # # JWT AUTHENTICATION
|
| 416 |
+
# # # ========================================================================
|
| 417 |
+
# # SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key-change-in-production")
|
| 418 |
+
# # ALGORITHM: str = os.getenv("ALGORITHM", "HS256")
|
| 419 |
+
# # ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440"))
|
| 420 |
+
|
| 421 |
+
# # # ========================================================================
|
| 422 |
+
# # # CORS (for frontend)
|
| 423 |
+
# # # ========================================================================
|
| 424 |
+
# # ALLOWED_ORIGINS: str = os.getenv("ALLOWED_ORIGINS", "*")
|
| 425 |
+
|
| 426 |
+
# # # ========================================================================
|
| 427 |
+
# # # GOOGLE GEMINI API
|
| 428 |
+
# # # ========================================================================
|
| 429 |
+
# # GOOGLE_API_KEY: str = os.getenv("GOOGLE_API_KEY", "")
|
| 430 |
+
# # GEMINI_MODEL: str = os.getenv("GEMINI_MODEL", "gemini-2.0-flash-lite")
|
| 431 |
+
|
| 432 |
+
# # # ========================================================================
|
| 433 |
+
# # # GROQ API (Optional - for your llm_manager)
|
| 434 |
+
# # # ========================================================================
|
| 435 |
+
# # GROQ_API_KEY: str = os.getenv("GROQ_API_KEY", "")
|
| 436 |
+
# # GROQ_MODEL: str = os.getenv("GROQ_MODEL", "llama3-70b-8192")
|
| 437 |
+
|
| 438 |
+
# # # ========================================================================
|
| 439 |
+
# # # HUGGING FACE (Optional - for model downloads)
|
| 440 |
+
# # # ========================================================================
|
| 441 |
+
# # HF_TOKEN: str = os.getenv("HF_TOKEN", "")
|
| 442 |
+
|
| 443 |
+
# # # ========================================================================
|
| 444 |
+
# # # HELPER METHODS (Required by llm_manager.py)
|
| 445 |
+
# # # ========================================================================
|
| 446 |
+
|
| 447 |
+
# # def is_gemini_enabled(self) -> bool:
|
| 448 |
+
# # """Check if Google Gemini API is configured"""
|
| 449 |
+
# # return bool(self.GOOGLE_API_KEY and self.GOOGLE_API_KEY != "")
|
| 450 |
+
|
| 451 |
+
# # def is_groq_enabled(self) -> bool:
|
| 452 |
+
# # """Check if Groq API is configured"""
|
| 453 |
+
# # return bool(self.GROQ_API_KEY and self.GROQ_API_KEY != "")
|
| 454 |
+
|
| 455 |
+
# # def is_hf_enabled(self) -> bool:
|
| 456 |
+
# # """Check if HuggingFace token is configured"""
|
| 457 |
+
# # return bool(self.HF_TOKEN and self.HF_TOKEN != "")
|
| 458 |
+
|
| 459 |
+
# # def get_allowed_origins(self) -> List[str]:
|
| 460 |
+
# # """Parse allowed origins from comma-separated string"""
|
| 461 |
+
# # if self.ALLOWED_ORIGINS == "*":
|
| 462 |
+
# # return ["*"]
|
| 463 |
+
# # return [origin.strip() for origin in self.ALLOWED_ORIGINS.split(",")]
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
# # # ============================================================================
|
| 467 |
+
# # # CREATE GLOBAL SETTINGS INSTANCE
|
| 468 |
+
# # # ============================================================================
|
| 469 |
+
# # settings = Settings()
|
| 470 |
+
|
| 471 |
+
# # # ============================================================================
|
| 472 |
+
# # # PRINT CONFIGURATION ON LOAD
|
| 473 |
+
# # # ============================================================================
|
| 474 |
+
# # print("=" * 80)
|
| 475 |
+
# # print("✅ Configuration Loaded")
|
| 476 |
+
# # print("=" * 80)
|
| 477 |
+
# # print(f"Environment: {settings.ENVIRONMENT}")
|
| 478 |
+
# # print(f"Debug Mode: {settings.DEBUG}")
|
| 479 |
+
# # print(f"Database: {settings.DATABASE_NAME}")
|
| 480 |
+
# # # print(f"JWT Algorithm: {settings.ALGORITHM}")
|
| 481 |
+
# # # print(f"Token Expiry: {settings.ACCESS_TOKEN_EXPIRE_MINUTES} minutes")
|
| 482 |
+
# # print(f"CORS Origins: {settings.ALLOWED_ORIGINS}")
|
| 483 |
+
# # print()
|
| 484 |
+
# # print("🔑 API Keys:")
|
| 485 |
+
# # print(f" Google Gemini: {'✅ Configured' if settings.is_gemini_enabled() else '❌ Missing'}")
|
| 486 |
+
# # print(f" Groq API: {'✅ Configured' if settings.is_groq_enabled() else '⚠️ Optional (not set)'}")
|
| 487 |
+
# # print(f" HuggingFace: {'✅ Configured' if settings.is_hf_enabled() else '⚠️ Optional (not set)'}")
|
| 488 |
+
# # print(f" MongoDB: {'✅ Configured' if settings.MONGODB_URI else '❌ Missing'}")
|
| 489 |
+
# # print(f" JWT Secret: {'✅ Configured' if settings.SECRET_KEY != 'your-secret-key-change-in-production' else '⚠️ Using default (CHANGE THIS!)'}")
|
| 490 |
+
# # print("=" * 80)
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
# """
|
| 517 |
+
# Application Configuration
|
| 518 |
+
# Settings for Banking RAG Chatbot with JWT Authentication
|
| 519 |
+
# Includes all settings needed by existing llm_manager.py
|
| 520 |
+
# """
|
| 521 |
+
|
| 522 |
+
# import os
|
| 523 |
+
# from typing import List
|
| 524 |
+
# from dotenv import load_dotenv
|
| 525 |
+
|
| 526 |
+
# load_dotenv()
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
# class Settings:
|
| 530 |
+
# """Application settings loaded from environment variables"""
|
| 531 |
+
|
| 532 |
+
# # ========================================================================
|
| 533 |
+
# # ENVIRONMENT
|
| 534 |
+
# # ========================================================================
|
| 535 |
+
# ENVIRONMENT: str = os.getenv("ENVIRONMENT", "development")
|
| 536 |
+
# DEBUG: bool = os.getenv("DEBUG", "True").lower() == "true"
|
| 537 |
+
|
| 538 |
+
# # ========================================================================
|
| 539 |
+
# # MONGODB
|
| 540 |
+
# # ========================================================================
|
| 541 |
+
# MONGODB_URI: str = os.getenv("MONGODB_URI", "")
|
| 542 |
+
# DATABASE_NAME: str = os.getenv("DATABASE_NAME", "aml_ia_db")
|
| 543 |
+
|
| 544 |
+
# # ========================================================================
|
| 545 |
+
# # JWT AUTHENTICATION
|
| 546 |
+
# # ========================================================================
|
| 547 |
+
# SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key-change-in-production")
|
| 548 |
+
# ALGORITHM: str = os.getenv("ALGORITHM", "HS256")
|
| 549 |
+
# ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440"))
|
| 550 |
+
|
| 551 |
+
# # ========================================================================
|
| 552 |
+
# # CORS (for frontend)
|
| 553 |
+
# # ========================================================================
|
| 554 |
+
# ALLOWED_ORIGINS: str = os.getenv("ALLOWED_ORIGINS", "*")
|
| 555 |
+
|
| 556 |
+
# # ========================================================================
|
| 557 |
+
# # GOOGLE GEMINI API
|
| 558 |
+
# # ========================================================================
|
| 559 |
+
# GOOGLE_API_KEY: str = os.getenv("GOOGLE_API_KEY", "")
|
| 560 |
+
# GEMINI_MODEL: str = os.getenv("GEMINI_MODEL", "gemini-2.0-flash-lite")
|
| 561 |
+
|
| 562 |
+
# # ========================================================================
|
| 563 |
+
# # GROQ API (Optional - for your llm_manager)
|
| 564 |
+
# # ========================================================================
|
| 565 |
+
# GROQ_API_KEY: str = os.getenv("GROQ_API_KEY", "")
|
| 566 |
+
# GROQ_MODEL: str = os.getenv("GROQ_MODEL", "llama3-70b-8192")
|
| 567 |
+
|
| 568 |
+
# # ========================================================================
|
| 569 |
+
# # HUGGING FACE (Optional - for model downloads)
|
| 570 |
+
# # ========================================================================
|
| 571 |
+
# HF_TOKEN: str = os.getenv("HF_TOKEN", "")
|
| 572 |
+
|
| 573 |
+
# # ========================================================================
|
| 574 |
+
# # MODEL PATHS (for RL Policy Network and RAG models)
|
| 575 |
+
# # ========================================================================
|
| 576 |
+
# POLICY_MODEL_PATH: str = os.getenv("POLICY_MODEL_PATH", "models/best_policy_model.pth")
|
| 577 |
+
# RETRIEVER_MODEL_PATH: str = os.getenv("RETRIEVER_MODEL_PATH", "models/best_retriever_model.pth")
|
| 578 |
+
# FAISS_INDEX_PATH: str = os.getenv("FAISS_INDEX_PATH", "models/faiss_index.pkl")
|
| 579 |
+
# KB_PATH: str = os.getenv("KB_PATH", "data/final_knowledge_base.jsonl")
|
| 580 |
+
|
| 581 |
+
# # ========================================================================
|
| 582 |
+
# # LLM PARAMETERS
|
| 583 |
+
# # ========================================================================
|
| 584 |
+
# LLM_TEMPERATURE: float = float(os.getenv("LLM_TEMPERATURE", "0.7"))
|
| 585 |
+
# LLM_MAX_TOKENS: int = int(os.getenv("LLM_MAX_TOKENS", "512"))
|
| 586 |
+
|
| 587 |
+
# # ========================================================================
|
| 588 |
+
# # RAG PARAMETERS
|
| 589 |
+
# # ========================================================================
|
| 590 |
+
# TOP_K: int = int(os.getenv("TOP_K", "5"))
|
| 591 |
+
# SIMILARITY_THRESHOLD: float = float(os.getenv("SIMILARITY_THRESHOLD", "0.5"))
|
| 592 |
+
# MAX_CONTEXT_LENGTH: int = int(os.getenv("MAX_CONTEXT_LENGTH", "2000"))
|
| 593 |
+
|
| 594 |
+
# # ========================================================================
|
| 595 |
+
# # HELPER METHODS (Required by llm_manager.py)
|
| 596 |
+
# # ========================================================================
|
| 597 |
+
|
| 598 |
+
# def is_gemini_enabled(self) -> bool:
|
| 599 |
+
# """Check if Google Gemini API is configured"""
|
| 600 |
+
# return bool(self.GOOGLE_API_KEY and self.GOOGLE_API_KEY != "")
|
| 601 |
+
|
| 602 |
+
# def is_groq_enabled(self) -> bool:
|
| 603 |
+
# """Check if Groq API is configured"""
|
| 604 |
+
# return bool(self.GROQ_API_KEY and self.GROQ_API_KEY != "")
|
| 605 |
+
|
| 606 |
+
# def is_hf_enabled(self) -> bool:
|
| 607 |
+
# """Check if HuggingFace token is configured"""
|
| 608 |
+
# return bool(self.HF_TOKEN and self.HF_TOKEN != "")
|
| 609 |
+
|
| 610 |
+
# def get_allowed_origins(self) -> List[str]:
|
| 611 |
+
# """Parse allowed origins from comma-separated string"""
|
| 612 |
+
# if self.ALLOWED_ORIGINS == "*":
|
| 613 |
+
# return ["*"]
|
| 614 |
+
# return [origin.strip() for origin in self.ALLOWED_ORIGINS.split(",")]
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
# # ============================================================================
|
| 618 |
+
# # CREATE GLOBAL SETTINGS INSTANCE
|
| 619 |
+
# # ============================================================================
|
| 620 |
+
# settings = Settings()
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
# # ============================================================================
|
| 624 |
+
# # PRINT CONFIGURATION ON LOAD
|
| 625 |
+
# # ============================================================================
|
| 626 |
+
# print("=" * 80)
|
| 627 |
+
# print("✅ Configuration Loaded")
|
| 628 |
+
# print("=" * 80)
|
| 629 |
+
# print(f"Environment: {settings.ENVIRONMENT}")
|
| 630 |
+
# print(f"Debug Mode: {settings.DEBUG}")
|
| 631 |
+
# print(f"Database: {settings.DATABASE_NAME}")
|
| 632 |
+
# print(f"CORS Origins: {settings.ALLOWED_ORIGINS}")
|
| 633 |
+
# print()
|
| 634 |
+
# print("🔑 API Keys:")
|
| 635 |
+
# print(f" Google Gemini: {'✅ Configured' if settings.is_gemini_enabled() else '❌ Missing'}")
|
| 636 |
+
# print(f" Groq API: {'✅ Configured' if settings.is_groq_enabled() else '⚠️ Optional (not set)'}")
|
| 637 |
+
# print(f" HuggingFace: {'✅ Configured' if settings.is_hf_enabled() else '⚠️ Optional (not set)'}")
|
| 638 |
+
# print(f" MongoDB: {'✅ Configured' if settings.MONGODB_URI else '❌ Missing'}")
|
| 639 |
+
# print(f" JWT Secret: {'✅ Configured' if settings.SECRET_KEY != 'your-secret-key-change-in-production' else '⚠️ Using default (CHANGE THIS!)'}")
|
| 640 |
+
# print("=" * 80)
|
backups/backup_llm_manager.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# """
|
| 2 |
+
# Multi-LLM Manager for Google Gemini, Groq, and HuggingFace
|
| 3 |
+
# All three APIs co-exist for different purposes (no fallback logic)
|
| 4 |
+
|
| 5 |
+
# Architecture:
|
| 6 |
+
# - Google Gemini (Primary): User-facing chat responses (best quality)
|
| 7 |
+
# - Groq (Secondary): Fast inference for evaluation and specific tasks
|
| 8 |
+
# - HuggingFace: Model downloads and embeddings (always required)
|
| 9 |
+
|
| 10 |
+
# Each API has its designated purpose based on config settings.
|
| 11 |
+
# """
|
| 12 |
+
|
| 13 |
+
# import time
|
| 14 |
+
# import google.generativeai as genai
|
| 15 |
+
# from typing import List, Dict, Optional, Literal
|
| 16 |
+
# from langchain_groq import ChatGroq
|
| 17 |
+
# from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
| 18 |
+
|
| 19 |
+
# from app.config import settings
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# # ============================================================================
|
| 23 |
+
# # GOOGLE GEMINI MANAGER
|
| 24 |
+
# # ============================================================================
|
| 25 |
+
|
| 26 |
+
# class GeminiManager:
|
| 27 |
+
# """
|
| 28 |
+
# Google Gemini API Manager (Primary LLM)
|
| 29 |
+
# Handles Google Pro account with gemini-2.0-flash-lite model
|
| 30 |
+
# """
|
| 31 |
+
|
| 32 |
+
# def __init__(self):
|
| 33 |
+
# """Initialize Gemini API with your Google API key"""
|
| 34 |
+
# self.api_key = settings.GOOGLE_API_KEY
|
| 35 |
+
# self.model_name = settings.GEMINI_MODEL
|
| 36 |
+
|
| 37 |
+
# # Configure Gemini
|
| 38 |
+
# genai.configure(api_key=self.api_key)
|
| 39 |
+
|
| 40 |
+
# # Create model instance with safety settings
|
| 41 |
+
# self.model = genai.GenerativeModel(
|
| 42 |
+
# model_name=self.model_name,
|
| 43 |
+
# generation_config={
|
| 44 |
+
# "temperature": settings.LLM_TEMPERATURE,
|
| 45 |
+
# "max_output_tokens": settings.LLM_MAX_TOKENS,
|
| 46 |
+
# }
|
| 47 |
+
# )
|
| 48 |
+
|
| 49 |
+
# # Rate limiting tracking
|
| 50 |
+
# self.requests_this_minute = 0
|
| 51 |
+
# self.tokens_this_minute = 0
|
| 52 |
+
# self.last_reset = time.time()
|
| 53 |
+
|
| 54 |
+
# print(f"✅ Gemini Manager initialized: {self.model_name}")
|
| 55 |
+
|
| 56 |
+
# def _check_rate_limits(self):
|
| 57 |
+
# """
|
| 58 |
+
# Check and reset rate limit counters.
|
| 59 |
+
# Gemini Pro: 60 requests/min, 60,000 tokens/min
|
| 60 |
+
# """
|
| 61 |
+
# current_time = time.time()
|
| 62 |
+
|
| 63 |
+
# # Reset counters every minute
|
| 64 |
+
# if current_time - self.last_reset > 60:
|
| 65 |
+
# self.requests_this_minute = 0
|
| 66 |
+
# self.tokens_this_minute = 0
|
| 67 |
+
# self.last_reset = current_time
|
| 68 |
+
|
| 69 |
+
# # Check if limits exceeded
|
| 70 |
+
# if self.requests_this_minute >= settings.GEMINI_REQUESTS_PER_MINUTE:
|
| 71 |
+
# wait_time = 60 - (current_time - self.last_reset)
|
| 72 |
+
# print(f"⚠️ Gemini rate limit hit. Waiting {wait_time:.1f}s...")
|
| 73 |
+
# time.sleep(wait_time)
|
| 74 |
+
# self._check_rate_limits() # Recursive check after waiting
|
| 75 |
+
|
| 76 |
+
# async def generate(
|
| 77 |
+
# self,
|
| 78 |
+
# messages: List[Dict[str, str]],
|
| 79 |
+
# system_prompt: Optional[str] = None
|
| 80 |
+
# ) -> str:
|
| 81 |
+
# """
|
| 82 |
+
# Generate response using Gemini.
|
| 83 |
+
|
| 84 |
+
# Args:
|
| 85 |
+
# messages: List of conversation messages
|
| 86 |
+
# Format: [{'role': 'user'/'assistant', 'content': '...'}]
|
| 87 |
+
# system_prompt: Optional system prompt (prepended to first message)
|
| 88 |
+
|
| 89 |
+
# Returns:
|
| 90 |
+
# str: Generated response text
|
| 91 |
+
# """
|
| 92 |
+
# self._check_rate_limits()
|
| 93 |
+
|
| 94 |
+
# try:
|
| 95 |
+
# # Format messages for Gemini
|
| 96 |
+
# # Gemini uses 'user' and 'model' roles
|
| 97 |
+
# formatted_messages = []
|
| 98 |
+
|
| 99 |
+
# # Add system prompt as first user message if provided
|
| 100 |
+
# if system_prompt:
|
| 101 |
+
# formatted_messages.append({
|
| 102 |
+
# 'role': 'user',
|
| 103 |
+
# 'parts': [system_prompt]
|
| 104 |
+
# })
|
| 105 |
+
|
| 106 |
+
# # Convert messages
|
| 107 |
+
# for msg in messages:
|
| 108 |
+
# role = 'model' if msg['role'] == 'assistant' else 'user'
|
| 109 |
+
# formatted_messages.append({
|
| 110 |
+
# 'role': role,
|
| 111 |
+
# 'parts': [msg['content']]
|
| 112 |
+
# })
|
| 113 |
+
|
| 114 |
+
# # Generate response
|
| 115 |
+
# chat = self.model.start_chat(history=formatted_messages[:-1])
|
| 116 |
+
# response = chat.send_message(formatted_messages[-1]['parts'][0])
|
| 117 |
+
|
| 118 |
+
# # Track rate limits
|
| 119 |
+
# self.requests_this_minute += 1
|
| 120 |
+
# # Note: Token counting would require additional API call
|
| 121 |
+
# # For now, estimate ~4 chars per token
|
| 122 |
+
# estimated_tokens = len(response.text) // 4
|
| 123 |
+
# self.tokens_this_minute += estimated_tokens
|
| 124 |
+
|
| 125 |
+
# return response.text
|
| 126 |
+
|
| 127 |
+
# except Exception as e:
|
| 128 |
+
# print(f"❌ Gemini API error: {e}")
|
| 129 |
+
# raise
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# # ============================================================================
|
| 133 |
+
# # GROQ MANAGER
|
| 134 |
+
# # ============================================================================
|
| 135 |
+
|
| 136 |
+
# class GroqManager:
|
| 137 |
+
# """
|
| 138 |
+
# Groq API Manager (Secondary LLM)
|
| 139 |
+
# Handles fast inference with Llama-3-70B
|
| 140 |
+
# """
|
| 141 |
+
|
| 142 |
+
# def __init__(self):
|
| 143 |
+
# """Initialize Groq API with single API key"""
|
| 144 |
+
# self.api_key = settings.GROQ_API_KEY
|
| 145 |
+
# self.model_name = settings.GROQ_MODEL
|
| 146 |
+
|
| 147 |
+
# # Create ChatGroq instance
|
| 148 |
+
# self.llm = ChatGroq(
|
| 149 |
+
# api_key=self.api_key,
|
| 150 |
+
# model_name=self.model_name,
|
| 151 |
+
# temperature=settings.LLM_TEMPERATURE,
|
| 152 |
+
# max_tokens=settings.LLM_MAX_TOKENS
|
| 153 |
+
# )
|
| 154 |
+
|
| 155 |
+
# # Rate limiting tracking
|
| 156 |
+
# self.requests_this_minute = 0
|
| 157 |
+
# self.tokens_this_minute = 0
|
| 158 |
+
# self.last_reset = time.time()
|
| 159 |
+
|
| 160 |
+
# print(f"✅ Groq Manager initialized: {self.model_name}")
|
| 161 |
+
|
| 162 |
+
# def _check_rate_limits(self):
|
| 163 |
+
# """
|
| 164 |
+
# Check and reset rate limit counters.
|
| 165 |
+
# Groq Free: 30 requests/min, 30,000 tokens/min
|
| 166 |
+
# """
|
| 167 |
+
# current_time = time.time()
|
| 168 |
+
|
| 169 |
+
# # Reset counters every minute
|
| 170 |
+
# if current_time - self.last_reset > 60:
|
| 171 |
+
# self.requests_this_minute = 0
|
| 172 |
+
# self.tokens_this_minute = 0
|
| 173 |
+
# self.last_reset = current_time
|
| 174 |
+
|
| 175 |
+
# # Check if limits exceeded
|
| 176 |
+
# if self.requests_this_minute >= settings.GROQ_REQUESTS_PER_MINUTE:
|
| 177 |
+
# wait_time = 60 - (current_time - self.last_reset)
|
| 178 |
+
# print(f"⚠️ Groq rate limit hit. Waiting {wait_time:.1f}s...")
|
| 179 |
+
# time.sleep(wait_time)
|
| 180 |
+
# self._check_rate_limits()
|
| 181 |
+
|
| 182 |
+
# async def generate(
|
| 183 |
+
# self,
|
| 184 |
+
# messages: List[Dict[str, str]],
|
| 185 |
+
# system_prompt: Optional[str] = None
|
| 186 |
+
# ) -> str:
|
| 187 |
+
# """
|
| 188 |
+
# Generate response using Groq.
|
| 189 |
+
|
| 190 |
+
# Args:
|
| 191 |
+
# messages: List of conversation messages
|
| 192 |
+
# Format: [{'role': 'user'/'assistant', 'content': '...'}]
|
| 193 |
+
# system_prompt: Optional system prompt
|
| 194 |
+
|
| 195 |
+
# Returns:
|
| 196 |
+
# str: Generated response text
|
| 197 |
+
# """
|
| 198 |
+
# self._check_rate_limits()
|
| 199 |
+
|
| 200 |
+
# try:
|
| 201 |
+
# # Format messages for LangChain
|
| 202 |
+
# formatted_messages = []
|
| 203 |
+
|
| 204 |
+
# # Add system message if provided
|
| 205 |
+
# if system_prompt:
|
| 206 |
+
# formatted_messages.append(SystemMessage(content=system_prompt))
|
| 207 |
+
|
| 208 |
+
# # Convert conversation messages
|
| 209 |
+
# for msg in messages:
|
| 210 |
+
# if msg['role'] == 'user':
|
| 211 |
+
# formatted_messages.append(HumanMessage(content=msg['content']))
|
| 212 |
+
# elif msg['role'] == 'assistant':
|
| 213 |
+
# formatted_messages.append(AIMessage(content=msg['content']))
|
| 214 |
+
|
| 215 |
+
# # Generate response
|
| 216 |
+
# response = await self.llm.ainvoke(formatted_messages)
|
| 217 |
+
|
| 218 |
+
# # Track rate limits
|
| 219 |
+
# self.requests_this_minute += 1
|
| 220 |
+
# # Estimate tokens (rough approximation)
|
| 221 |
+
# estimated_tokens = len(response.content) // 4
|
| 222 |
+
# self.tokens_this_minute += estimated_tokens
|
| 223 |
+
|
| 224 |
+
# return response.content
|
| 225 |
+
|
| 226 |
+
# except Exception as e:
|
| 227 |
+
# print(f"❌ Groq API error: {e}")
|
| 228 |
+
# raise
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# # ============================================================================
|
| 232 |
+
# # UNIFIED LLM MANAGER (Routes to appropriate LLM)
|
| 233 |
+
# # ============================================================================
|
| 234 |
+
|
| 235 |
+
# class LLMManager:
|
| 236 |
+
# """
|
| 237 |
+
# Unified LLM Manager that routes requests to appropriate LLM.
|
| 238 |
+
|
| 239 |
+
# Routing strategy (from config):
|
| 240 |
+
# - Chat responses → Gemini (best quality for users)
|
| 241 |
+
# - Evaluation → Groq (fast, good enough for RL)
|
| 242 |
+
# - Policy → Local BERT (no API call)
|
| 243 |
+
# """
|
| 244 |
+
|
| 245 |
+
# def __init__(self):
|
| 246 |
+
# """Initialize all LLM managers"""
|
| 247 |
+
# self.gemini = None
|
| 248 |
+
# self.groq = None
|
| 249 |
+
|
| 250 |
+
# # Initialize Gemini if configured
|
| 251 |
+
# if settings.is_gemini_enabled():
|
| 252 |
+
# try:
|
| 253 |
+
# self.gemini = GeminiManager()
|
| 254 |
+
# except Exception as e:
|
| 255 |
+
# print(f"⚠️ Failed to initialize Gemini: {e}")
|
| 256 |
+
|
| 257 |
+
# # Initialize Groq if configured
|
| 258 |
+
# if settings.is_groq_enabled():
|
| 259 |
+
# try:
|
| 260 |
+
# self.groq = GroqManager()
|
| 261 |
+
# except Exception as e:
|
| 262 |
+
# print(f"⚠️ Failed to initialize Groq: {e}")
|
| 263 |
+
|
| 264 |
+
# print("✅ LLM Manager initialized")
|
| 265 |
+
|
| 266 |
+
# async def generate(
|
| 267 |
+
# self,
|
| 268 |
+
# messages: List[Dict[str, str]],
|
| 269 |
+
# system_prompt: Optional[str] = None,
|
| 270 |
+
# task: Literal["chat", "evaluation"] = "chat"
|
| 271 |
+
# ) -> str:
|
| 272 |
+
# """
|
| 273 |
+
# Generate response using appropriate LLM based on task.
|
| 274 |
+
|
| 275 |
+
# Args:
|
| 276 |
+
# messages: Conversation messages
|
| 277 |
+
# system_prompt: Optional system prompt
|
| 278 |
+
# task: Task type - "chat" (user-facing) or "evaluation" (RL training)
|
| 279 |
+
|
| 280 |
+
# Returns:
|
| 281 |
+
# str: Generated response
|
| 282 |
+
|
| 283 |
+
# Raises:
|
| 284 |
+
# ValueError: If appropriate LLM is not configured
|
| 285 |
+
# """
|
| 286 |
+
# # Determine which LLM to use based on task
|
| 287 |
+
# llm_choice = settings.get_llm_for_task(task)
|
| 288 |
+
|
| 289 |
+
# if llm_choice == "gemini":
|
| 290 |
+
# if self.gemini is None:
|
| 291 |
+
# raise ValueError("Gemini API not configured. Set GOOGLE_API_KEY in .env")
|
| 292 |
+
# return await self.gemini.generate(messages, system_prompt)
|
| 293 |
+
|
| 294 |
+
# elif llm_choice == "groq":
|
| 295 |
+
# if self.groq is None:
|
| 296 |
+
# raise ValueError("Groq API not configured. Set GROQ_API_KEY in .env")
|
| 297 |
+
# return await self.groq.generate(messages, system_prompt)
|
| 298 |
+
|
| 299 |
+
# else:
|
| 300 |
+
# raise ValueError(f"Unknown LLM choice: {llm_choice}")
|
| 301 |
+
|
| 302 |
+
# # async def generate_chat_response(
|
| 303 |
+
# # self,
|
| 304 |
+
# # query: str,
|
| 305 |
+
# # context: str,
|
| 306 |
+
# # history: List[Dict[str, str]]
|
| 307 |
+
# # ) -> str:
|
| 308 |
+
# # """
|
| 309 |
+
# # Generate chat response (uses Gemini by default).
|
| 310 |
+
|
| 311 |
+
# # Args:
|
| 312 |
+
# # query: User query
|
| 313 |
+
# # context: Retrieved context (from FAISS)
|
| 314 |
+
# # history: Conversation history
|
| 315 |
+
|
| 316 |
+
# # Returns:
|
| 317 |
+
# # str: Chat response
|
| 318 |
+
# # """
|
| 319 |
+
# # # Build system prompt
|
| 320 |
+
# # system_prompt = settings.SYSTEM_PROMPT
|
| 321 |
+
# # if context:
|
| 322 |
+
# # system_prompt += f"\n\nRelevant Information:\n{context}"
|
| 323 |
+
|
| 324 |
+
# # # Build messages
|
| 325 |
+
# # messages = history + [{'role': 'user', 'content': query}]
|
| 326 |
+
|
| 327 |
+
# # # Generate using chat LLM (Gemini)
|
| 328 |
+
# # return await self.generate(messages, system_prompt, task="chat")
|
| 329 |
+
|
| 330 |
+
# async def generate_chat_response(
|
| 331 |
+
# self,
|
| 332 |
+
# query: str,
|
| 333 |
+
# context: str,
|
| 334 |
+
# history: List[Dict[str, str]]
|
| 335 |
+
# ) -> str:
|
| 336 |
+
# """Generate chat response (uses Gemini by default)."""
|
| 337 |
+
|
| 338 |
+
# # Import the detailed prompt
|
| 339 |
+
# from app.services.chat_service import BANKING_SYSTEM_PROMPT
|
| 340 |
+
|
| 341 |
+
# # Build enhanced system prompt with context
|
| 342 |
+
# system_prompt = BANKING_SYSTEM_PROMPT
|
| 343 |
+
|
| 344 |
+
# if context:
|
| 345 |
+
# system_prompt += f"\n\nRelevant Knowledge Base Context:\n{context}"
|
| 346 |
+
# else:
|
| 347 |
+
# system_prompt += "\n\nNo specific banking documents were retrieved for this query. Provide a helpful general response while acknowledging your banking specialization."
|
| 348 |
+
|
| 349 |
+
# # Build messages
|
| 350 |
+
# messages = history + [{'role': 'user', 'content': query}]
|
| 351 |
+
|
| 352 |
+
# # Generate using chat LLM (Gemini)
|
| 353 |
+
# return await self.generate(messages, system_prompt, task="chat")
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
# async def evaluate_response(
|
| 360 |
+
# self,
|
| 361 |
+
# query: str,
|
| 362 |
+
# response: str,
|
| 363 |
+
# context: str = ""
|
| 364 |
+
# ) -> Dict:
|
| 365 |
+
# """
|
| 366 |
+
# Evaluate response quality (uses Groq for speed).
|
| 367 |
+
# Used during RL training.
|
| 368 |
+
|
| 369 |
+
# Args:
|
| 370 |
+
# query: User query
|
| 371 |
+
# response: Generated response
|
| 372 |
+
# context: Retrieved context (if any)
|
| 373 |
+
|
| 374 |
+
# Returns:
|
| 375 |
+
# dict: Evaluation results
|
| 376 |
+
# {'quality': 'Good'/'Bad', 'explanation': '...'}
|
| 377 |
+
# """
|
| 378 |
+
# eval_prompt = f"""Evaluate this response:
|
| 379 |
+
# Query: {query}
|
| 380 |
+
# Response: {response}
|
| 381 |
+
# Context used: {context if context else 'None'}
|
| 382 |
+
|
| 383 |
+
# Is this response Good or Bad? Respond with just "Good" or "Bad" and brief explanation."""
|
| 384 |
+
|
| 385 |
+
# messages = [{'role': 'user', 'content': eval_prompt}]
|
| 386 |
+
|
| 387 |
+
# # Generate using evaluation LLM (Groq)
|
| 388 |
+
# result = await self.generate(messages, task="evaluation")
|
| 389 |
+
|
| 390 |
+
# # Parse result
|
| 391 |
+
# quality = "Good" if "Good" in result else "Bad"
|
| 392 |
+
|
| 393 |
+
# return {
|
| 394 |
+
# 'quality': quality,
|
| 395 |
+
# 'explanation': result
|
| 396 |
+
# }
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
# # ============================================================================
|
| 400 |
+
# # GLOBAL LLM MANAGER INSTANCE
|
| 401 |
+
# # ============================================================================
|
| 402 |
+
# llm_manager = LLMManager()
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
# # ============================================================================
|
| 406 |
+
# # USAGE EXAMPLE (for reference)
|
| 407 |
+
# # ============================================================================
|
| 408 |
+
# """
|
| 409 |
+
# # In your service file:
|
| 410 |
+
|
| 411 |
+
# from app.core.llm_manager import llm_manager
|
| 412 |
+
|
| 413 |
+
# # Generate chat response (uses Gemini)
|
| 414 |
+
# response = await llm_manager.generate_chat_response(
|
| 415 |
+
# query="What is my account balance?",
|
| 416 |
+
# context="Your balance is $1000",
|
| 417 |
+
# history=[]
|
| 418 |
+
# )
|
| 419 |
+
|
| 420 |
+
# # Evaluate response (uses Groq)
|
| 421 |
+
# evaluation = await llm_manager.evaluate_response(
|
| 422 |
+
# query="What is my balance?",
|
| 423 |
+
# response="Your balance is $1000",
|
| 424 |
+
# context="Balance: $1000"
|
| 425 |
+
# )
|
| 426 |
+
# """
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
|
backups/backup_main.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI Main Application Entry Point
|
| 3 |
+
Banking RAG Chatbot API with JWT Authentication
|
| 4 |
+
|
| 5 |
+
This file:
|
| 6 |
+
1. Creates the FastAPI app
|
| 7 |
+
2. Configures CORS middleware
|
| 8 |
+
3. Connects to MongoDB on startup/shutdown
|
| 9 |
+
4. Includes API routers (auth + chat)
|
| 10 |
+
5. Provides health check endpoints
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from fastapi import FastAPI, Request
|
| 14 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 15 |
+
from fastapi.responses import JSONResponse
|
| 16 |
+
from contextlib import asynccontextmanager
|
| 17 |
+
|
| 18 |
+
from app.config import settings
|
| 19 |
+
from app.db.mongodb import connect_to_mongo, close_mongo_connection
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# ============================================================================
|
| 23 |
+
# LIFESPAN MANAGER (Startup & Shutdown)
|
| 24 |
+
# ============================================================================
|
| 25 |
+
|
| 26 |
+
@asynccontextmanager
|
| 27 |
+
async def lifespan(app: FastAPI):
|
| 28 |
+
"""
|
| 29 |
+
Manage application lifespan events.
|
| 30 |
+
|
| 31 |
+
Startup:
|
| 32 |
+
- Connect to MongoDB Atlas
|
| 33 |
+
- ML models load lazily on first use
|
| 34 |
+
|
| 35 |
+
Shutdown:
|
| 36 |
+
- Close MongoDB connection
|
| 37 |
+
- Cleanup resources
|
| 38 |
+
"""
|
| 39 |
+
# ========================================================================
|
| 40 |
+
# STARTUP
|
| 41 |
+
# ========================================================================
|
| 42 |
+
print("\n" + "=" * 80)
|
| 43 |
+
print("🚀 STARTING BANKING RAG CHATBOT API")
|
| 44 |
+
print("=" * 80)
|
| 45 |
+
print(f"Environment: {settings.ENVIRONMENT}")
|
| 46 |
+
print(f"Debug Mode: {settings.DEBUG}")
|
| 47 |
+
print("=" * 80)
|
| 48 |
+
|
| 49 |
+
# Connect to MongoDB
|
| 50 |
+
await connect_to_mongo()
|
| 51 |
+
|
| 52 |
+
print("\n💡 ML Models Info:")
|
| 53 |
+
print(" Policy Network: Loads on first chat request (lazy loading)")
|
| 54 |
+
print(" Retriever Model: Loads on first retrieval (lazy loading)")
|
| 55 |
+
print(" LLM (Gemini): Connects on first generation")
|
| 56 |
+
|
| 57 |
+
print("\n✅ Backend startup complete!")
|
| 58 |
+
print("=" * 80)
|
| 59 |
+
print(f"📖 API Docs: http://localhost:8000/docs")
|
| 60 |
+
print(f"🏥 Health Check: http://localhost:8000/health")
|
| 61 |
+
print(f"🔐 Register: POST http://localhost:8000/api/v1/auth/register")
|
| 62 |
+
print(f"🔑 Login: POST http://localhost:8000/api/v1/auth/login")
|
| 63 |
+
print("=" * 80 + "\n")
|
| 64 |
+
|
| 65 |
+
yield # Application runs here
|
| 66 |
+
|
| 67 |
+
# ========================================================================
|
| 68 |
+
# SHUTDOWN
|
| 69 |
+
# ========================================================================
|
| 70 |
+
print("\n" + "=" * 80)
|
| 71 |
+
print("🛑 SHUTTING DOWN API")
|
| 72 |
+
print("=" * 80)
|
| 73 |
+
|
| 74 |
+
# Close MongoDB connection
|
| 75 |
+
await close_mongo_connection()
|
| 76 |
+
|
| 77 |
+
print("✅ Shutdown complete")
|
| 78 |
+
print("=" * 80 + "\n")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ============================================================================
|
| 82 |
+
# CREATE FASTAPI APPLICATION
|
| 83 |
+
# ============================================================================
|
| 84 |
+
|
| 85 |
+
app = FastAPI(
|
| 86 |
+
title="Banking RAG Chatbot API",
|
| 87 |
+
description="""
|
| 88 |
+
🤖 AI-powered Banking Assistant with:
|
| 89 |
+
|
| 90 |
+
**Features:**
|
| 91 |
+
- 🔐 JWT Authentication (Sign up, Login, Protected routes)
|
| 92 |
+
- 💬 RAG (Retrieval-Augmented Generation)
|
| 93 |
+
- 🧠 RL-based Policy Network (BERT)
|
| 94 |
+
- 🔍 Custom E5 Retriever
|
| 95 |
+
- ✨ Google Gemini LLM
|
| 96 |
+
|
| 97 |
+
**Capabilities:**
|
| 98 |
+
- Intelligent document retrieval
|
| 99 |
+
- Context-aware responses
|
| 100 |
+
- Conversation history
|
| 101 |
+
- Real-time chat
|
| 102 |
+
- User authentication & authorization
|
| 103 |
+
""",
|
| 104 |
+
version="1.0.0",
|
| 105 |
+
docs_url="/docs",
|
| 106 |
+
redoc_url="/redoc",
|
| 107 |
+
lifespan=lifespan
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# ============================================================================
|
| 112 |
+
# CORS MIDDLEWARE
|
| 113 |
+
# ============================================================================
|
| 114 |
+
|
| 115 |
+
allowed_origins = settings.get_allowed_origins()
|
| 116 |
+
|
| 117 |
+
print("\n🌐 CORS Configuration:")
|
| 118 |
+
print(f" Allowed Origins: {allowed_origins}")
|
| 119 |
+
|
| 120 |
+
app.add_middleware(
|
| 121 |
+
CORSMiddleware,
|
| 122 |
+
allow_origins=allowed_origins,
|
| 123 |
+
allow_credentials=True,
|
| 124 |
+
allow_methods=["*"],
|
| 125 |
+
allow_headers=["*"],
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# ============================================================================
|
| 130 |
+
# INCLUDE API ROUTERS
|
| 131 |
+
# ============================================================================
|
| 132 |
+
|
| 133 |
+
from app.api.v1 import chat, auth
|
| 134 |
+
|
| 135 |
+
# Auth router (public endpoints - register, login)
|
| 136 |
+
app.include_router(
|
| 137 |
+
auth.router,
|
| 138 |
+
prefix="/api/v1/auth",
|
| 139 |
+
tags=["🔐 Authentication"]
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Chat router (protected endpoints - requires JWT token)
|
| 143 |
+
app.include_router(
|
| 144 |
+
chat.router,
|
| 145 |
+
prefix="/api/v1/chat",
|
| 146 |
+
tags=["💬 Chat"]
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# ============================================================================
|
| 151 |
+
# ROOT ENDPOINTS
|
| 152 |
+
# ============================================================================
|
| 153 |
+
|
| 154 |
+
@app.get("/", tags=["📍 Root"])
|
| 155 |
+
async def root():
|
| 156 |
+
"""
|
| 157 |
+
Root endpoint - API information and available endpoints
|
| 158 |
+
"""
|
| 159 |
+
return {
|
| 160 |
+
"message": "Banking RAG Chatbot API with Authentication",
|
| 161 |
+
"version": "1.0.0",
|
| 162 |
+
"status": "online",
|
| 163 |
+
"authentication": "JWT Bearer Token Required for chat endpoints",
|
| 164 |
+
"documentation": {
|
| 165 |
+
"swagger_ui": "/docs",
|
| 166 |
+
"redoc": "/redoc"
|
| 167 |
+
},
|
| 168 |
+
"endpoints": {
|
| 169 |
+
"auth": {
|
| 170 |
+
"register": "POST /api/v1/auth/register",
|
| 171 |
+
"login": "POST /api/v1/auth/login",
|
| 172 |
+
"me": "GET /api/v1/auth/me (requires token)",
|
| 173 |
+
"logout": "POST /api/v1/auth/logout (requires token)"
|
| 174 |
+
},
|
| 175 |
+
"chat": {
|
| 176 |
+
"send_message": "POST /api/v1/chat/ (requires token)",
|
| 177 |
+
"get_history": "GET /api/v1/chat/history/{conversation_id} (requires token)",
|
| 178 |
+
"list_conversations": "GET /api/v1/chat/conversations (requires token)",
|
| 179 |
+
"delete_conversation": "DELETE /api/v1/chat/conversation/{conversation_id} (requires token)"
|
| 180 |
+
},
|
| 181 |
+
"health": "GET /health"
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
@app.get("/health", tags=["🏥 Health"])
|
| 187 |
+
async def health_check():
|
| 188 |
+
"""
|
| 189 |
+
Comprehensive health check endpoint
|
| 190 |
+
|
| 191 |
+
Checks status of:
|
| 192 |
+
- API service
|
| 193 |
+
- MongoDB connection
|
| 194 |
+
- ML models (lazy loaded)
|
| 195 |
+
- Authentication system
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
dict: Health status of all components
|
| 199 |
+
"""
|
| 200 |
+
from app.db.mongodb import get_database
|
| 201 |
+
|
| 202 |
+
# Check MongoDB
|
| 203 |
+
mongodb_status = "connected" if get_database() is not None else "disconnected"
|
| 204 |
+
|
| 205 |
+
# Check ML models (don't load them, just check readiness)
|
| 206 |
+
ml_models_status = {
|
| 207 |
+
"policy_network": "ready (lazy load)",
|
| 208 |
+
"retriever": "ready (lazy load)",
|
| 209 |
+
"llm": "ready (API-based)"
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
# Check authentication
|
| 213 |
+
auth_status = {
|
| 214 |
+
"jwt_enabled": bool(settings.SECRET_KEY and settings.SECRET_KEY != "your-secret-key-change-in-production"),
|
| 215 |
+
"algorithm": settings.ALGORITHM,
|
| 216 |
+
"token_expiry_minutes": settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
# Overall health
|
| 220 |
+
is_healthy = mongodb_status == "connected" and auth_status["jwt_enabled"]
|
| 221 |
+
|
| 222 |
+
return {
|
| 223 |
+
"status": "healthy" if is_healthy else "degraded",
|
| 224 |
+
"api": "online",
|
| 225 |
+
"mongodb": mongodb_status,
|
| 226 |
+
"authentication": auth_status,
|
| 227 |
+
"ml_models": ml_models_status,
|
| 228 |
+
"environment": settings.ENVIRONMENT,
|
| 229 |
+
"debug_mode": settings.DEBUG
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# ============================================================================
|
| 234 |
+
# GLOBAL EXCEPTION HANDLER
|
| 235 |
+
# ============================================================================
|
| 236 |
+
|
| 237 |
+
@app.exception_handler(Exception)
|
| 238 |
+
async def global_exception_handler(request: Request, exc: Exception):
|
| 239 |
+
"""
|
| 240 |
+
Global exception handler for unhandled errors
|
| 241 |
+
"""
|
| 242 |
+
print(f"\n❌ Unhandled Exception:")
|
| 243 |
+
print(f" Path: {request.url.path}")
|
| 244 |
+
print(f" Error: {str(exc)}")
|
| 245 |
+
|
| 246 |
+
if settings.DEBUG:
|
| 247 |
+
import traceback
|
| 248 |
+
traceback.print_exc()
|
| 249 |
+
|
| 250 |
+
return JSONResponse(
|
| 251 |
+
status_code=500,
|
| 252 |
+
content={
|
| 253 |
+
"error": "Internal Server Error",
|
| 254 |
+
"detail": str(exc) if settings.DEBUG else "An unexpected error occurred",
|
| 255 |
+
"path": str(request.url.path)
|
| 256 |
+
}
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# ============================================================================
|
| 261 |
+
# MAIN ENTRY POINT (for direct execution)
|
| 262 |
+
# ============================================================================
|
| 263 |
+
|
| 264 |
+
if __name__ == "__main__":
|
| 265 |
+
import uvicorn
|
| 266 |
+
|
| 267 |
+
print("\n🚀 Starting server directly...")
|
| 268 |
+
print(" Note: For production, use: uvicorn app.main:app --host 0.0.0.0 --port 8000")
|
| 269 |
+
|
| 270 |
+
uvicorn.run(
|
| 271 |
+
"app.main:app",
|
| 272 |
+
host="0.0.0.0",
|
| 273 |
+
port=8000,
|
| 274 |
+
reload=settings.DEBUG # Auto-reload only in debug mode
|
| 275 |
+
)
|
backups/backup_requirements.txt
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# # ================================================================================
|
| 2 |
+
# # BANKING RAG CHATBOT API - DEPENDENCIES
|
| 3 |
+
# # Python 3.10+ required
|
| 4 |
+
# # ================================================================================
|
| 5 |
+
|
| 6 |
+
# # ============================================================================
|
| 7 |
+
# # CORE WEB FRAMEWORK
|
| 8 |
+
# # ============================================================================
|
| 9 |
+
# # FastAPI - Modern async web framework
|
| 10 |
+
# fastapi==0.104.1
|
| 11 |
+
|
| 12 |
+
# # Uvicorn - ASGI server for FastAPI
|
| 13 |
+
# uvicorn[standard]==0.24.0
|
| 14 |
+
|
| 15 |
+
# # Python multipart for file uploads (if needed later)
|
| 16 |
+
# python-multipart==0.0.6
|
| 17 |
+
|
| 18 |
+
# # ============================================================================
|
| 19 |
+
# # CONFIGURATION & ENVIRONMENT
|
| 20 |
+
# # ============================================================================
|
| 21 |
+
# # Pydantic - Data validation and settings management
|
| 22 |
+
# pydantic==2.5.0
|
| 23 |
+
# pydantic-settings==2.1.0
|
| 24 |
+
|
| 25 |
+
# # Python-dotenv - Load environment variables from .env file
|
| 26 |
+
# python-dotenv==1.0.0
|
| 27 |
+
|
| 28 |
+
# # ============================================================================
|
| 29 |
+
# # DATABASE - MongoDB
|
| 30 |
+
# # ============================================================================
|
| 31 |
+
# # Motor - Async MongoDB driver for FastAPI
|
| 32 |
+
# motor==3.3.2
|
| 33 |
+
|
| 34 |
+
# # PyMongo - MongoDB Python driver (used by Motor)
|
| 35 |
+
# pymongo==4.6.0
|
| 36 |
+
|
| 37 |
+
# # ============================================================================
|
| 38 |
+
# # AUTHENTICATION & SECURITY
|
| 39 |
+
# # ============================================================================
|
| 40 |
+
# # Python-jose - JWT token handling
|
| 41 |
+
# python-jose[cryptography]==3.3.0
|
| 42 |
+
|
| 43 |
+
# # Passlib - Password hashing
|
| 44 |
+
# passlib[bcrypt]==1.7.4
|
| 45 |
+
|
| 46 |
+
# # ============================================================================
|
| 47 |
+
# # MACHINE LEARNING - PYTORCH & TRANSFORMERS
|
| 48 |
+
# # ============================================================================
|
| 49 |
+
# # PyTorch - Deep learning framework
|
| 50 |
+
# torch==2.1.0
|
| 51 |
+
|
| 52 |
+
# # Transformers - HuggingFace transformers library (BERT, e5-base-v2)
|
| 53 |
+
# transformers==4.35.0
|
| 54 |
+
|
| 55 |
+
# # Sentence-Transformers - Sentence embeddings
|
| 56 |
+
# sentence-transformers==2.2.2
|
| 57 |
+
|
| 58 |
+
# # ============================================================================
|
| 59 |
+
# # VECTOR SEARCH
|
| 60 |
+
# # ============================================================================
|
| 61 |
+
# # FAISS - Facebook AI Similarity Search (CPU version)
|
| 62 |
+
# faiss-cpu==1.7.4
|
| 63 |
+
|
| 64 |
+
# # ============================================================================
|
| 65 |
+
# # LLM INTEGRATIONS
|
| 66 |
+
# # ============================================================================
|
| 67 |
+
# # LangChain - LLM orchestration framework
|
| 68 |
+
# langchain==0.1.0
|
| 69 |
+
|
| 70 |
+
# # LangChain Groq integration
|
| 71 |
+
# langchain-groq==0.0.1
|
| 72 |
+
|
| 73 |
+
# # LangChain Google GenAI (for Gemini)
|
| 74 |
+
# langchain-google-genai==1.0.0
|
| 75 |
+
|
| 76 |
+
# # Google Generative AI - Direct Gemini API
|
| 77 |
+
# google-generativeai==0.3.2
|
| 78 |
+
|
| 79 |
+
# # ============================================================================
|
| 80 |
+
# # UTILITIES
|
| 81 |
+
# # ============================================================================
|
| 82 |
+
# # NumPy - Numerical computing
|
| 83 |
+
# numpy==1.24.3
|
| 84 |
+
|
| 85 |
+
# # Tiktoken - OpenAI tokenizer (for token counting)
|
| 86 |
+
# tiktoken==0.5.1
|
| 87 |
+
|
| 88 |
+
# # Rich - Beautiful terminal output (for logging)
|
| 89 |
+
# rich==13.7.0
|
| 90 |
+
|
| 91 |
+
# # Requests - HTTP library
|
| 92 |
+
# requests==2.31.0
|
| 93 |
+
|
| 94 |
+
# # ============================================================================
|
| 95 |
+
# # OPTIONAL: DEVELOPMENT TOOLS (comment out for production)
|
| 96 |
+
# # ============================================================================
|
| 97 |
+
# # Pytest - Testing framework
|
| 98 |
+
# # pytest==7.4.3
|
| 99 |
+
|
| 100 |
+
# # Black - Code formatter
|
| 101 |
+
# # black==23.12.0
|
| 102 |
+
|
| 103 |
+
# # Flake8 - Linter
|
| 104 |
+
# # flake8==6.1.0
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# fastapi==0.104.1
|
| 122 |
+
# uvicorn[standard]==0.24.0
|
| 123 |
+
# pydantic==2.5.0
|
| 124 |
+
# pydantic-settings==2.1.0
|
| 125 |
+
# python-dotenv==1.0.0
|
| 126 |
+
# motor==3.3.2
|
| 127 |
+
# pymongo==4.6.0
|
| 128 |
+
# google-generativeai==0.3.1
|
| 129 |
+
# sentence-transformers==2.2.2
|
| 130 |
+
# faiss-cpu==1.7.4
|
| 131 |
+
# numpy==1.24.3
|
| 132 |
+
# torch==2.1.0
|
| 133 |
+
# transformers==4.35.2
|
| 134 |
+
|
| 135 |
+
# # AUTH DEPENDENCIES (NEW!)
|
| 136 |
+
# python-jose[cryptography]==3.3.0
|
| 137 |
+
# passlib[bcrypt]==1.7.4
|
| 138 |
+
# python-multipart==0.0.6
|
| 139 |
+
# bcrypt==4.1.1
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# FastAPI & Server
|
| 152 |
+
fastapi==0.104.1
|
| 153 |
+
uvicorn[standard]==0.24.0
|
| 154 |
+
|
| 155 |
+
# Data Validation
|
| 156 |
+
pydantic==2.5.0
|
| 157 |
+
pydantic-settings==2.1.0
|
| 158 |
+
python-dotenv==1.0.0
|
| 159 |
+
|
| 160 |
+
# Database
|
| 161 |
+
motor==3.3.2
|
| 162 |
+
pymongo==4.6.0
|
| 163 |
+
|
| 164 |
+
# LLM & AI Libraries
|
| 165 |
+
langchain-groq==0.1.0
|
| 166 |
+
langchain-core==0.1.0
|
| 167 |
+
huggingface-hub==0.20.0
|
| 168 |
+
|
| 169 |
+
# Embeddings & Vector Search
|
| 170 |
+
sentence-transformers==2.2.2
|
| 171 |
+
faiss-cpu==1.7.4
|
| 172 |
+
numpy==1.24.3
|
| 173 |
+
|
| 174 |
+
# ML/Deep Learning
|
| 175 |
+
torch==2.1.0
|
| 176 |
+
transformers==4.35.2
|
| 177 |
+
|
| 178 |
+
# Authentication
|
| 179 |
+
python-jose[cryptography]==3.3.0
|
| 180 |
+
passlib[bcrypt]==1.7.4
|
| 181 |
+
python-multipart==0.0.6
|
| 182 |
+
bcrypt==4.1.1
|
build_faiss_index.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Build FAISS Index from Scratch
|
| 3 |
+
Creates faiss_index.pkl from your knowledge base and trained retriever model
|
| 4 |
+
|
| 5 |
+
Run this ONCE before starting the backend:
|
| 6 |
+
python build_faiss_index.py
|
| 7 |
+
|
| 8 |
+
Author: Banking RAG Chatbot
|
| 9 |
+
Date: October 2025
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Add these lines at the very top (after docstring)
|
| 14 |
+
import os
|
| 15 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow info/warnings
|
| 16 |
+
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' # Disable oneDNN messages
|
| 17 |
+
|
| 18 |
+
import warnings
|
| 19 |
+
warnings.filterwarnings('ignore') # Suppress all warnings
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import pickle
|
| 23 |
+
import json
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
import faiss
|
| 28 |
+
import numpy as np
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
from transformers import AutoTokenizer, AutoModel
|
| 31 |
+
from typing import List
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ============================================================================
|
| 35 |
+
# CONFIGURATION - UPDATE THESE PATHS!
|
| 36 |
+
# ============================================================================
|
| 37 |
+
|
| 38 |
+
# Where is your knowledge base JSONL file?
|
| 39 |
+
KB_JSONL_FILE = "data/final_knowledge_base.jsonl"
|
| 40 |
+
|
| 41 |
+
# Where is your trained retriever model?
|
| 42 |
+
RETRIEVER_MODEL_PATH = "app/models/best_retriever_model.pth"
|
| 43 |
+
|
| 44 |
+
# Where to save the output FAISS pickle?
|
| 45 |
+
OUTPUT_PKL_FILE = "app/models/faiss_index.pkl"
|
| 46 |
+
|
| 47 |
+
# Device (auto-detect GPU/CPU)
|
| 48 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 49 |
+
|
| 50 |
+
# Batch size for encoding (reduce if you get OOM errors)
|
| 51 |
+
BATCH_SIZE = 32
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ============================================================================
|
| 55 |
+
# CUSTOM SENTENCE TRANSFORMER (Same as retriever.py)
|
| 56 |
+
# ============================================================================
|
| 57 |
+
|
| 58 |
+
class CustomSentenceTransformer(nn.Module):
|
| 59 |
+
"""
|
| 60 |
+
Custom SentenceTransformer - exact copy from retriever.py
|
| 61 |
+
Uses e5-base-v2 with mean pooling and L2 normalization
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, model_name: str = "intfloat/e5-base-v2"):
|
| 65 |
+
super().__init__()
|
| 66 |
+
print(f" Loading base model: {model_name}...")
|
| 67 |
+
self.encoder = AutoModel.from_pretrained(model_name)
|
| 68 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 69 |
+
self.config = self.encoder.config
|
| 70 |
+
print(f" ✅ Base model loaded")
|
| 71 |
+
|
| 72 |
+
def forward(self, input_ids, attention_mask):
|
| 73 |
+
"""Forward pass through BERT encoder"""
|
| 74 |
+
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
| 75 |
+
|
| 76 |
+
# Mean pooling
|
| 77 |
+
token_embeddings = outputs.last_hidden_state
|
| 78 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 79 |
+
embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
|
| 80 |
+
input_mask_expanded.sum(1), min=1e-9
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# L2 normalize
|
| 84 |
+
embeddings = F.normalize(embeddings, p=2, dim=1)
|
| 85 |
+
return embeddings
|
| 86 |
+
|
| 87 |
+
def encode(self, sentences: List[str], batch_size: int = 32) -> np.ndarray:
|
| 88 |
+
"""Encode sentences - same as training"""
|
| 89 |
+
self.eval()
|
| 90 |
+
|
| 91 |
+
if isinstance(sentences, str):
|
| 92 |
+
sentences = [sentences]
|
| 93 |
+
|
| 94 |
+
# Add 'query: ' prefix for e5-base-v2
|
| 95 |
+
processed_sentences = [f"query: {s.strip()}" for s in sentences]
|
| 96 |
+
|
| 97 |
+
all_embeddings = []
|
| 98 |
+
|
| 99 |
+
with torch.no_grad():
|
| 100 |
+
for i in range(0, len(processed_sentences), batch_size):
|
| 101 |
+
batch_sentences = processed_sentences[i:i + batch_size]
|
| 102 |
+
|
| 103 |
+
# Tokenize
|
| 104 |
+
tokens = self.tokenizer(
|
| 105 |
+
batch_sentences,
|
| 106 |
+
truncation=True,
|
| 107 |
+
padding=True,
|
| 108 |
+
max_length=128,
|
| 109 |
+
return_tensors='pt'
|
| 110 |
+
).to(self.encoder.device)
|
| 111 |
+
|
| 112 |
+
# Get embeddings
|
| 113 |
+
embeddings = self.forward(tokens['input_ids'], tokens['attention_mask'])
|
| 114 |
+
all_embeddings.append(embeddings.cpu().numpy())
|
| 115 |
+
|
| 116 |
+
return np.vstack(all_embeddings)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# ============================================================================
|
| 120 |
+
# RETRIEVER MODEL (Wrapper)
|
| 121 |
+
# ============================================================================
|
| 122 |
+
|
| 123 |
+
class RetrieverModel:
|
| 124 |
+
"""Wrapper for trained retriever model"""
|
| 125 |
+
|
| 126 |
+
def __init__(self, model_path: str, device: str = "cpu"):
|
| 127 |
+
print(f"\n🤖 Loading retriever model...")
|
| 128 |
+
print(f" Device: {device}")
|
| 129 |
+
|
| 130 |
+
self.device = device
|
| 131 |
+
self.model = CustomSentenceTransformer("intfloat/e5-base-v2").to(device)
|
| 132 |
+
|
| 133 |
+
# Load trained weights
|
| 134 |
+
print(f" Loading weights from: {model_path}")
|
| 135 |
+
try:
|
| 136 |
+
state_dict = torch.load(model_path, map_location=device)
|
| 137 |
+
self.model.load_state_dict(state_dict)
|
| 138 |
+
print(f" ✅ Trained weights loaded")
|
| 139 |
+
except Exception as e:
|
| 140 |
+
print(f" ⚠️ Warning: Could not load trained weights: {e}")
|
| 141 |
+
print(f" Using base e5-base-v2 model instead")
|
| 142 |
+
|
| 143 |
+
self.model.eval()
|
| 144 |
+
|
| 145 |
+
def encode_documents(self, documents: List[str], batch_size: int = 32) -> np.ndarray:
|
| 146 |
+
"""Encode documents"""
|
| 147 |
+
return self.model.encode(documents, batch_size=batch_size)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# ============================================================================
|
| 151 |
+
# MAIN: BUILD FAISS INDEX
|
| 152 |
+
# ============================================================================
|
| 153 |
+
|
| 154 |
+
def build_faiss_index():
|
| 155 |
+
"""Main function to build FAISS index from scratch"""
|
| 156 |
+
|
| 157 |
+
print("=" * 80)
|
| 158 |
+
print("🏗️ BUILDING FAISS INDEX FROM SCRATCH")
|
| 159 |
+
print("=" * 80)
|
| 160 |
+
|
| 161 |
+
# ========================================================================
|
| 162 |
+
# STEP 1: LOAD KNOWLEDGE BASE
|
| 163 |
+
# ========================================================================
|
| 164 |
+
print(f"\n📖 STEP 1: Loading knowledge base...")
|
| 165 |
+
print(f" File: {KB_JSONL_FILE}")
|
| 166 |
+
|
| 167 |
+
if not os.path.exists(KB_JSONL_FILE):
|
| 168 |
+
print(f" ❌ ERROR: File not found!")
|
| 169 |
+
print(f" Please copy your knowledge base to: {KB_JSONL_FILE}")
|
| 170 |
+
return False
|
| 171 |
+
|
| 172 |
+
kb_data = []
|
| 173 |
+
with open(KB_JSONL_FILE, 'r', encoding='utf-8') as f:
|
| 174 |
+
for line_num, line in enumerate(f, 1):
|
| 175 |
+
try:
|
| 176 |
+
kb_data.append(json.loads(line))
|
| 177 |
+
except json.JSONDecodeError as e:
|
| 178 |
+
print(f" ⚠️ Warning: Skipping invalid JSON on line {line_num}: {e}")
|
| 179 |
+
|
| 180 |
+
print(f" ✅ Loaded {len(kb_data)} documents")
|
| 181 |
+
|
| 182 |
+
if len(kb_data) == 0:
|
| 183 |
+
print(f" ❌ ERROR: Knowledge base is empty!")
|
| 184 |
+
return False
|
| 185 |
+
|
| 186 |
+
# ========================================================================
|
| 187 |
+
# STEP 2: PREPARE DOCUMENTS FOR ENCODING
|
| 188 |
+
# ========================================================================
|
| 189 |
+
print(f"\n📝 STEP 2: Preparing documents for encoding...")
|
| 190 |
+
|
| 191 |
+
documents = []
|
| 192 |
+
for i, item in enumerate(kb_data):
|
| 193 |
+
# Combine instruction + response for embedding (same as training)
|
| 194 |
+
instruction = item.get('instruction', '')
|
| 195 |
+
response = item.get('response', '')
|
| 196 |
+
|
| 197 |
+
# Create combined text
|
| 198 |
+
if instruction and response:
|
| 199 |
+
text = f"{instruction} {response}"
|
| 200 |
+
elif instruction:
|
| 201 |
+
text = instruction
|
| 202 |
+
elif response:
|
| 203 |
+
text = response
|
| 204 |
+
else:
|
| 205 |
+
print(f" ⚠️ Warning: Document {i} has no content, using placeholder")
|
| 206 |
+
text = "empty document"
|
| 207 |
+
|
| 208 |
+
documents.append(text)
|
| 209 |
+
|
| 210 |
+
print(f" ✅ Prepared {len(documents)} documents for encoding")
|
| 211 |
+
print(f" Average length: {sum(len(d) for d in documents) / len(documents):.1f} chars")
|
| 212 |
+
|
| 213 |
+
# ========================================================================
|
| 214 |
+
# STEP 3: LOAD RETRIEVER AND ENCODE DOCUMENTS
|
| 215 |
+
# ========================================================================
|
| 216 |
+
print(f"\n🔮 STEP 3: Encoding documents with trained retriever...")
|
| 217 |
+
|
| 218 |
+
if not os.path.exists(RETRIEVER_MODEL_PATH):
|
| 219 |
+
print(f" ❌ ERROR: Retriever model not found!")
|
| 220 |
+
print(f" Please copy your trained model to: {RETRIEVER_MODEL_PATH}")
|
| 221 |
+
return False
|
| 222 |
+
|
| 223 |
+
# Load retriever
|
| 224 |
+
retriever = RetrieverModel(RETRIEVER_MODEL_PATH, device=DEVICE)
|
| 225 |
+
|
| 226 |
+
# Encode all documents
|
| 227 |
+
print(f" Encoding {len(documents)} documents...")
|
| 228 |
+
print(f" Batch size: {BATCH_SIZE}")
|
| 229 |
+
print(f" This may take a few minutes... ☕")
|
| 230 |
+
|
| 231 |
+
try:
|
| 232 |
+
embeddings = retriever.encode_documents(documents, batch_size=BATCH_SIZE)
|
| 233 |
+
print(f" ✅ Encoded {embeddings.shape[0]} documents")
|
| 234 |
+
print(f" Embedding dimension: {embeddings.shape[1]}")
|
| 235 |
+
except Exception as e:
|
| 236 |
+
print(f" ❌ ERROR during encoding: {e}")
|
| 237 |
+
import traceback
|
| 238 |
+
traceback.print_exc()
|
| 239 |
+
return False
|
| 240 |
+
|
| 241 |
+
# ========================================================================
|
| 242 |
+
# STEP 4: BUILD FAISS INDEX
|
| 243 |
+
# ========================================================================
|
| 244 |
+
print(f"\n🔍 STEP 4: Building FAISS index...")
|
| 245 |
+
|
| 246 |
+
dimension = embeddings.shape[1]
|
| 247 |
+
print(f" Dimension: {dimension}")
|
| 248 |
+
|
| 249 |
+
# Create FAISS index (Inner Product = Cosine similarity after normalization)
|
| 250 |
+
index = faiss.IndexFlatIP(dimension)
|
| 251 |
+
|
| 252 |
+
# Normalize embeddings for cosine similarity
|
| 253 |
+
print(f" Normalizing embeddings...")
|
| 254 |
+
faiss.normalize_L2(embeddings)
|
| 255 |
+
|
| 256 |
+
# Add to index
|
| 257 |
+
print(f" Adding {embeddings.shape[0]} vectors to FAISS index...")
|
| 258 |
+
index.add(embeddings.astype('float32'))
|
| 259 |
+
|
| 260 |
+
print(f" ✅ FAISS index built successfully")
|
| 261 |
+
print(f" Total vectors: {index.ntotal}")
|
| 262 |
+
|
| 263 |
+
# ========================================================================
|
| 264 |
+
# STEP 5: SAVE AS PICKLE FILE
|
| 265 |
+
# ========================================================================
|
| 266 |
+
print(f"\n💾 STEP 5: Saving as pickle file...")
|
| 267 |
+
|
| 268 |
+
# Create models directory if it doesn't exist
|
| 269 |
+
os.makedirs(os.path.dirname(OUTPUT_PKL_FILE), exist_ok=True)
|
| 270 |
+
|
| 271 |
+
# Save tuple of (index, kb_data)
|
| 272 |
+
print(f" Pickling (index, kb_data) tuple...")
|
| 273 |
+
try:
|
| 274 |
+
with open(OUTPUT_PKL_FILE, 'wb') as f:
|
| 275 |
+
pickle.dump((index, kb_data), f)
|
| 276 |
+
|
| 277 |
+
file_size_mb = Path(OUTPUT_PKL_FILE).stat().st_size / (1024 * 1024)
|
| 278 |
+
print(f" ✅ Saved: {OUTPUT_PKL_FILE}")
|
| 279 |
+
print(f" File size: {file_size_mb:.2f} MB")
|
| 280 |
+
except Exception as e:
|
| 281 |
+
print(f" ❌ ERROR saving pickle: {e}")
|
| 282 |
+
return False
|
| 283 |
+
|
| 284 |
+
# ========================================================================
|
| 285 |
+
# STEP 6: VERIFY SAVED FILE
|
| 286 |
+
# ========================================================================
|
| 287 |
+
print(f"\n✅ STEP 6: Verifying saved file...")
|
| 288 |
+
|
| 289 |
+
try:
|
| 290 |
+
with open(OUTPUT_PKL_FILE, 'rb') as f:
|
| 291 |
+
loaded_index, loaded_kb = pickle.load(f)
|
| 292 |
+
|
| 293 |
+
print(f" ✅ Verification successful")
|
| 294 |
+
print(f" Index vectors: {loaded_index.ntotal}")
|
| 295 |
+
print(f" KB documents: {len(loaded_kb)}")
|
| 296 |
+
|
| 297 |
+
if loaded_index.ntotal != len(loaded_kb):
|
| 298 |
+
print(f" ⚠️ WARNING: Size mismatch detected!")
|
| 299 |
+
|
| 300 |
+
except Exception as e:
|
| 301 |
+
print(f" ❌ ERROR verifying file: {e}")
|
| 302 |
+
return False
|
| 303 |
+
|
| 304 |
+
# ========================================================================
|
| 305 |
+
# SUCCESS!
|
| 306 |
+
# ========================================================================
|
| 307 |
+
print("\n" + "=" * 80)
|
| 308 |
+
print("🎉 SUCCESS! FAISS INDEX BUILT AND SAVED")
|
| 309 |
+
print("=" * 80)
|
| 310 |
+
print(f"\n📊 Summary:")
|
| 311 |
+
print(f" Documents: {len(kb_data)}")
|
| 312 |
+
print(f" Vectors: {index.ntotal}")
|
| 313 |
+
print(f" Dimension: {dimension}")
|
| 314 |
+
print(f" File: {OUTPUT_PKL_FILE} ({file_size_mb:.2f} MB)")
|
| 315 |
+
print(f"\n🚀 You can now start the backend:")
|
| 316 |
+
print(f" cd backend")
|
| 317 |
+
print(f" uvicorn app.main:app --reload")
|
| 318 |
+
print("=" * 80 + "\n")
|
| 319 |
+
|
| 320 |
+
return True
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
# ============================================================================
|
| 324 |
+
# RUN SCRIPT
|
| 325 |
+
# ============================================================================
|
| 326 |
+
|
| 327 |
+
if __name__ == "__main__":
|
| 328 |
+
success = build_faiss_index()
|
| 329 |
+
|
| 330 |
+
if not success:
|
| 331 |
+
print("\n" + "=" * 80)
|
| 332 |
+
print("❌ FAILED TO BUILD FAISS INDEX")
|
| 333 |
+
print("=" * 80)
|
| 334 |
+
print("\nPlease check:")
|
| 335 |
+
print("1. Knowledge base file exists: data/final_knowledge_base.jsonl")
|
| 336 |
+
print("2. Retriever model exists: models/best_retriever_model.pth")
|
| 337 |
+
print("3. You have enough RAM (embeddings need ~1GB for 10k docs)")
|
| 338 |
+
print("=" * 80 + "\n")
|
| 339 |
+
exit(1)
|
folder_structure.txt
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
AML-IA-IMP/
|
| 2 |
+
├── backend/
|
| 3 |
+
│ ├── app/
|
| 4 |
+
│ │ ├── __init__.py
|
| 5 |
+
│ │ ├── config.py
|
| 6 |
+
│ │ ├── main.py
|
| 7 |
+
│ │ ├── api/
|
| 8 |
+
│ │ │ ├── __init__.py
|
| 9 |
+
│ │ │ └── v1/
|
| 10 |
+
│ │ │ ├── __init__.py
|
| 11 |
+
│ │ │ ├── chat.py
|
| 12 |
+
│ │ │ ├── auth.py
|
| 13 |
+
│ │ │ └── admin.py
|
| 14 |
+
│ │ ├── services/
|
| 15 |
+
│ │ │ ├── __init__.py
|
| 16 |
+
│ │ │ └── chat_service.py
|
| 17 |
+
│ │ ├── ml/
|
| 18 |
+
│ │ │ ├── __init__.py
|
| 19 |
+
│ │ │ ├── retriever.py
|
| 20 |
+
│ │ │ └── policy_network.py
|
| 21 |
+
│ │ ├── db/
|
| 22 |
+
│ │ │ ├── __init__.py
|
| 23 |
+
│ │ │ ├── mongodb.py
|
| 24 |
+
│ │ │ └── repositories/
|
| 25 |
+
│ │ │ ├── __init__.py
|
| 26 |
+
│ │ │ └── conversation_repository.py
|
| 27 |
+
│ │ └── core/
|
| 28 |
+
│ │ ├── __init__.py
|
| 29 |
+
│ │ ├── security.py
|
| 30 |
+
│ │ └── api_key_manager.py
|
| 31 |
+
│ ├── models/
|
| 32 |
+
│ │ └── (your .pth, .pt, .pkl files go here)
|
| 33 |
+
│ ├── data/
|
| 34 |
+
│ │ └── final_knowledge_base.jsonl
|
| 35 |
+
│ ├── .env
|
| 36 |
+
│ ├── .env.example
|
| 37 |
+
│ ├── requirements.txt
|
| 38 |
+
│ ├── Dockerfile
|
| 39 |
+
│ ├── .gitignore
|
| 40 |
+
│ └── README.md
|
requirements.txt
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FastAPI & Server
|
| 2 |
+
fastapi==0.104.1
|
| 3 |
+
uvicorn[standard]==0.24.0
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# Data Validation
|
| 7 |
+
pydantic==2.5.0
|
| 8 |
+
pydantic-settings==2.1.0
|
| 9 |
+
python-dotenv==1.0.0
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Database
|
| 13 |
+
motor==3.3.2
|
| 14 |
+
pymongo==4.6.0
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# LLM & AI Libraries
|
| 18 |
+
langchain-groq==0.1.9
|
| 19 |
+
langchain-core==0.2.38
|
| 20 |
+
huggingface-hub==0.24.6
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Embeddings & Vector Search
|
| 24 |
+
sentence-transformers==2.2.2
|
| 25 |
+
faiss-cpu==1.7.4
|
| 26 |
+
numpy==1.24.3
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ML/Deep Learning
|
| 30 |
+
torch==2.1.0
|
| 31 |
+
transformers==4.35.2
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Authentication
|
| 35 |
+
python-jose[cryptography]==3.3.0
|
| 36 |
+
passlib[bcrypt]==1.7.4
|
| 37 |
+
python-multipart==0.0.6
|
| 38 |
+
bcrypt==4.1.1
|