Spaces:
Sleeping
Sleeping
Update counselor.py
Browse files- counselor.py +965 -708
counselor.py
CHANGED
|
@@ -1,18 +1,14 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
UltraAdvancedHybridCounselor -
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
-
|
| 8 |
-
-
|
| 9 |
-
-
|
| 10 |
-
- Machine-readable XML is produced only when explicitly requested.
|
| 11 |
-
|
| 12 |
-
Caveats:
|
| 13 |
-
- This keeps the fine-tune persistence and background worker.
|
| 14 |
-
- Behavior is conservative: when web context or "with sources" is requested we prefer detailed markdown.
|
| 15 |
"""
|
|
|
|
| 16 |
import asyncio
|
| 17 |
import hashlib
|
| 18 |
import inspect
|
|
@@ -41,12 +37,14 @@ try:
|
|
| 41 |
import joblib
|
| 42 |
except Exception:
|
| 43 |
joblib = None
|
|
|
|
| 44 |
try:
|
| 45 |
import torch
|
| 46 |
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
|
| 47 |
from torch.optim import AdamW
|
| 48 |
except Exception:
|
| 49 |
torch = None
|
|
|
|
| 50 |
try:
|
| 51 |
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
|
| 52 |
except Exception:
|
|
@@ -64,24 +62,28 @@ except Exception:
|
|
| 64 |
LLMChain = None
|
| 65 |
|
| 66 |
try:
|
| 67 |
-
from langchain_community.retrievers import TavilySearchAPIRetriever
|
| 68 |
_TAVILY_CLASS = TavilySearchAPIRetriever
|
| 69 |
except Exception:
|
| 70 |
_TAVILY_CLASS = None
|
|
|
|
| 71 |
try:
|
| 72 |
from rag import RAGComponent
|
| 73 |
except Exception:
|
| 74 |
RAGComponent = None
|
|
|
|
| 75 |
try:
|
| 76 |
from db import SessionDB
|
| 77 |
except Exception:
|
| 78 |
SessionDB = None
|
|
|
|
| 79 |
try:
|
| 80 |
from cache import RedisCache
|
| 81 |
except Exception:
|
| 82 |
RedisCache = None
|
|
|
|
| 83 |
try:
|
| 84 |
-
from tavily import UsageLimitExceededError
|
| 85 |
except Exception:
|
| 86 |
class UsageLimitExceededError(Exception):
|
| 87 |
pass
|
|
@@ -95,35 +97,281 @@ logging.basicConfig(
|
|
| 95 |
format='%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s',
|
| 96 |
handlers=[logging.FileHandler('logs/counselor.log', encoding='utf-8'), logging.StreamHandler()]
|
| 97 |
)
|
|
|
|
| 98 |
logger = logging.getLogger(__name__)
|
| 99 |
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
class SimpleDoc:
|
| 102 |
def __init__(self, source: str, content: str, title: str = "", score: float = None):
|
| 103 |
self.metadata = {"source": source, "title": title, "score": score}
|
| 104 |
self.page_content = content
|
| 105 |
|
|
|
|
| 106 |
def create_tavily_retriever_safe(k: int = 10, logger: logging.Logger = logger, **kwargs):
|
| 107 |
global _TAVILY_CLASS
|
| 108 |
if _TAVILY_CLASS is None:
|
| 109 |
try:
|
| 110 |
-
from langchain_community.retrievers import TavilySearchAPIRetriever
|
| 111 |
_TAVILY_CLASS = TavilySearchAPIRetriever
|
| 112 |
except Exception as e:
|
| 113 |
logger.error(f"TavilySearchAPIRetriever not importable: {e}")
|
| 114 |
raise ImportError("TavilySearchAPIRetriever unavailable") from e
|
| 115 |
-
|
| 116 |
cls = _TAVILY_CLASS
|
| 117 |
try:
|
| 118 |
sig = inspect.signature(cls.__init__)
|
| 119 |
except Exception:
|
| 120 |
sig = None
|
| 121 |
-
|
| 122 |
allowed = {}
|
| 123 |
for name, val in {"k": k, **kwargs}.items():
|
| 124 |
if sig is None or (name in sig.parameters and name != "self"):
|
| 125 |
allowed[name] = val
|
| 126 |
-
|
| 127 |
try:
|
| 128 |
return cls(**allowed)
|
| 129 |
except TypeError as te:
|
|
@@ -134,14 +382,15 @@ def create_tavily_retriever_safe(k: int = 10, logger: logging.Logger = logger, *
|
|
| 134 |
logger.error(f"Tavily no-arg constructor failed: {e}")
|
| 135 |
raise
|
| 136 |
|
|
|
|
| 137 |
async def tavily_search_safe(retriever, query: str, logger: logging.Logger = logger, *args, **kwargs) -> List[Any]:
|
| 138 |
if retriever is None:
|
| 139 |
logger.debug("tavily_search_safe: retriever is None")
|
| 140 |
return []
|
| 141 |
-
|
| 142 |
async_methods = ["ainvoke", "aget_relevant_documents", "aretrieve", "asearch"]
|
| 143 |
sync_methods = ["invoke", "get_relevant_documents", "retrieve", "search"]
|
| 144 |
-
|
| 145 |
for name in async_methods:
|
| 146 |
fn = getattr(retriever, name, None)
|
| 147 |
if callable(fn):
|
|
@@ -154,7 +403,7 @@ async def tavily_search_safe(retriever, query: str, logger: logging.Logger = log
|
|
| 154 |
continue
|
| 155 |
except Exception:
|
| 156 |
continue
|
| 157 |
-
|
| 158 |
loop = asyncio.get_event_loop()
|
| 159 |
for name in sync_methods:
|
| 160 |
fn = getattr(retriever, name, None)
|
|
@@ -163,24 +412,26 @@ async def tavily_search_safe(retriever, query: str, logger: logging.Logger = log
|
|
| 163 |
return await loop.run_in_executor(None, lambda: fn(query))
|
| 164 |
except Exception:
|
| 165 |
continue
|
| 166 |
-
|
| 167 |
logger.warning("tavily_search_safe: no usable methods on retriever")
|
| 168 |
return []
|
| 169 |
|
|
|
|
| 170 |
async def tavily_rest_search(api_key: str, query: str, timeout: int = 15, logger: logging.Logger = logger) -> List[Dict[str, Any]]:
|
| 171 |
if requests is None:
|
| 172 |
logger.error("requests library not available; cannot use REST fallback for Tavily.")
|
| 173 |
return []
|
|
|
|
| 174 |
url = "https://api.tavily.com/search"
|
| 175 |
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
| 176 |
payload = {"query": query}
|
| 177 |
loop = asyncio.get_event_loop()
|
| 178 |
-
|
| 179 |
def do_post():
|
| 180 |
r = requests.post(url, json=payload, headers=headers, timeout=timeout)
|
| 181 |
r.raise_for_status()
|
| 182 |
return r.json()
|
| 183 |
-
|
| 184 |
try:
|
| 185 |
resp = await loop.run_in_executor(None, do_post)
|
| 186 |
results = resp.get("results", [])
|
|
@@ -190,115 +441,51 @@ async def tavily_rest_search(api_key: str, query: str, timeout: int = 15, logger
|
|
| 190 |
logger.exception("tavily_rest_search failed")
|
| 191 |
return []
|
| 192 |
|
|
|
|
| 193 |
def format_sources_block(docs: List[SimpleDoc]) -> str:
|
| 194 |
if not docs:
|
| 195 |
return ""
|
| 196 |
-
lines = ["\n---\n
|
| 197 |
-
for d in docs:
|
| 198 |
meta = getattr(d, "metadata", {}) or {}
|
| 199 |
url = meta.get("source") or ""
|
| 200 |
-
title = meta.get("title") or
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
if title:
|
| 211 |
-
lines.append(f" **Title:** {title}")
|
| 212 |
-
if score_display:
|
| 213 |
-
lines.append(f" **Relevance:** {score_display}")
|
| 214 |
-
lines.append("")
|
| 215 |
-
return "\n".join(lines).strip()
|
| 216 |
-
|
| 217 |
-
def _parse_llm_tagged_output(text: str) -> Dict[str, str]:
|
| 218 |
-
data = {}
|
| 219 |
-
tags = {
|
| 220 |
-
r"<SUMMARY>(.*?)</SUMMARY>": "summary",
|
| 221 |
-
r"<COMPREHENSIVE_EXPLANATION>(.*?)</COMPREHENSIVE_EXPLANATION>": "explanation",
|
| 222 |
-
r"<RELEVANT_INSIGHTS>(.*?)</RELEVANT_INSIGHTS>": "insights",
|
| 223 |
-
}
|
| 224 |
-
for pattern, key in tags.items():
|
| 225 |
-
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
|
| 226 |
-
if match:
|
| 227 |
-
content = match.group(1).strip()
|
| 228 |
-
if content:
|
| 229 |
-
data[key] = content
|
| 230 |
-
if "explanation" not in data or not data["explanation"]:
|
| 231 |
-
clean_text = re.sub(r"<[^>]+>", "", text).strip()
|
| 232 |
-
if "summary" in data and data["summary"] in clean_text:
|
| 233 |
-
clean_text = clean_text.replace(data["summary"], "").strip()
|
| 234 |
-
data["explanation"] = clean_text or "No detailed answer was generated."
|
| 235 |
-
return data
|
| 236 |
-
|
| 237 |
-
# --- New: parsing helpers to support multiple output formats ---
|
| 238 |
-
def _parse_markdown_structured(text: str) -> Dict[str, str]:
|
| 239 |
-
# Look for headings like "Summary", "Detailed Explanation", "Relevant Insights"
|
| 240 |
-
data = {"summary": "", "explanation": "", "insights": ""}
|
| 241 |
-
t = text.replace("\r\n", "\n")
|
| 242 |
-
summary_match = re.search(r"(^|\n)#+\s*Summary\s*\n(.*?)(\n#|\n\Z)", t, re.IGNORECASE | re.DOTALL)
|
| 243 |
-
if summary_match:
|
| 244 |
-
data["summary"] = summary_match.group(2).strip()
|
| 245 |
-
else:
|
| 246 |
-
m = re.search(r"(^|\n)Summary\s*[:\-]\s*(.*?)(\n\n|\n#|\n\Z)", t, re.IGNORECASE | re.DOTALL)
|
| 247 |
-
if m:
|
| 248 |
-
data["summary"] = m.group(2).strip()
|
| 249 |
-
|
| 250 |
-
expl_match = re.search(r"(^|\n)#+\s*(Detailed Explanation|Explanation|Answer)\s*\n(.*?)(\n#|\n\Z)", t, re.IGNORECASE | re.DOTALL)
|
| 251 |
-
if expl_match:
|
| 252 |
-
data["explanation"] = expl_match.group(2).strip()
|
| 253 |
-
else:
|
| 254 |
-
if data["summary"]:
|
| 255 |
-
parts = t.split(data["summary"], 1)
|
| 256 |
-
if len(parts) > 1:
|
| 257 |
-
data["explanation"] = parts[1].strip()
|
| 258 |
-
else:
|
| 259 |
-
data["explanation"] = t.strip()
|
| 260 |
-
|
| 261 |
-
insights_match = re.search(r"(^|\n)#+\s*Relevant Insights\s*\n(.*?)(\n#|\n\Z)", t, re.IGNORECASE | re.DOTALL)
|
| 262 |
-
if insights_match:
|
| 263 |
-
data["insights"] = insights_match.group(2).strip()
|
| 264 |
-
else:
|
| 265 |
-
m2 = re.search(r"(^|\n)Relevant Insights\s*[:\-]\s*(.*?)(\n#|\n\Z)", t, re.IGNORECASE | re.DOTALL)
|
| 266 |
-
if m2:
|
| 267 |
-
data["insights"] = m2.group(2).strip()
|
| 268 |
-
for k in data:
|
| 269 |
-
if not data[k]:
|
| 270 |
-
data[k] = ""
|
| 271 |
-
return data
|
| 272 |
-
|
| 273 |
-
def _parse_plain_text(text: str) -> Dict[str, str]:
|
| 274 |
-
s = text.strip()
|
| 275 |
-
if not s:
|
| 276 |
-
return {"summary": "", "explanation": "", "insights": ""}
|
| 277 |
-
sentences = re.split(r'(?<=[.!?])\s+', s)
|
| 278 |
-
summary = " ".join(sentences[:2]).strip()
|
| 279 |
-
explanation = " ".join(sentences[2:]).strip() if len(sentences) > 2 else ""
|
| 280 |
-
if not explanation:
|
| 281 |
-
explanation = summary
|
| 282 |
-
return {"summary": summary, "explanation": explanation, "insights": ""}
|
| 283 |
-
|
| 284 |
-
# --- Keyword Lists ---
|
| 285 |
_COUNTRY_KEYWORDS = {
|
| 286 |
-
"india", "usa", "united states", "canada", "uk", "united kingdom", "germany",
|
| 287 |
-
"
|
|
|
|
| 288 |
}
|
|
|
|
| 289 |
_LANGUAGE_KEYWORDS = {
|
| 290 |
-
"english", "german", "french", "spanish", "mandarin", "chinese", "japanese",
|
|
|
|
| 291 |
}
|
|
|
|
| 292 |
_ILLEGAL_TRIGGERS = [
|
| 293 |
-
r"how to make a bomb", r"detonate", r"how to assassinat", r"kill someone",
|
| 294 |
-
r"
|
| 295 |
-
r"
|
|
|
|
| 296 |
]
|
| 297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
class UltraAdvancedHybridCounselor:
|
| 299 |
def __init__(self):
|
| 300 |
logger.info(f"🐍 Python version: {sys.version}")
|
| 301 |
-
|
| 302 |
# --- Paths and Model State ---
|
| 303 |
self.model_path = "Sachin21112004/carrerflow-ai"
|
| 304 |
self.label_encoder_path = "Sachin21112004/carrerflow-ai/label_encoder.pkl"
|
|
@@ -316,30 +503,67 @@ class UltraAdvancedHybridCounselor:
|
|
| 316 |
"gemini-1.5-pro-002", "gemini-2.5-flash-lite-preview", "gemini-1.5-flash-8b-latest",
|
| 317 |
"gemini-1.5-flash-latest", "gemini-1.5-pro-latest", "gemini-1.0-pro", "gemini-pro"
|
| 318 |
]
|
| 319 |
-
|
| 320 |
# --- Persistent user-adaptation files ---
|
| 321 |
self.user_corpus_path = Path("user_corpus.txt")
|
| 322 |
self.user_keywords_path = Path("user_keywords.json")
|
| 323 |
self.user_greetings_path = Path("user_greetings.json")
|
| 324 |
-
|
| 325 |
# Fine-tuning dataset and config
|
| 326 |
self.finetune_examples_path = Path("fine_tune_examples.jsonl")
|
| 327 |
self.finetune_label_map_path = Path("fine_tune_label_map.json")
|
| 328 |
-
|
| 329 |
-
# ---
|
| 330 |
self.dataset_repo_id = os.getenv("HF_DATASET_REPO_ID", "Sachin21112004/DreamFlow-AI-Data")
|
| 331 |
self.examples_filename_in_repo = "fine_tune_examples.jsonl"
|
| 332 |
-
self.local_examples_path = Path(f"./{self.examples_filename_in_repo}")
|
| 333 |
-
|
| 334 |
self.fine_tune_interval = int(os.getenv("FINE_TUNE_INTERVAL_SECS", "300"))
|
| 335 |
self.min_examples_to_train = int(os.getenv("MIN_EXAMPLES_TO_TRAIN", "32"))
|
| 336 |
self.fine_tune_batch_size = int(os.getenv("FINE_TUNE_BATCH", "8"))
|
| 337 |
self.fine_tune_epochs = int(os.getenv("FINE_TUNE_EPOCHS", "1"))
|
| 338 |
-
|
| 339 |
# Default greetings
|
| 340 |
-
self._default_greetings = {
|
| 341 |
-
|
|
|
|
|
|
|
|
|
|
| 342 |
# Load persisted greetings and user keywords
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
try:
|
| 344 |
if self.user_greetings_path.exists():
|
| 345 |
with open(self.user_greetings_path, "r", encoding="utf-8") as f:
|
|
@@ -347,7 +571,7 @@ class UltraAdvancedHybridCounselor:
|
|
| 347 |
self.greetings = set(stored.get("greetings", [])) | self._default_greetings
|
| 348 |
else:
|
| 349 |
self.greetings = set(self._default_greetings)
|
| 350 |
-
|
| 351 |
if self.user_keywords_path.exists():
|
| 352 |
with open(self.user_keywords_path, "r", encoding="utf-8") as f:
|
| 353 |
self.user_keywords = json.load(f)
|
|
@@ -358,7 +582,8 @@ class UltraAdvancedHybridCounselor:
|
|
| 358 |
self.greetings = set(self._default_greetings)
|
| 359 |
self.user_keywords = {}
|
| 360 |
|
| 361 |
-
|
|
|
|
| 362 |
try:
|
| 363 |
if DistilBertTokenizer and DistilBertForSequenceClassification:
|
| 364 |
try:
|
|
@@ -370,29 +595,35 @@ class UltraAdvancedHybridCounselor:
|
|
| 370 |
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
| 371 |
except Exception:
|
| 372 |
self.tokenizer = None
|
| 373 |
-
|
|
|
|
| 374 |
if joblib and Path(self.label_encoder_path).exists():
|
| 375 |
self.label_encoder = joblib.load(self.label_encoder_path)
|
| 376 |
logger.info("✅ Label encoder loaded")
|
| 377 |
except Exception as e:
|
| 378 |
logger.error(f"Error loading local ML models: {e}")
|
| 379 |
|
| 380 |
-
|
|
|
|
| 381 |
try:
|
| 382 |
self.rag = RAGComponent() if RAGComponent else None
|
| 383 |
self.db = SessionDB() if SessionDB else None
|
| 384 |
-
if self.rag:
|
| 385 |
-
|
|
|
|
|
|
|
| 386 |
except Exception as e:
|
| 387 |
logger.error(f"Error initializing RAG/DB: {e}")
|
| 388 |
self.rag = None
|
| 389 |
self.db = None
|
| 390 |
|
| 391 |
-
|
|
|
|
| 392 |
self.tavily = None
|
| 393 |
self.tavily_keys_list = []
|
| 394 |
self.tavily_key_pool = None
|
| 395 |
self.current_tavily_key = None
|
|
|
|
| 396 |
try:
|
| 397 |
tavily_keys_str = os.getenv("TAVILY_API_KEY", "")
|
| 398 |
if tavily_keys_str:
|
|
@@ -407,7 +638,8 @@ class UltraAdvancedHybridCounselor:
|
|
| 407 |
logger.error(f"Error during Tavily init: {e}")
|
| 408 |
self.tavily = None
|
| 409 |
|
| 410 |
-
|
|
|
|
| 411 |
self.use_redis = os.getenv("USE_REDIS", "false").lower() == "true"
|
| 412 |
self.cache = None
|
| 413 |
if self.use_redis and RedisCache:
|
|
@@ -416,34 +648,8 @@ class UltraAdvancedHybridCounselor:
|
|
| 416 |
except Exception:
|
| 417 |
pass
|
| 418 |
|
| 419 |
-
# --- Initialize LLM ---
|
| 420 |
-
try:
|
| 421 |
-
self.llm = self._initialize_llm()
|
| 422 |
-
if self.llm: logger.info(f"✅ LLM initialized: {self.current_model}")
|
| 423 |
-
else: logger.info("LLM not initialized; operating in degraded mode.")
|
| 424 |
-
except Exception as e:
|
| 425 |
-
logger.error(f"LLM initialization error: {e}")
|
| 426 |
-
self.llm = None
|
| 427 |
-
|
| 428 |
-
# Setup prompts/chains
|
| 429 |
-
self._setup_prompts()
|
| 430 |
-
|
| 431 |
-
# --- Start background fine-tune worker if local training is possible ---
|
| 432 |
-
self._fine_tune_lock = threading.Lock()
|
| 433 |
-
self._stop_fine_tune_worker = False
|
| 434 |
-
self._fine_tune_thread = None
|
| 435 |
-
if torch and self.model is not None and self.tokenizer is not None:
|
| 436 |
-
try:
|
| 437 |
-
self._fine_tune_thread = threading.Thread(target=self._fine_tune_loop_sync, daemon=True)
|
| 438 |
-
self._fine_tune_thread.start()
|
| 439 |
-
logger.info("✅ Background fine-tune worker started.")
|
| 440 |
-
except Exception as e:
|
| 441 |
-
logger.error(f"Failed to start fine-tune background worker: {e}")
|
| 442 |
-
|
| 443 |
-
logger.info("UltraAdvancedHybridCounselor ready.")
|
| 444 |
-
|
| 445 |
-
# --- LLM and Tavily Utility Methods ---
|
| 446 |
def _get_model_priority_score(self, model_name: str) -> int:
|
|
|
|
| 447 |
priority_map = {
|
| 448 |
"gemini-2.5-flash-lite": 100, "gemini-2.5-flash": 95, "gemini-2.0-flash-lite": 90,
|
| 449 |
"gemini-2.0-flash": 85, "gemini-2.5-pro": 80, "gemini-1.5-flash": 75, "gemini-1.5-pro": 60
|
|
@@ -451,44 +657,60 @@ class UltraAdvancedHybridCounselor:
|
|
| 451 |
return priority_map.get(model_name, 10)
|
| 452 |
|
| 453 |
def _initialize_llm(self):
|
|
|
|
| 454 |
google_api_key = os.getenv("GOOGLE_API_KEY")
|
| 455 |
if not google_api_key or ChatGoogleGenerativeAI is None:
|
| 456 |
return None
|
| 457 |
-
|
| 458 |
sorted_models = sorted(self.available_models, key=self._get_model_priority_score, reverse=True)
|
| 459 |
for model_name in sorted_models:
|
| 460 |
try:
|
| 461 |
llm = ChatGoogleGenerativeAI(
|
| 462 |
-
model=model_name, temperature=0.
|
| 463 |
google_api_key=google_api_key, timeout=30, max_retries=1
|
| 464 |
)
|
| 465 |
if hasattr(llm, "invoke"):
|
| 466 |
_ = llm.invoke("ping")
|
| 467 |
elif hasattr(llm, "generate"):
|
| 468 |
_ = llm.generate("ping")
|
|
|
|
| 469 |
self.current_model = model_name
|
| 470 |
-
self.model_performance_stats[model_name] = {
|
|
|
|
|
|
|
|
|
|
| 471 |
return llm
|
| 472 |
except Exception:
|
| 473 |
continue
|
|
|
|
| 474 |
logger.error("No LLM models could be initialized.")
|
| 475 |
return None
|
| 476 |
|
| 477 |
def _fallback_to_next_model(self) -> bool:
|
|
|
|
| 478 |
if ChatGoogleGenerativeAI is None:
|
| 479 |
return False
|
|
|
|
| 480 |
try:
|
| 481 |
current_index = self.available_models.index(self.current_model) if self.current_model in self.available_models else -1
|
| 482 |
remaining = self.available_models[current_index + 1:] if current_index >= 0 else self.available_models
|
| 483 |
except Exception:
|
| 484 |
remaining = self.available_models
|
|
|
|
| 485 |
remaining = sorted(remaining, key=self._get_model_priority_score, reverse=True)
|
| 486 |
for model in remaining:
|
| 487 |
try:
|
| 488 |
-
llm = ChatGoogleGenerativeAI(
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
self.llm = llm
|
| 493 |
self.current_model = model
|
| 494 |
logger.info(f"Fell back to {model}")
|
|
@@ -498,50 +720,156 @@ class UltraAdvancedHybridCounselor:
|
|
| 498 |
return False
|
| 499 |
|
| 500 |
def _update_model_stats(self, model_name: str, success: bool, response_time: float = None, error: str = None):
|
|
|
|
| 501 |
if model_name not in self.model_performance_stats:
|
| 502 |
-
self.model_performance_stats[model_name] = {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
stats = self.model_performance_stats[model_name]
|
| 504 |
stats["total_requests"] = stats.get("total_requests", 0) + 1
|
|
|
|
| 505 |
if success:
|
| 506 |
stats["successful_requests"] = stats.get("successful_requests", 0) + 1
|
| 507 |
stats["response_time"] = response_time
|
| 508 |
stats["last_used"] = time.time()
|
| 509 |
else:
|
| 510 |
-
if error:
|
|
|
|
|
|
|
| 511 |
total = stats["total_requests"]
|
| 512 |
stats["success_rate"] = stats.get("successful_requests", 0) / total if total > 0 else 0.0
|
| 513 |
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
|
|
|
| 518 |
try:
|
| 519 |
-
|
| 520 |
-
if hasattr(chain, "ainvoke"):
|
| 521 |
-
res = await chain.ainvoke(params)
|
| 522 |
-
else:
|
| 523 |
-
loop = asyncio.get_event_loop()
|
| 524 |
-
res = await loop.run_in_executor(None, lambda: chain.invoke(params))
|
| 525 |
-
self._update_model_stats(self.current_model, True, time.time() - start)
|
| 526 |
-
return res
|
| 527 |
except Exception as e:
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 540 |
|
| 541 |
async def _call_direct_llm(self, prompt: str, max_retries: int = 2) -> str:
|
|
|
|
| 542 |
if self.llm is None:
|
| 543 |
return "LLM not available. Enable GOOGLE_API_KEY and ensure dependencies are installed."
|
| 544 |
-
|
| 545 |
for attempt in range(max_retries):
|
| 546 |
try:
|
| 547 |
start = time.time()
|
|
@@ -554,10 +882,10 @@ class UltraAdvancedHybridCounselor:
|
|
| 554 |
full_response_text = res.content if hasattr(res, 'content') else str(res)
|
| 555 |
else:
|
| 556 |
return "LLM present but has no recognized call method."
|
| 557 |
-
|
| 558 |
self._update_model_stats(self.current_model, True, time.time() - start)
|
| 559 |
return full_response_text
|
| 560 |
-
|
| 561 |
except Exception as e:
|
| 562 |
self._update_model_stats(self.current_model, False, error=str(e))
|
| 563 |
msg = str(e).lower()
|
|
@@ -567,27 +895,36 @@ class UltraAdvancedHybridCounselor:
|
|
| 567 |
continue
|
| 568 |
else:
|
| 569 |
raise RuntimeError("All models failed.")
|
|
|
|
| 570 |
if attempt < max_retries - 1:
|
| 571 |
await asyncio.sleep(2 ** attempt)
|
| 572 |
else:
|
| 573 |
logger.error(f"Direct LLM call failed after {max_retries} attempts: {e}")
|
| 574 |
raise
|
|
|
|
| 575 |
return "I encountered an error while generating the response after multiple retries."
|
| 576 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
async def _rotate_tavily_key(self, query: str, max_retries: int = None) -> list:
|
|
|
|
| 578 |
if not getattr(self, "tavily_key_pool", None) or not getattr(self, "tavily_keys_list", None):
|
| 579 |
return []
|
|
|
|
| 580 |
if max_retries is None:
|
| 581 |
max_retries = min(3, len(self.tavily_keys_list))
|
|
|
|
| 582 |
for attempt in range(max_retries):
|
| 583 |
try:
|
| 584 |
if self.current_tavily_key:
|
| 585 |
os.environ["TAVILY_API_KEY"] = self.current_tavily_key
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
search_docs = await tavily_search_safe(self.tavily, query, logger=logger)
|
| 592 |
if search_docs:
|
| 593 |
normalized = []
|
|
@@ -599,13 +936,22 @@ class UltraAdvancedHybridCounselor:
|
|
| 599 |
content = getattr(doc, "page_content", None) or (doc.get("content") if isinstance(doc, dict) else str(doc))
|
| 600 |
normalized.append(SimpleDoc(source or "", content or "", title=title or "", score=score))
|
| 601 |
return normalized
|
| 602 |
-
|
|
|
|
| 603 |
if self.current_tavily_key:
|
| 604 |
rest_results = await tavily_rest_search(self.current_tavily_key, query)
|
| 605 |
if rest_results:
|
| 606 |
-
normalized = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
return normalized
|
| 608 |
-
|
|
|
|
| 609 |
if attempt < max_retries - 1:
|
| 610 |
try:
|
| 611 |
self.current_tavily_key = next(self.tavily_key_pool)
|
|
@@ -614,6 +960,7 @@ class UltraAdvancedHybridCounselor:
|
|
| 614 |
continue
|
| 615 |
else:
|
| 616 |
break
|
|
|
|
| 617 |
except UsageLimitExceededError:
|
| 618 |
if attempt < max_retries - 1:
|
| 619 |
try:
|
|
@@ -633,243 +980,371 @@ class UltraAdvancedHybridCounselor:
|
|
| 633 |
continue
|
| 634 |
else:
|
| 635 |
break
|
|
|
|
| 636 |
logger.error("🚫 All Tavily attempts failed. Falling back to no web context.")
|
| 637 |
return []
|
| 638 |
|
| 639 |
-
#
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
return False
|
| 643 |
-
q = query.lower()
|
| 644 |
-
force_triggers = ["with sources", "with source", "show sources", "cite", "sources", "verify", "search web", "web search", "please search", "please look up", "look up", "confirm from", "confirm that"]
|
| 645 |
-
if any(t in q for t in force_triggers):
|
| 646 |
-
return True
|
| 647 |
-
if intent == "salary_info":
|
| 648 |
-
return True
|
| 649 |
-
web_triggers = ["latest", "current", "202", "trend", "trends", "salary", "average", "median", "top", "emerging", "statistics", "how much", "pay", "ctc", "package", "percent", "percentile", "growth", "outlook"]
|
| 650 |
-
if any(w in q for w in web_triggers):
|
| 651 |
-
return True
|
| 652 |
-
greetings = self.greetings
|
| 653 |
-
if q.strip() in greetings or len(q.split()) <= 4:
|
| 654 |
-
return False
|
| 655 |
-
if intent == "career_recommendation":
|
| 656 |
-
exploratory = ["explore", "help me explore", "recommend", "what should i study", "what should i do", "interests", "skills", "i like", "i enjoy", "prefer"]
|
| 657 |
-
if any(kw in q for kw in exploratory):
|
| 658 |
-
return False
|
| 659 |
-
return False
|
| 660 |
|
| 661 |
-
def
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
|
| 734 |
else:
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
if intent in ("general_guidance", "off_topic"):
|
| 763 |
-
role_desc = "GENERAL CHAT ASSISTANT (not limited to education/career)"
|
| 764 |
-
|
| 765 |
-
if fmt == "structured_xml":
|
| 766 |
-
output_instructions = """
|
| 767 |
-
**MANDATORY OUTPUT STRUCTURE**:
|
| 768 |
-
Your final output **MUST** use the following XML-like tags. **DO NOT** include any text outside of these tags, and **DO NOT** add your own meta-commentary, headers, or intros.
|
| 769 |
-
|
| 770 |
-
<SUMMARY>A concise (2-3 sentence) summary that directly answers the user's main question. This is the "executive summary".</SUMMARY>
|
| 771 |
-
|
| 772 |
-
<COMPREHENSIVE_EXPLANATION>
|
| 773 |
-
This is the main, detailed answer. Provide a full explanation, address all parts of the user's query, and integrate facts from the web context. Use markdown (like lists, bolding) for clarity. This section should contain the "proof" and detailed breakdown.
|
| 774 |
-
</COMPREHENSIVE_EXPLANATION>
|
| 775 |
-
|
| 776 |
-
<RELEVANT_INSIGHTS>
|
| 777 |
-
**This section is optional.** Only include this if you have 1-2 *highly specific, relevant* suggestions or next steps that directly follow from your explanation.
|
| 778 |
-
</RELEVANT_INSIGHTS>
|
| 779 |
-
"""
|
| 780 |
-
elif fmt == "markdown_detailed":
|
| 781 |
-
output_instructions = """
|
| 782 |
-
Produce a clear, well-structured markdown answer with these top-level sections:
|
| 783 |
-
|
| 784 |
-
## Summary
|
| 785 |
-
A concise (2-3 sentence) direct answer.
|
| 786 |
-
|
| 787 |
-
## Detailed Explanation
|
| 788 |
-
A full answer with clear subsections, lists, bolding where appropriate, and step-by-step guidance. Cite or reference external facts when web context is provided.
|
| 789 |
-
|
| 790 |
-
## Relevant Insights
|
| 791 |
-
(If you have up to 1-2 highly specific next steps or suggestions, include them here.)
|
| 792 |
-
|
| 793 |
-
At the end, include a "Sources" section or bullet list when web/contextual sources are available.
|
| 794 |
-
"""
|
| 795 |
-
else: # brief_plain
|
| 796 |
-
output_instructions = """
|
| 797 |
-
Produce a short, conversational reply (1-4 sentences). If helpful, include 1-2 concise actionable steps. Avoid long-form exposition.
|
| 798 |
-
"""
|
| 799 |
-
|
| 800 |
-
prompt = f"""
|
| 801 |
-
**INSTRUCTIONS TO AI COUNSOLER**
|
| 802 |
-
|
| 803 |
-
1. **ROLE**: You are an Ultra-Advanced Hybrid Counselor and {role_desc}, specializing in {intent.replace('_', ' ').upper()} guidance.
|
| 804 |
-
2. **USER CONTEXT**: Persona: **{persona}**. {tone_instruction}
|
| 805 |
-
3. **LANGUAGE**: Respond *only* in the language of the user's query.
|
| 806 |
-
4. **CONTEXTS**:
|
| 807 |
-
{rag_section}
|
| 808 |
-
### EXTERNAL WEB SEARCH CONTEXT (Tavily):
|
| 809 |
-
{web_context}
|
| 810 |
-
### HISTORY:
|
| 811 |
-
{history_str}
|
| 812 |
|
| 813 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 814 |
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
|
| 818 |
|
| 819 |
-
def
|
| 820 |
-
|
| 821 |
-
|
| 822 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 823 |
try:
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 828 |
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
q = user_query.lower()
|
| 832 |
-
if any(k in q for k in ["use xml", "use tags", "provide xml", "<summary>", "with tags", "machine readable", "parseable"]):
|
| 833 |
-
return "structured_xml"
|
| 834 |
-
if "with sources" in q or "cite" in q or web_used:
|
| 835 |
-
return "markdown_detailed"
|
| 836 |
-
if intent in ("salary_info",):
|
| 837 |
-
return "markdown_detailed"
|
| 838 |
-
if len(q.split()) <= 6 and any(g in q for g in self.greetings):
|
| 839 |
-
return "brief_plain"
|
| 840 |
-
if intent in ("career_recommendation", "educational_guidance", "resume_advice", "interview_prep"):
|
| 841 |
-
return "markdown_detailed"
|
| 842 |
-
if any(k in q for k in ["how to", "step by step", "steps to", "give steps", "do this"]) or any(k in q for k in ["how do i", "how can i"]):
|
| 843 |
-
return "markdown_detailed"
|
| 844 |
-
return "markdown_detailed"
|
| 845 |
-
|
| 846 |
-
# --- New: unified parser that handles multiple formats ---
|
| 847 |
-
def _parse_response_by_format(self, fmt: str, text: str) -> Dict[str, str]:
|
| 848 |
if not text:
|
| 849 |
-
return
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 853 |
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 857 |
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
|
|
|
|
|
|
|
| 863 |
|
| 864 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 865 |
|
| 866 |
-
# --- New: fine-tune persistence helpers (unchanged) ---
|
| 867 |
def _persist_fine_tune_example(self, text: str, label: str) -> None:
|
|
|
|
| 868 |
try:
|
| 869 |
line = json.dumps({"text": text, "label": label}, ensure_ascii=False)
|
| 870 |
with open(self.local_examples_path, "a", encoding="utf-8") as f:
|
| 871 |
f.write(line + "\n")
|
| 872 |
-
|
| 873 |
api = HfApi()
|
| 874 |
api.upload_file(
|
| 875 |
path_or_fileobj=self.local_examples_path,
|
|
@@ -882,6 +1357,7 @@ Produce a short, conversational reply (1-4 sentences). If helpful, include 1-2 c
|
|
| 882 |
logger.debug(f"Failed to persist fine-tune example to Hub: {e}")
|
| 883 |
|
| 884 |
def _load_fine_tune_examples(self) -> List[Dict[str, str]]:
|
|
|
|
| 885 |
try:
|
| 886 |
hf_hub_download(
|
| 887 |
repo_id=self.dataset_repo_id,
|
|
@@ -891,11 +1367,11 @@ Produce a short, conversational reply (1-4 sentences). If helpful, include 1-2 c
|
|
| 891 |
token=os.environ.get("HF_WRITE_TOKEN"),
|
| 892 |
force_filename=self.examples_filename_in_repo
|
| 893 |
)
|
| 894 |
-
|
| 895 |
if not self.local_examples_path.exists():
|
| 896 |
logger.info("No examples file found in dataset repo.")
|
| 897 |
return []
|
| 898 |
-
|
| 899 |
with open(self.local_examples_path, "r", encoding="utf-8") as f:
|
| 900 |
lines = [json.loads(l) for l in f if l.strip()]
|
| 901 |
return lines
|
|
@@ -904,6 +1380,7 @@ Produce a short, conversational reply (1-4 sentences). If helpful, include 1-2 c
|
|
| 904 |
return []
|
| 905 |
|
| 906 |
def _clear_fine_tune_examples(self, archive: bool = True):
|
|
|
|
| 907 |
api = HfApi()
|
| 908 |
try:
|
| 909 |
if archive:
|
|
@@ -923,19 +1400,23 @@ Produce a short, conversational reply (1-4 sentences). If helpful, include 1-2 c
|
|
| 923 |
repo_type="dataset",
|
| 924 |
token=os.environ.get("HF_WRITE_TOKEN")
|
| 925 |
)
|
| 926 |
-
|
| 927 |
for f in glob.glob(f"./{self.examples_filename_in_repo}*"):
|
| 928 |
try:
|
| 929 |
os.remove(f)
|
| 930 |
except Exception:
|
| 931 |
pass
|
|
|
|
| 932 |
logger.info("Archived examples file in dataset repo.")
|
| 933 |
-
|
| 934 |
except Exception as e:
|
| 935 |
logger.debug(f"Failed to clear/archive examples in Hub (non-fatal): {e}")
|
| 936 |
|
| 937 |
-
#
|
|
|
|
|
|
|
|
|
|
| 938 |
def _fine_tune_loop_sync(self):
|
|
|
|
| 939 |
logger.info("Fine-tune loop running.")
|
| 940 |
while not getattr(self, "_stop_fine_tune_worker", False):
|
| 941 |
try:
|
|
@@ -945,19 +1426,22 @@ Produce a short, conversational reply (1-4 sentences). If helpful, include 1-2 c
|
|
| 945 |
time.sleep(max(10, self.fine_tune_interval))
|
| 946 |
|
| 947 |
def _maybe_fine_tune_once(self):
|
|
|
|
| 948 |
if not self._fine_tune_lock.acquire(blocking=False):
|
| 949 |
logger.debug("Fine-tune run already in progress; skipping this iteration.")
|
| 950 |
return
|
|
|
|
| 951 |
try:
|
| 952 |
examples = self._load_fine_tune_examples()
|
| 953 |
if len(examples) < self.min_examples_to_train:
|
| 954 |
logger.debug(f"Not enough examples for fine-tune (have {len(examples)}, need {self.min_examples_to_train}).")
|
| 955 |
return
|
| 956 |
-
|
| 957 |
if not (torch and self.model is not None and self.tokenizer is not None):
|
| 958 |
logger.warning("Fine-tune prerequisites missing (torch/model/tokenizer). Skipping training.")
|
| 959 |
return
|
| 960 |
-
|
|
|
|
| 961 |
label_to_id = {}
|
| 962 |
if self.label_encoder is not None and hasattr(self.label_encoder, "classes_"):
|
| 963 |
for idx, lab in enumerate(getattr(self.label_encoder, "classes_", [])):
|
|
@@ -969,40 +1453,40 @@ Produce a short, conversational reply (1-4 sentences). If helpful, include 1-2 c
|
|
| 969 |
label_to_id = json.load(f)
|
| 970 |
except Exception:
|
| 971 |
label_to_id = {}
|
| 972 |
-
|
| 973 |
next_id = max(label_to_id.values()) + 1 if label_to_id else 0
|
| 974 |
for ex in examples:
|
| 975 |
lab = ex.get("label", "general_guidance")
|
| 976 |
if lab not in label_to_id:
|
| 977 |
label_to_id[lab] = next_id
|
| 978 |
next_id += 1
|
| 979 |
-
|
| 980 |
try:
|
| 981 |
with open(self.finetune_label_map_path, "w", encoding="utf-8") as f:
|
| 982 |
json.dump(label_to_id, f, ensure_ascii=False, indent=2)
|
| 983 |
except Exception:
|
| 984 |
pass
|
| 985 |
-
|
|
|
|
| 986 |
texts = [ex["text"] for ex in examples]
|
| 987 |
labels = [label_to_id.get(ex.get("label", "general_guidance"), 0) for ex in examples]
|
| 988 |
-
|
| 989 |
enc = self.tokenizer(texts, padding=True, truncation=True, max_length=128, return_tensors="pt")
|
| 990 |
input_ids = enc["input_ids"]
|
| 991 |
attention_mask = enc["attention_mask"]
|
| 992 |
labels_tensor = torch.tensor(labels, dtype=torch.long)
|
| 993 |
-
|
| 994 |
dataset = TensorDataset(input_ids, attention_mask, labels_tensor)
|
| 995 |
sampler = RandomSampler(dataset)
|
| 996 |
loader = DataLoader(dataset, sampler=sampler, batch_size=self.fine_tune_batch_size)
|
| 997 |
-
|
| 998 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 999 |
self.model.to(device)
|
| 1000 |
self.model.train()
|
| 1001 |
optimizer = AdamW(self.model.parameters(), lr=1e-5)
|
| 1002 |
-
|
| 1003 |
-
total_steps = len(loader) * max(1, self.fine_tune_epochs)
|
| 1004 |
logger.info(f"Starting fine-tune: {len(examples)} examples, {len(loader)} batches, epochs={self.fine_tune_epochs}")
|
| 1005 |
-
|
| 1006 |
for epoch in range(self.fine_tune_epochs):
|
| 1007 |
epoch_loss = 0.0
|
| 1008 |
for batch in loader:
|
|
@@ -1016,7 +1500,8 @@ Produce a short, conversational reply (1-4 sentences). If helpful, include 1-2 c
|
|
| 1016 |
optimizer.step()
|
| 1017 |
epoch_loss += loss.item() if loss is not None else 0.0
|
| 1018 |
logger.info(f"Fine-tune epoch {epoch+1}/{self.fine_tune_epochs} loss: {epoch_loss:.4f}")
|
| 1019 |
-
|
|
|
|
| 1020 |
try:
|
| 1021 |
self.model.save_pretrained(self.model_path)
|
| 1022 |
try:
|
|
@@ -1026,266 +1511,30 @@ Produce a short, conversational reply (1-4 sentences). If helpful, include 1-2 c
|
|
| 1026 |
logger.info(f"✅ Fine-tuned model saved to {self.model_path}")
|
| 1027 |
except Exception as e:
|
| 1028 |
logger.error(f"Failed to save fine-tuned model: {e}")
|
| 1029 |
-
|
| 1030 |
self._clear_fine_tune_examples(archive=True)
|
| 1031 |
-
|
| 1032 |
finally:
|
| 1033 |
try:
|
| 1034 |
self._fine_tune_lock.release()
|
| 1035 |
except Exception:
|
| 1036 |
pass
|
| 1037 |
|
| 1038 |
-
#
|
| 1039 |
-
|
| 1040 |
-
|
| 1041 |
-
with open(self.user_keywords_path, "w", encoding="utf-8") as f:
|
| 1042 |
-
json.dump(self.user_keywords, f, ensure_ascii=False, indent=2)
|
| 1043 |
-
except Exception as e:
|
| 1044 |
-
logger.debug(f"Failed to persist user keywords: {e}")
|
| 1045 |
-
|
| 1046 |
-
def _persist_greetings(self):
|
| 1047 |
-
try:
|
| 1048 |
-
with open(self.user_greetings_path, "w", encoding="utf-8") as f:
|
| 1049 |
-
json.dump({"greetings": sorted(list(self.greetings))}, f, ensure_ascii=False, indent=2)
|
| 1050 |
-
except Exception as e:
|
| 1051 |
-
logger.debug(f"Failed to persist user greetings: {e}")
|
| 1052 |
-
|
| 1053 |
-
def _learn_from_interaction(self, query: str, intent: str):
|
| 1054 |
-
if not query:
|
| 1055 |
-
return
|
| 1056 |
-
q = query.strip()
|
| 1057 |
-
try:
|
| 1058 |
-
with open(self.user_corpus_path, "a", encoding="utf-8") as f:
|
| 1059 |
-
f.write(q + "\n")
|
| 1060 |
-
except Exception:
|
| 1061 |
-
pass
|
| 1062 |
-
|
| 1063 |
-
tokens = [t for t in re.findall(r"\b[a-zA-Z]{2,}\b", q.lower()) if len(t) > 1]
|
| 1064 |
-
if len(tokens) <= 2 and q.lower() not in {"", "ok", "thanks", "thank you"}:
|
| 1065 |
-
if any(g in q.lower() for g in ["hi", "hello", "hey", "hlo", "hiii", "hii"]):
|
| 1066 |
-
self.greetings.add(q.lower())
|
| 1067 |
-
self._persist_greetings()
|
| 1068 |
-
|
| 1069 |
-
if intent not in self.user_keywords:
|
| 1070 |
-
self.user_keywords[intent] = {}
|
| 1071 |
-
token_counts = self.user_keywords.get(intent, {})
|
| 1072 |
-
for t in tokens:
|
| 1073 |
-
token_counts[t] = token_counts.get(t, 0) + 1
|
| 1074 |
-
self.user_keywords[intent] = token_counts
|
| 1075 |
-
self._persist_user_keywords()
|
| 1076 |
-
|
| 1077 |
-
try:
|
| 1078 |
-
self._persist_fine_tune_example(q, intent)
|
| 1079 |
-
except Exception:
|
| 1080 |
-
logger.debug("Failed to persist fine-tune example (non-fatal).")
|
| 1081 |
-
|
| 1082 |
-
# --- Core API Methods ---
|
| 1083 |
-
async def get_comprehensive_answer(self, user_query: str, session_id: str) -> AsyncGenerator[str, None]:
|
| 1084 |
-
history = []
|
| 1085 |
-
try:
|
| 1086 |
-
if self.db:
|
| 1087 |
-
history = self.db.get_history(session_id)
|
| 1088 |
-
except Exception:
|
| 1089 |
-
logger.debug("History fetch failed.")
|
| 1090 |
-
|
| 1091 |
-
history_str = "\n".join([f"User: {h[0]}\nBot: {h[1]}" for h in history]) if history else "No history yet."
|
| 1092 |
-
web_context = "No web search required or performed."
|
| 1093 |
-
sources_text = ""
|
| 1094 |
-
|
| 1095 |
-
normalized_query = user_query.strip()
|
| 1096 |
-
normalized_lower = normalized_query.lower().rstrip(".!?")
|
| 1097 |
-
greetings = set(self.greetings)
|
| 1098 |
-
|
| 1099 |
-
if normalized_lower in greetings:
|
| 1100 |
-
greeting_response = "Hello! I'm your AI education & career counselor. How can I assist you with your education or career goals today?"
|
| 1101 |
-
yield greeting_response
|
| 1102 |
-
try:
|
| 1103 |
-
if self.db:
|
| 1104 |
-
self.db.save_history(session_id, history + [[user_query, greeting_response]])
|
| 1105 |
-
except Exception:
|
| 1106 |
-
pass
|
| 1107 |
-
try:
|
| 1108 |
-
self._learn_from_interaction(user_query, "general_guidance")
|
| 1109 |
-
except Exception:
|
| 1110 |
-
pass
|
| 1111 |
-
return
|
| 1112 |
-
|
| 1113 |
-
if self._is_illegal_request(user_query):
|
| 1114 |
-
refusal = "I cannot assist with requests that enable illegal or harmful activities. I can help with lawful education, career guidance, coding practice, and study resources — please rephrase your question."
|
| 1115 |
-
yield refusal
|
| 1116 |
-
return
|
| 1117 |
-
|
| 1118 |
-
try:
|
| 1119 |
-
persona = self._detect_persona(user_query)
|
| 1120 |
-
geo_lang = self._detect_country_language(user_query)
|
| 1121 |
-
country = geo_lang.get("country")
|
| 1122 |
-
language = geo_lang.get("language")
|
| 1123 |
-
|
| 1124 |
-
intent = await self.classify_intent(user_query)
|
| 1125 |
-
logger.info(f"Intent detected: {intent}")
|
| 1126 |
-
|
| 1127 |
-
use_web = self._should_use_web_search(intent, user_query, history) or (country or language)
|
| 1128 |
-
combined_docs: List[SimpleDoc] = []
|
| 1129 |
-
|
| 1130 |
-
if use_web:
|
| 1131 |
-
search_queries: List[str] = []
|
| 1132 |
-
if country or language:
|
| 1133 |
-
search_queries.extend(self._generate_mandatory_search_queries(user_query, country, language))
|
| 1134 |
-
if intent == "salary_info" or self._should_use_web_search(intent, user_query):
|
| 1135 |
-
search_queries.append(f"Latest verified information for: {user_query}")
|
| 1136 |
-
|
| 1137 |
-
search_queries = list(set(search_queries))
|
| 1138 |
-
|
| 1139 |
-
for sq in search_queries:
|
| 1140 |
-
docs = await self._rotate_tavily_key(sq)
|
| 1141 |
-
combined_docs.extend(docs)
|
| 1142 |
-
|
| 1143 |
-
if combined_docs:
|
| 1144 |
-
unique_docs = {d.page_content: d for d in combined_docs}
|
| 1145 |
-
final_docs = list(unique_docs.values())[:10]
|
| 1146 |
-
|
| 1147 |
-
web_context = "\n\n".join([f"Source: {getattr(doc, 'metadata', {}).get('source','N/A')}\nTitle: {getattr(doc, 'metadata', {}).get('title','')}\nContent: {getattr(doc, 'page_content','')}" for doc in final_docs])
|
| 1148 |
-
sources_text = format_sources_block(final_docs)
|
| 1149 |
-
else:
|
| 1150 |
-
web_context = "Web search performed but returned no highly relevant results."
|
| 1151 |
-
sources_text = "No reliable external sources were found for this query."
|
| 1152 |
-
|
| 1153 |
-
rag_context = "No RAG content"
|
| 1154 |
-
if self.rag:
|
| 1155 |
-
rag_context = "Local knowledge base accessed and utilized."
|
| 1156 |
-
|
| 1157 |
-
fmt = self._select_output_format(intent, user_query, use_web)
|
| 1158 |
-
prompt = self._get_base_prompt_template(intent, persona, web_context, rag_context, history_str, user_query, fmt=fmt)
|
| 1159 |
-
|
| 1160 |
-
generated_answer_text = await self._call_direct_llm(prompt)
|
| 1161 |
-
|
| 1162 |
-
if not generated_answer_text:
|
| 1163 |
-
raise RuntimeError("LLM returned an empty response.")
|
| 1164 |
-
|
| 1165 |
-
parsed_data = self._parse_response_by_format(fmt, generated_answer_text)
|
| 1166 |
-
structured_response = self._format_structured_response(parsed_data, sources_text, fmt=fmt)
|
| 1167 |
-
yield structured_response
|
| 1168 |
-
|
| 1169 |
-
if self.db:
|
| 1170 |
-
try:
|
| 1171 |
-
self.db.save_history(session_id, history + [[user_query, structured_response]])
|
| 1172 |
-
except Exception:
|
| 1173 |
-
pass
|
| 1174 |
-
|
| 1175 |
-
try:
|
| 1176 |
-
self._learn_from_interaction(user_query, intent)
|
| 1177 |
-
except Exception:
|
| 1178 |
-
logger.debug("Learning step failed (non-fatal).")
|
| 1179 |
-
|
| 1180 |
-
except Exception as e:
|
| 1181 |
-
logger.error(f"❌ Error in get_comprehensive_answer: {e}", exc_info=True)
|
| 1182 |
-
error_msg = f"I'm sorry, I'm encountering an unexpected error while processing your request. Current model: {self.current_model}. Please try again."
|
| 1183 |
-
if self.db:
|
| 1184 |
-
try:
|
| 1185 |
-
self.db.save_history(session_id, history + [[user_query, error_msg]])
|
| 1186 |
-
except Exception:
|
| 1187 |
-
pass
|
| 1188 |
-
yield error_msg
|
| 1189 |
-
|
| 1190 |
-
async def classify_intent(self, query: str) -> str:
|
| 1191 |
-
if self.cache:
|
| 1192 |
-
key = f"intent_{hashlib.sha256(query.encode()).hexdigest()}"
|
| 1193 |
-
cached = self.cache.get(key)
|
| 1194 |
-
if cached:
|
| 1195 |
-
return cached
|
| 1196 |
-
|
| 1197 |
-
tokens = [t for t in re.findall(r"\b[a-zA-Z]{2,}\b", query.lower())]
|
| 1198 |
-
intent_scores = Counter()
|
| 1199 |
-
for intent_name, token_map in (self.user_keywords or {}).items():
|
| 1200 |
-
for t in tokens:
|
| 1201 |
-
intent_scores[intent_name] += token_map.get(t, 0)
|
| 1202 |
-
if intent_scores:
|
| 1203 |
-
top_intent, top_score = intent_scores.most_common(1)[0]
|
| 1204 |
-
if top_score >= 2 or (len(intent_scores) == 1 and top_score >= 1):
|
| 1205 |
-
if self.cache:
|
| 1206 |
-
try:
|
| 1207 |
-
self.cache.set(key, top_intent, ttl=3600)
|
| 1208 |
-
except Exception:
|
| 1209 |
-
pass
|
| 1210 |
-
return top_intent
|
| 1211 |
-
|
| 1212 |
-
if self.model is not None and self.tokenizer is not None and torch is not None:
|
| 1213 |
-
try:
|
| 1214 |
-
self.model.eval()
|
| 1215 |
-
with torch.no_grad():
|
| 1216 |
-
enc = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True, max_length=128)
|
| 1217 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1218 |
-
for k, v in enc.items():
|
| 1219 |
-
enc[k] = v.to(device)
|
| 1220 |
-
self.model.to(device)
|
| 1221 |
-
outputs = self.model(**enc)
|
| 1222 |
-
logits = outputs.logits.cpu().numpy().tolist()[0]
|
| 1223 |
-
label_map = {}
|
| 1224 |
-
if self.finetune_label_map_path.exists():
|
| 1225 |
-
try:
|
| 1226 |
-
with open(self.finetune_label_map_path, "r", encoding="utf-8") as f:
|
| 1227 |
-
label_map = json.load(f)
|
| 1228 |
-
except Exception:
|
| 1229 |
-
label_map = {}
|
| 1230 |
-
if label_map:
|
| 1231 |
-
id_to_label = {int(v): k for k, v in label_map.items()}
|
| 1232 |
-
pred_idx = int(max(range(len(logits)), key=lambda i: logits[i]))
|
| 1233 |
-
intent = id_to_label.get(pred_idx, "general_guidance")
|
| 1234 |
-
if self.cache:
|
| 1235 |
-
try:
|
| 1236 |
-
self.cache.set(key, intent, ttl=3600)
|
| 1237 |
-
except Exception:
|
| 1238 |
-
pass
|
| 1239 |
-
return intent
|
| 1240 |
-
except Exception:
|
| 1241 |
-
logger.debug("Local classifier prediction failed; falling back to heuristics.")
|
| 1242 |
-
|
| 1243 |
-
if self.intent_chain is None:
|
| 1244 |
-
q = query.lower()
|
| 1245 |
-
if any(w in q for w in ["resume", "cv", "cover letter"]):
|
| 1246 |
-
intent = "resume_advice"
|
| 1247 |
-
elif any(w in q for w in ["interview", "star method", "technical interview", "hr round"]):
|
| 1248 |
-
intent = "interview_prep"
|
| 1249 |
-
elif any(w in q for w in ["salary", "ctc", "package", "pay"]):
|
| 1250 |
-
intent = "salary_info"
|
| 1251 |
-
elif any(w in q for w in ["which course", "which college", "what should i study", "career", "i like"]):
|
| 1252 |
-
intent = "career_recommendation"
|
| 1253 |
-
elif any(w in q for w in ["school", "exam", "jee", "neet", "admission"]):
|
| 1254 |
-
intent = "educational_guidance"
|
| 1255 |
-
else:
|
| 1256 |
-
intent = "general_guidance"
|
| 1257 |
-
if self.cache:
|
| 1258 |
-
try:
|
| 1259 |
-
self.cache.set(key, intent, ttl=3600)
|
| 1260 |
-
except Exception:
|
| 1261 |
-
pass
|
| 1262 |
-
return intent
|
| 1263 |
-
|
| 1264 |
-
try:
|
| 1265 |
-
response = await self._safe_llm_invoke(self.intent_chain, {"query": query})
|
| 1266 |
-
intent_text = response.get("text", "") if isinstance(response, dict) else str(response)
|
| 1267 |
-
intent = intent_text.strip().lower().replace(".", "")
|
| 1268 |
-
valid = ["educational_guidance", "career_recommendation", "resume_advice", "interview_prep", "salary_info", "general_guidance", "off_topic"]
|
| 1269 |
-
if intent not in valid:
|
| 1270 |
-
intent = "general_guidance"
|
| 1271 |
-
if self.cache:
|
| 1272 |
-
try:
|
| 1273 |
-
self.cache.set(key, intent, ttl=3600)
|
| 1274 |
-
except Exception:
|
| 1275 |
-
pass
|
| 1276 |
-
return intent
|
| 1277 |
-
except Exception as e:
|
| 1278 |
-
logger.error(f"Intent classification failed: {e}")
|
| 1279 |
-
return "general_guidance"
|
| 1280 |
|
| 1281 |
async def predict_career(self, query: str) -> Dict[str, Any]:
|
|
|
|
| 1282 |
if self.cache:
|
| 1283 |
key = f"predict_{hashlib.sha256(query.encode()).hexdigest()}"
|
| 1284 |
cached = self.cache.get(key)
|
| 1285 |
if cached:
|
| 1286 |
return cached
|
|
|
|
| 1287 |
if not (self.model and self.tokenizer and torch and self.label_encoder is not None):
|
| 1288 |
return {"recommendation": None, "confidence": 0.0, "error": "Local prediction unavailable"}
|
|
|
|
| 1289 |
try:
|
| 1290 |
inputs = self.tokenizer(query.lower(), return_tensors="pt", padding=True, truncation=True, max_length=128)
|
| 1291 |
with torch.no_grad():
|
|
@@ -1302,8 +1551,12 @@ Produce a short, conversational reply (1-4 sentences). If helpful, include 1-2 c
|
|
| 1302 |
logger.error(f"Prediction failed: {e}")
|
| 1303 |
return {"recommendation": None, "confidence": 0.0, "error": str(e)}
|
| 1304 |
|
| 1305 |
-
#
|
|
|
|
|
|
|
|
|
|
| 1306 |
def get_current_model_info(self) -> Dict[str, Any]:
|
|
|
|
| 1307 |
return {
|
| 1308 |
"current_model": self.current_model,
|
| 1309 |
"available_models": self.available_models,
|
|
@@ -1315,6 +1568,7 @@ Produce a short, conversational reply (1-4 sentences). If helpful, include 1-2 c
|
|
| 1315 |
}
|
| 1316 |
|
| 1317 |
def get_health_status(self) -> Dict[str, Any]:
|
|
|
|
| 1318 |
try:
|
| 1319 |
total_models = len(self.available_models)
|
| 1320 |
working = sum(1 for s in self.model_performance_stats.values() if s.get("success_rate", 0) > 0)
|
|
@@ -1332,31 +1586,34 @@ Produce a short, conversational reply (1-4 sentences). If helpful, include 1-2 c
|
|
| 1332 |
except Exception as e:
|
| 1333 |
return {"status": "error", "error": str(e), "last_updated": time.time()}
|
| 1334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1335 |
if __name__ == "__main__":
|
| 1336 |
async def demo():
|
| 1337 |
c = UltraAdvancedHybridCounselor()
|
| 1338 |
-
|
| 1339 |
-
|
| 1340 |
-
|
| 1341 |
-
|
| 1342 |
-
|
| 1343 |
-
|
| 1344 |
-
|
| 1345 |
-
|
| 1346 |
-
|
| 1347 |
-
|
| 1348 |
-
|
| 1349 |
-
print(
|
| 1350 |
-
|
| 1351 |
-
|
| 1352 |
-
|
| 1353 |
-
|
| 1354 |
-
|
| 1355 |
-
|
| 1356 |
-
|
| 1357 |
-
print(out)
|
| 1358 |
-
|
| 1359 |
try:
|
| 1360 |
asyncio.run(demo())
|
| 1361 |
except Exception as e:
|
| 1362 |
-
logger.error(f"Demo failed: {e}")
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
UltraAdvancedHybridCounselor - PREMIUM Edition with Intelligent Adaptive Formatting
|
| 4 |
+
|
| 5 |
+
Key improvements:
|
| 6 |
+
- Intelligent query type detection (quick, definition, list, howto, comparison, etc.)
|
| 7 |
+
- Adaptive output formatting based on query type
|
| 8 |
+
- No forced structure - responses feel natural and premium
|
| 9 |
+
- Proper markdown with line breaks preserved
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
"""
|
| 11 |
+
|
| 12 |
import asyncio
|
| 13 |
import hashlib
|
| 14 |
import inspect
|
|
|
|
| 37 |
import joblib
|
| 38 |
except Exception:
|
| 39 |
joblib = None
|
| 40 |
+
|
| 41 |
try:
|
| 42 |
import torch
|
| 43 |
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
|
| 44 |
from torch.optim import AdamW
|
| 45 |
except Exception:
|
| 46 |
torch = None
|
| 47 |
+
|
| 48 |
try:
|
| 49 |
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
|
| 50 |
except Exception:
|
|
|
|
| 62 |
LLMChain = None
|
| 63 |
|
| 64 |
try:
|
| 65 |
+
from langchain_community.retrievers import TavilySearchAPIRetriever
|
| 66 |
_TAVILY_CLASS = TavilySearchAPIRetriever
|
| 67 |
except Exception:
|
| 68 |
_TAVILY_CLASS = None
|
| 69 |
+
|
| 70 |
try:
|
| 71 |
from rag import RAGComponent
|
| 72 |
except Exception:
|
| 73 |
RAGComponent = None
|
| 74 |
+
|
| 75 |
try:
|
| 76 |
from db import SessionDB
|
| 77 |
except Exception:
|
| 78 |
SessionDB = None
|
| 79 |
+
|
| 80 |
try:
|
| 81 |
from cache import RedisCache
|
| 82 |
except Exception:
|
| 83 |
RedisCache = None
|
| 84 |
+
|
| 85 |
try:
|
| 86 |
+
from tavily import UsageLimitExceededError
|
| 87 |
except Exception:
|
| 88 |
class UsageLimitExceededError(Exception):
|
| 89 |
pass
|
|
|
|
| 97 |
format='%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s',
|
| 98 |
handlers=[logging.FileHandler('logs/counselor.log', encoding='utf-8'), logging.StreamHandler()]
|
| 99 |
)
|
| 100 |
+
|
| 101 |
logger = logging.getLogger(__name__)
|
| 102 |
|
| 103 |
+
|
| 104 |
+
# ============================================
|
| 105 |
+
# QUERY TYPE DETECTION - THE KEY INNOVATION
|
| 106 |
+
# ============================================
|
| 107 |
+
|
| 108 |
+
class QueryType:
|
| 109 |
+
QUICK = "quick" # "What is X?" - 1-2 sentence answer
|
| 110 |
+
DEFINITION = "definition" # "Define X" / "Explain X" - short paragraph
|
| 111 |
+
LIST = "list" # "Give me list of..." / "Top 10..." - bullet points
|
| 112 |
+
HOWTO = "howto" # "How to..." / "Steps to..." - numbered steps
|
| 113 |
+
COMPARISON = "comparison" # "X vs Y" / "Difference between..." - table/comparison
|
| 114 |
+
ROADMAP = "roadmap" # "Roadmap for..." / "Path to become..." - timeline
|
| 115 |
+
SYLLABUS = "syllabus" # "Syllabus for..." / "Curriculum..." - structured list
|
| 116 |
+
DETAILED = "detailed" # Complex questions needing full explanation
|
| 117 |
+
CONVERSATIONAL = "conversational" # Casual chat, follow-ups
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def detect_query_type(query: str) -> str:
|
| 121 |
+
"""Intelligently detect the type of query to format response appropriately."""
|
| 122 |
+
q = query.lower().strip()
|
| 123 |
+
words = q.split()
|
| 124 |
+
word_count = len(words)
|
| 125 |
+
|
| 126 |
+
# Quick answers (very short questions about facts)
|
| 127 |
+
quick_patterns = [
|
| 128 |
+
r"^what is [a-z\s]{1,30}\??$",
|
| 129 |
+
r"^who is ",
|
| 130 |
+
r"^when (is|was|did) ",
|
| 131 |
+
r"^where (is|was) ",
|
| 132 |
+
r"^is [a-z\s]+ (a|an) ",
|
| 133 |
+
]
|
| 134 |
+
for pattern in quick_patterns:
|
| 135 |
+
if re.match(pattern, q) and word_count <= 6:
|
| 136 |
+
return QueryType.QUICK
|
| 137 |
+
|
| 138 |
+
# Definition requests
|
| 139 |
+
if any(q.startswith(p) for p in ["define ", "explain ", "what does ", "meaning of "]):
|
| 140 |
+
return QueryType.DEFINITION
|
| 141 |
+
|
| 142 |
+
# Syllabus/Curriculum requests
|
| 143 |
+
syllabus_triggers = ["syllabus", "curriculum", "course content", "course outline", "topics covered", "what to study"]
|
| 144 |
+
if any(t in q for t in syllabus_triggers):
|
| 145 |
+
return QueryType.SYLLABUS
|
| 146 |
+
|
| 147 |
+
# Roadmap requests
|
| 148 |
+
roadmap_triggers = ["roadmap", "path to become", "how to become", "career path", "learning path",
|
| 149 |
+
"journey to", "steps to become", "guide to become"]
|
| 150 |
+
if any(t in q for t in roadmap_triggers):
|
| 151 |
+
return QueryType.ROADMAP
|
| 152 |
+
|
| 153 |
+
# List requests
|
| 154 |
+
list_triggers = ["list of", "give me list", "top 10", "top 5", "best ", "names of",
|
| 155 |
+
"examples of", "types of", "kinds of", "options for"]
|
| 156 |
+
if any(t in q for t in list_triggers):
|
| 157 |
+
return QueryType.LIST
|
| 158 |
+
|
| 159 |
+
# How-to requests
|
| 160 |
+
howto_triggers = ["how to ", "how do i ", "how can i ", "steps to ", "process of ",
|
| 161 |
+
"guide for ", "tutorial", "way to "]
|
| 162 |
+
if any(t in q for t in howto_triggers):
|
| 163 |
+
return QueryType.HOWTO
|
| 164 |
+
|
| 165 |
+
# Comparison requests
|
| 166 |
+
if " vs " in q or " versus " in q or "difference between" in q or "compare " in q:
|
| 167 |
+
return QueryType.COMPARISON
|
| 168 |
+
|
| 169 |
+
# Conversational (short casual queries)
|
| 170 |
+
if word_count <= 4:
|
| 171 |
+
return QueryType.CONVERSATIONAL
|
| 172 |
+
|
| 173 |
+
# Default to detailed for complex questions
|
| 174 |
+
return QueryType.DETAILED
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# ============================================
|
| 178 |
+
# ADAPTIVE PROMPT TEMPLATES
|
| 179 |
+
# ============================================
|
| 180 |
+
|
| 181 |
+
def get_adaptive_prompt(query_type: str, persona: str, intent: str, web_context: str,
|
| 182 |
+
rag_context: str, history_str: str, user_query: str) -> str:
|
| 183 |
+
"""Generate query-type-specific prompts that produce natural responses."""
|
| 184 |
+
|
| 185 |
+
tone_instruction = {
|
| 186 |
+
"Student": "Use a friendly, encouraging tone with practical examples.",
|
| 187 |
+
"Teacher": "Use a professional, resourceful tone with academic references.",
|
| 188 |
+
"Parent": "Use an empathetic, clear tone with actionable guidance.",
|
| 189 |
+
"Other": "Use a helpful, informative tone."
|
| 190 |
+
}.get(persona, "Use a helpful, informative tone.")
|
| 191 |
+
|
| 192 |
+
# Base context section
|
| 193 |
+
context_section = ""
|
| 194 |
+
if web_context and web_context != "No web search required or performed.":
|
| 195 |
+
context_section = f"""
|
| 196 |
+
**Available Context:**
|
| 197 |
+
{web_context}
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
# Query-type specific instructions
|
| 201 |
+
if query_type == QueryType.QUICK:
|
| 202 |
+
format_instruction = """
|
| 203 |
+
**Response Format:** Give a direct, concise answer in 1-2 sentences. No headers, no bullet points, no lengthy explanations. Just answer the question naturally like a knowledgeable friend would.
|
| 204 |
+
|
| 205 |
+
Example: "Machine learning is a type of artificial intelligence that enables computers to learn from data and improve their performance without being explicitly programmed."
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
elif query_type == QueryType.DEFINITION:
|
| 209 |
+
format_instruction = """
|
| 210 |
+
**Response Format:** Provide a clear definition in 2-3 sentences, followed by a brief practical example or application. Keep it conversational. No headers needed.
|
| 211 |
+
|
| 212 |
+
Example format:
|
| 213 |
+
[Definition in 2-3 sentences]
|
| 214 |
+
|
| 215 |
+
[One practical example or real-world application]
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
elif query_type == QueryType.SYLLABUS:
|
| 219 |
+
format_instruction = """
|
| 220 |
+
**Response Format:** Present the syllabus as a clean, numbered list. Each topic should be on its own line. Group related topics under clear section headers if needed.
|
| 221 |
+
|
| 222 |
+
**IMPORTANT:** Each numbered item MUST be on a NEW LINE. Use this exact format:
|
| 223 |
+
|
| 224 |
+
**[Subject Name] Syllabus:**
|
| 225 |
+
|
| 226 |
+
**1. [First Topic/Module]**
|
| 227 |
+
- Subtopic A
|
| 228 |
+
- Subtopic B
|
| 229 |
+
|
| 230 |
+
**2. [Second Topic/Module]**
|
| 231 |
+
- Subtopic A
|
| 232 |
+
- Subtopic B
|
| 233 |
+
|
| 234 |
+
Continue this pattern. Keep descriptions brief (1 line each).
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
elif query_type == QueryType.ROADMAP:
|
| 238 |
+
format_instruction = """
|
| 239 |
+
**Response Format:** Present as a clear timeline/path. Use phases or stages with timeframes where appropriate.
|
| 240 |
+
|
| 241 |
+
**Format like this:**
|
| 242 |
+
|
| 243 |
+
### Phase 1: Foundation (Month 1-2)
|
| 244 |
+
- Skill/topic to learn
|
| 245 |
+
- Skill/topic to learn
|
| 246 |
+
|
| 247 |
+
### Phase 2: Intermediate (Month 3-4)
|
| 248 |
+
- Skill/topic to learn
|
| 249 |
+
- Skill/topic to learn
|
| 250 |
+
|
| 251 |
+
### Phase 3: Advanced (Month 5-6)
|
| 252 |
+
- Skill/topic to learn
|
| 253 |
+
- Skill/topic to learn
|
| 254 |
+
|
| 255 |
+
**Resources:** List 2-3 recommended resources at the end.
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
elif query_type == QueryType.LIST:
|
| 259 |
+
format_instruction = """
|
| 260 |
+
**Response Format:** Present as a clean bullet list. Each item on its own line.
|
| 261 |
+
|
| 262 |
+
**Format:**
|
| 263 |
+
Here are the [requested items]:
|
| 264 |
+
|
| 265 |
+
- **Item 1:** Brief description
|
| 266 |
+
- **Item 2:** Brief description
|
| 267 |
+
- **Item 3:** Brief description
|
| 268 |
+
|
| 269 |
+
Keep descriptions to 1 line each. No lengthy paragraphs.
|
| 270 |
+
"""
|
| 271 |
+
|
| 272 |
+
elif query_type == QueryType.HOWTO:
|
| 273 |
+
format_instruction = """
|
| 274 |
+
**Response Format:** Present as numbered steps. Each step should be actionable and clear.
|
| 275 |
+
|
| 276 |
+
**Format:**
|
| 277 |
+
Here's how to [do the thing]:
|
| 278 |
+
|
| 279 |
+
**Step 1: [Action verb + what to do]**
|
| 280 |
+
Brief explanation (1-2 sentences max).
|
| 281 |
+
|
| 282 |
+
**Step 2: [Action verb + what to do]**
|
| 283 |
+
Brief explanation (1-2 sentences max).
|
| 284 |
+
|
| 285 |
+
Continue this pattern. Keep it practical and actionable.
|
| 286 |
+
"""
|
| 287 |
+
|
| 288 |
+
elif query_type == QueryType.COMPARISON:
|
| 289 |
+
format_instruction = """
|
| 290 |
+
**Response Format:** Present a clear comparison. You can use a simple table format or side-by-side comparison.
|
| 291 |
+
|
| 292 |
+
**Format:**
|
| 293 |
+
Here's how [X] and [Y] compare:
|
| 294 |
+
|
| 295 |
+
| Aspect | X | Y |
|
| 296 |
+
|--------|---|---|
|
| 297 |
+
| [Aspect 1] | [X's characteristic] | [Y's characteristic] |
|
| 298 |
+
| [Aspect 2] | [X's characteristic] | [Y's characteristic] |
|
| 299 |
+
|
| 300 |
+
**Bottom Line:** 1-2 sentences on when to choose each option.
|
| 301 |
+
"""
|
| 302 |
+
|
| 303 |
+
elif query_type == QueryType.CONVERSATIONAL:
|
| 304 |
+
format_instruction = """
|
| 305 |
+
**Response Format:** Keep it brief and conversational. 1-3 sentences max. No headers, no bullet points unless specifically asked. Respond like a helpful friend would in a chat.
|
| 306 |
+
"""
|
| 307 |
+
|
| 308 |
+
else: # DETAILED
|
| 309 |
+
format_instruction = """
|
| 310 |
+
**Response Format:** Provide a comprehensive answer with clear structure:
|
| 311 |
+
|
| 312 |
+
1. Start with a direct 2-3 sentence answer to the main question
|
| 313 |
+
2. Use headers (##) only if covering multiple distinct aspects
|
| 314 |
+
3. Use bullet points for lists
|
| 315 |
+
4. Keep paragraphs short (3-4 sentences max)
|
| 316 |
+
5. End with practical advice or next steps if relevant
|
| 317 |
+
|
| 318 |
+
Keep the response focused and avoid unnecessary filler.
|
| 319 |
+
"""
|
| 320 |
+
|
| 321 |
+
prompt = f"""You are a premium AI career and education counselor. {tone_instruction}
|
| 322 |
+
|
| 323 |
+
**User's Question:** {user_query}
|
| 324 |
+
|
| 325 |
+
{context_section}
|
| 326 |
+
|
| 327 |
+
**Conversation History:**
|
| 328 |
+
{history_str}
|
| 329 |
+
|
| 330 |
+
{format_instruction}
|
| 331 |
+
|
| 332 |
+
**Critical Rules:**
|
| 333 |
+
1. NEVER use XML tags like <summary>, <explanation>, <insights>
|
| 334 |
+
2. NEVER force a rigid structure if it doesn't fit the question
|
| 335 |
+
3. Preserve proper line breaks - each list item/step MUST be on its own line
|
| 336 |
+
4. Match your response length to the complexity of the question
|
| 337 |
+
5. Be helpful, accurate, and natural
|
| 338 |
+
|
| 339 |
+
Now respond to the user's question:"""
|
| 340 |
+
|
| 341 |
+
return prompt
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
# ============================================
|
| 345 |
+
# SIMPLE DOC AND TAVILY HELPERS (Unchanged)
|
| 346 |
+
# ============================================
|
| 347 |
+
|
| 348 |
class SimpleDoc:
|
| 349 |
def __init__(self, source: str, content: str, title: str = "", score: float = None):
|
| 350 |
self.metadata = {"source": source, "title": title, "score": score}
|
| 351 |
self.page_content = content
|
| 352 |
|
| 353 |
+
|
| 354 |
def create_tavily_retriever_safe(k: int = 10, logger: logging.Logger = logger, **kwargs):
|
| 355 |
global _TAVILY_CLASS
|
| 356 |
if _TAVILY_CLASS is None:
|
| 357 |
try:
|
| 358 |
+
from langchain_community.retrievers import TavilySearchAPIRetriever
|
| 359 |
_TAVILY_CLASS = TavilySearchAPIRetriever
|
| 360 |
except Exception as e:
|
| 361 |
logger.error(f"TavilySearchAPIRetriever not importable: {e}")
|
| 362 |
raise ImportError("TavilySearchAPIRetriever unavailable") from e
|
| 363 |
+
|
| 364 |
cls = _TAVILY_CLASS
|
| 365 |
try:
|
| 366 |
sig = inspect.signature(cls.__init__)
|
| 367 |
except Exception:
|
| 368 |
sig = None
|
| 369 |
+
|
| 370 |
allowed = {}
|
| 371 |
for name, val in {"k": k, **kwargs}.items():
|
| 372 |
if sig is None or (name in sig.parameters and name != "self"):
|
| 373 |
allowed[name] = val
|
| 374 |
+
|
| 375 |
try:
|
| 376 |
return cls(**allowed)
|
| 377 |
except TypeError as te:
|
|
|
|
| 382 |
logger.error(f"Tavily no-arg constructor failed: {e}")
|
| 383 |
raise
|
| 384 |
|
| 385 |
+
|
| 386 |
async def tavily_search_safe(retriever, query: str, logger: logging.Logger = logger, *args, **kwargs) -> List[Any]:
|
| 387 |
if retriever is None:
|
| 388 |
logger.debug("tavily_search_safe: retriever is None")
|
| 389 |
return []
|
| 390 |
+
|
| 391 |
async_methods = ["ainvoke", "aget_relevant_documents", "aretrieve", "asearch"]
|
| 392 |
sync_methods = ["invoke", "get_relevant_documents", "retrieve", "search"]
|
| 393 |
+
|
| 394 |
for name in async_methods:
|
| 395 |
fn = getattr(retriever, name, None)
|
| 396 |
if callable(fn):
|
|
|
|
| 403 |
continue
|
| 404 |
except Exception:
|
| 405 |
continue
|
| 406 |
+
|
| 407 |
loop = asyncio.get_event_loop()
|
| 408 |
for name in sync_methods:
|
| 409 |
fn = getattr(retriever, name, None)
|
|
|
|
| 412 |
return await loop.run_in_executor(None, lambda: fn(query))
|
| 413 |
except Exception:
|
| 414 |
continue
|
| 415 |
+
|
| 416 |
logger.warning("tavily_search_safe: no usable methods on retriever")
|
| 417 |
return []
|
| 418 |
|
| 419 |
+
|
| 420 |
async def tavily_rest_search(api_key: str, query: str, timeout: int = 15, logger: logging.Logger = logger) -> List[Dict[str, Any]]:
|
| 421 |
if requests is None:
|
| 422 |
logger.error("requests library not available; cannot use REST fallback for Tavily.")
|
| 423 |
return []
|
| 424 |
+
|
| 425 |
url = "https://api.tavily.com/search"
|
| 426 |
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
| 427 |
payload = {"query": query}
|
| 428 |
loop = asyncio.get_event_loop()
|
| 429 |
+
|
| 430 |
def do_post():
|
| 431 |
r = requests.post(url, json=payload, headers=headers, timeout=timeout)
|
| 432 |
r.raise_for_status()
|
| 433 |
return r.json()
|
| 434 |
+
|
| 435 |
try:
|
| 436 |
resp = await loop.run_in_executor(None, do_post)
|
| 437 |
results = resp.get("results", [])
|
|
|
|
| 441 |
logger.exception("tavily_rest_search failed")
|
| 442 |
return []
|
| 443 |
|
| 444 |
+
|
| 445 |
def format_sources_block(docs: List[SimpleDoc]) -> str:
|
| 446 |
if not docs:
|
| 447 |
return ""
|
| 448 |
+
lines = ["\n---\n**📚 Sources:**"]
|
| 449 |
+
for i, d in enumerate(docs[:5], 1): # Limit to 5 sources
|
| 450 |
meta = getattr(d, "metadata", {}) or {}
|
| 451 |
url = meta.get("source") or ""
|
| 452 |
+
title = meta.get("title") or url
|
| 453 |
+
if url:
|
| 454 |
+
lines.append(f"{i}. [{title}]({url})")
|
| 455 |
+
return "\n".join(lines)
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
# ============================================
|
| 459 |
+
# KEYWORD LISTS (Unchanged)
|
| 460 |
+
# ============================================
|
| 461 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 462 |
_COUNTRY_KEYWORDS = {
|
| 463 |
+
"india", "usa", "united states", "canada", "uk", "united kingdom", "germany",
|
| 464 |
+
"france", "japan", "china", "brazil", "australia", "singapore", "netherlands",
|
| 465 |
+
"italy", "spain"
|
| 466 |
}
|
| 467 |
+
|
| 468 |
_LANGUAGE_KEYWORDS = {
|
| 469 |
+
"english", "german", "french", "spanish", "mandarin", "chinese", "japanese",
|
| 470 |
+
"korean", "hindi", "arabic", "portuguese"
|
| 471 |
}
|
| 472 |
+
|
| 473 |
_ILLEGAL_TRIGGERS = [
|
| 474 |
+
r"how to make a bomb", r"detonate", r"how to assassinat", r"kill someone",
|
| 475 |
+
r"poison", r"how to hack into", r"bypass security", r"carding", r"credit card fraud",
|
| 476 |
+
r"explosive", r"illicit drug", r"how to sell drugs", r"manufacture illegal",
|
| 477 |
+
r"produce illegal", r"evade law", r"how to avoid taxes illegally"
|
| 478 |
]
|
| 479 |
|
| 480 |
+
|
| 481 |
+
# ============================================
|
| 482 |
+
# MAIN COUNSELOR CLASS
|
| 483 |
+
# ============================================
|
| 484 |
+
|
| 485 |
class UltraAdvancedHybridCounselor:
|
| 486 |
def __init__(self):
|
| 487 |
logger.info(f"🐍 Python version: {sys.version}")
|
| 488 |
+
|
| 489 |
# --- Paths and Model State ---
|
| 490 |
self.model_path = "Sachin21112004/carrerflow-ai"
|
| 491 |
self.label_encoder_path = "Sachin21112004/carrerflow-ai/label_encoder.pkl"
|
|
|
|
| 503 |
"gemini-1.5-pro-002", "gemini-2.5-flash-lite-preview", "gemini-1.5-flash-8b-latest",
|
| 504 |
"gemini-1.5-flash-latest", "gemini-1.5-pro-latest", "gemini-1.0-pro", "gemini-pro"
|
| 505 |
]
|
| 506 |
+
|
| 507 |
# --- Persistent user-adaptation files ---
|
| 508 |
self.user_corpus_path = Path("user_corpus.txt")
|
| 509 |
self.user_keywords_path = Path("user_keywords.json")
|
| 510 |
self.user_greetings_path = Path("user_greetings.json")
|
| 511 |
+
|
| 512 |
# Fine-tuning dataset and config
|
| 513 |
self.finetune_examples_path = Path("fine_tune_examples.jsonl")
|
| 514 |
self.finetune_label_map_path = Path("fine_tune_label_map.json")
|
| 515 |
+
|
| 516 |
+
# --- HF Dataset Config ---
|
| 517 |
self.dataset_repo_id = os.getenv("HF_DATASET_REPO_ID", "Sachin21112004/DreamFlow-AI-Data")
|
| 518 |
self.examples_filename_in_repo = "fine_tune_examples.jsonl"
|
| 519 |
+
self.local_examples_path = Path(f"./{self.examples_filename_in_repo}")
|
|
|
|
| 520 |
self.fine_tune_interval = int(os.getenv("FINE_TUNE_INTERVAL_SECS", "300"))
|
| 521 |
self.min_examples_to_train = int(os.getenv("MIN_EXAMPLES_TO_TRAIN", "32"))
|
| 522 |
self.fine_tune_batch_size = int(os.getenv("FINE_TUNE_BATCH", "8"))
|
| 523 |
self.fine_tune_epochs = int(os.getenv("FINE_TUNE_EPOCHS", "1"))
|
| 524 |
+
|
| 525 |
# Default greetings
|
| 526 |
+
self._default_greetings = {
|
| 527 |
+
"hi", "hiii", "hii", "hello", "hey", "hlo", "how are you",
|
| 528 |
+
"good morning", "good afternoon", "good evening"
|
| 529 |
+
}
|
| 530 |
+
|
| 531 |
# Load persisted greetings and user keywords
|
| 532 |
+
self._load_user_data()
|
| 533 |
+
|
| 534 |
+
# --- Load local ML model (if available) ---
|
| 535 |
+
self._load_local_models()
|
| 536 |
+
|
| 537 |
+
# --- Initialize RAG and DB ---
|
| 538 |
+
self._initialize_rag_db()
|
| 539 |
+
|
| 540 |
+
# --- Tavily key rotation setup ---
|
| 541 |
+
self._initialize_tavily()
|
| 542 |
+
|
| 543 |
+
# --- Redis caching ---
|
| 544 |
+
self._initialize_cache()
|
| 545 |
+
|
| 546 |
+
# --- Initialize LLM ---
|
| 547 |
+
try:
|
| 548 |
+
self.llm = self._initialize_llm()
|
| 549 |
+
if self.llm:
|
| 550 |
+
logger.info(f"✅ LLM initialized: {self.current_model}")
|
| 551 |
+
else:
|
| 552 |
+
logger.info("LLM not initialized; operating in degraded mode.")
|
| 553 |
+
except Exception as e:
|
| 554 |
+
logger.error(f"LLM initialization error: {e}")
|
| 555 |
+
self.llm = None
|
| 556 |
+
|
| 557 |
+
# Setup intent chain
|
| 558 |
+
self._setup_intent_chain()
|
| 559 |
+
|
| 560 |
+
# --- Start background fine-tune worker ---
|
| 561 |
+
self._start_fine_tune_worker()
|
| 562 |
+
|
| 563 |
+
logger.info("✅ UltraAdvancedHybridCounselor PREMIUM Edition ready.")
|
| 564 |
+
|
| 565 |
+
def _load_user_data(self):
|
| 566 |
+
"""Load persisted user greetings and keywords."""
|
| 567 |
try:
|
| 568 |
if self.user_greetings_path.exists():
|
| 569 |
with open(self.user_greetings_path, "r", encoding="utf-8") as f:
|
|
|
|
| 571 |
self.greetings = set(stored.get("greetings", [])) | self._default_greetings
|
| 572 |
else:
|
| 573 |
self.greetings = set(self._default_greetings)
|
| 574 |
+
|
| 575 |
if self.user_keywords_path.exists():
|
| 576 |
with open(self.user_keywords_path, "r", encoding="utf-8") as f:
|
| 577 |
self.user_keywords = json.load(f)
|
|
|
|
| 582 |
self.greetings = set(self._default_greetings)
|
| 583 |
self.user_keywords = {}
|
| 584 |
|
| 585 |
+
def _load_local_models(self):
|
| 586 |
+
"""Load local ML models for intent classification."""
|
| 587 |
try:
|
| 588 |
if DistilBertTokenizer and DistilBertForSequenceClassification:
|
| 589 |
try:
|
|
|
|
| 595 |
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
| 596 |
except Exception:
|
| 597 |
self.tokenizer = None
|
| 598 |
+
logger.debug("No tokenizer available locally.")
|
| 599 |
+
|
| 600 |
if joblib and Path(self.label_encoder_path).exists():
|
| 601 |
self.label_encoder = joblib.load(self.label_encoder_path)
|
| 602 |
logger.info("✅ Label encoder loaded")
|
| 603 |
except Exception as e:
|
| 604 |
logger.error(f"Error loading local ML models: {e}")
|
| 605 |
|
| 606 |
+
def _initialize_rag_db(self):
|
| 607 |
+
"""Initialize RAG and SessionDB components."""
|
| 608 |
try:
|
| 609 |
self.rag = RAGComponent() if RAGComponent else None
|
| 610 |
self.db = SessionDB() if SessionDB else None
|
| 611 |
+
if self.rag:
|
| 612 |
+
logger.info("✅ RAG initialized")
|
| 613 |
+
if self.db:
|
| 614 |
+
logger.info("✅ SessionDB initialized")
|
| 615 |
except Exception as e:
|
| 616 |
logger.error(f"Error initializing RAG/DB: {e}")
|
| 617 |
self.rag = None
|
| 618 |
self.db = None
|
| 619 |
|
| 620 |
+
def _initialize_tavily(self):
|
| 621 |
+
"""Initialize Tavily search with key rotation."""
|
| 622 |
self.tavily = None
|
| 623 |
self.tavily_keys_list = []
|
| 624 |
self.tavily_key_pool = None
|
| 625 |
self.current_tavily_key = None
|
| 626 |
+
|
| 627 |
try:
|
| 628 |
tavily_keys_str = os.getenv("TAVILY_API_KEY", "")
|
| 629 |
if tavily_keys_str:
|
|
|
|
| 638 |
logger.error(f"Error during Tavily init: {e}")
|
| 639 |
self.tavily = None
|
| 640 |
|
| 641 |
+
def _initialize_cache(self):
|
| 642 |
+
"""Initialize Redis cache if available."""
|
| 643 |
self.use_redis = os.getenv("USE_REDIS", "false").lower() == "true"
|
| 644 |
self.cache = None
|
| 645 |
if self.use_redis and RedisCache:
|
|
|
|
| 648 |
except Exception:
|
| 649 |
pass
|
| 650 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 651 |
def _get_model_priority_score(self, model_name: str) -> int:
|
| 652 |
+
"""Get priority score for model selection."""
|
| 653 |
priority_map = {
|
| 654 |
"gemini-2.5-flash-lite": 100, "gemini-2.5-flash": 95, "gemini-2.0-flash-lite": 90,
|
| 655 |
"gemini-2.0-flash": 85, "gemini-2.5-pro": 80, "gemini-1.5-flash": 75, "gemini-1.5-pro": 60
|
|
|
|
| 657 |
return priority_map.get(model_name, 10)
|
| 658 |
|
| 659 |
def _initialize_llm(self):
|
| 660 |
+
"""Initialize LLM with fallback support."""
|
| 661 |
google_api_key = os.getenv("GOOGLE_API_KEY")
|
| 662 |
if not google_api_key or ChatGoogleGenerativeAI is None:
|
| 663 |
return None
|
| 664 |
+
|
| 665 |
sorted_models = sorted(self.available_models, key=self._get_model_priority_score, reverse=True)
|
| 666 |
for model_name in sorted_models:
|
| 667 |
try:
|
| 668 |
llm = ChatGoogleGenerativeAI(
|
| 669 |
+
model=model_name, temperature=0.3, max_tokens=4096,
|
| 670 |
google_api_key=google_api_key, timeout=30, max_retries=1
|
| 671 |
)
|
| 672 |
if hasattr(llm, "invoke"):
|
| 673 |
_ = llm.invoke("ping")
|
| 674 |
elif hasattr(llm, "generate"):
|
| 675 |
_ = llm.generate("ping")
|
| 676 |
+
|
| 677 |
self.current_model = model_name
|
| 678 |
+
self.model_performance_stats[model_name] = {
|
| 679 |
+
"response_time": 0.0, "success_rate": 1.0,
|
| 680 |
+
"last_used": time.time(), "total_requests": 0, "successful_requests": 0
|
| 681 |
+
}
|
| 682 |
return llm
|
| 683 |
except Exception:
|
| 684 |
continue
|
| 685 |
+
|
| 686 |
logger.error("No LLM models could be initialized.")
|
| 687 |
return None
|
| 688 |
|
| 689 |
def _fallback_to_next_model(self) -> bool:
|
| 690 |
+
"""Attempt to fallback to next available model."""
|
| 691 |
if ChatGoogleGenerativeAI is None:
|
| 692 |
return False
|
| 693 |
+
|
| 694 |
try:
|
| 695 |
current_index = self.available_models.index(self.current_model) if self.current_model in self.available_models else -1
|
| 696 |
remaining = self.available_models[current_index + 1:] if current_index >= 0 else self.available_models
|
| 697 |
except Exception:
|
| 698 |
remaining = self.available_models
|
| 699 |
+
|
| 700 |
remaining = sorted(remaining, key=self._get_model_priority_score, reverse=True)
|
| 701 |
for model in remaining:
|
| 702 |
try:
|
| 703 |
+
llm = ChatGoogleGenerativeAI(
|
| 704 |
+
model=model, temperature=0.3, max_tokens=4096,
|
| 705 |
+
google_api_key=os.getenv("GOOGLE_API_KEY"), timeout=30, max_retries=1
|
| 706 |
+
)
|
| 707 |
+
if hasattr(llm, "invoke"):
|
| 708 |
+
_ = llm.invoke("ping")
|
| 709 |
+
elif hasattr(llm, "generate"):
|
| 710 |
+
_ = llm.generate("ping")
|
| 711 |
+
else:
|
| 712 |
+
continue
|
| 713 |
+
|
| 714 |
self.llm = llm
|
| 715 |
self.current_model = model
|
| 716 |
logger.info(f"Fell back to {model}")
|
|
|
|
| 720 |
return False
|
| 721 |
|
| 722 |
def _update_model_stats(self, model_name: str, success: bool, response_time: float = None, error: str = None):
|
| 723 |
+
"""Update model performance statistics."""
|
| 724 |
if model_name not in self.model_performance_stats:
|
| 725 |
+
self.model_performance_stats[model_name] = {
|
| 726 |
+
"total_requests": 0, "successful_requests": 0,
|
| 727 |
+
"response_time": None, "success_rate": 0.0
|
| 728 |
+
}
|
| 729 |
+
|
| 730 |
stats = self.model_performance_stats[model_name]
|
| 731 |
stats["total_requests"] = stats.get("total_requests", 0) + 1
|
| 732 |
+
|
| 733 |
if success:
|
| 734 |
stats["successful_requests"] = stats.get("successful_requests", 0) + 1
|
| 735 |
stats["response_time"] = response_time
|
| 736 |
stats["last_used"] = time.time()
|
| 737 |
else:
|
| 738 |
+
if error:
|
| 739 |
+
stats["last_error"] = error
|
| 740 |
+
|
| 741 |
total = stats["total_requests"]
|
| 742 |
stats["success_rate"] = stats.get("successful_requests", 0) / total if total > 0 else 0.0
|
| 743 |
|
| 744 |
+
def _setup_intent_chain(self):
|
| 745 |
+
"""Setup intent classification chain."""
|
| 746 |
+
self.intent_chain = None
|
| 747 |
+
if self.llm and LLMChain and PromptTemplate:
|
| 748 |
+
self.intent_template = """You are an intent classifier. Respond only with one of: educational_guidance, career_recommendation, resume_advice, interview_prep, salary_info, general_guidance, off_topic. Query: {query}"""
|
| 749 |
try:
|
| 750 |
+
self.intent_chain = LLMChain(llm=self.llm, prompt=PromptTemplate.from_template(self.intent_template))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 751 |
except Exception as e:
|
| 752 |
+
logger.error(f"Failed to create intent chain: {e}")
|
| 753 |
+
self.intent_chain = None
|
| 754 |
+
|
| 755 |
+
def _start_fine_tune_worker(self):
|
| 756 |
+
"""Start background fine-tuning worker thread."""
|
| 757 |
+
self._fine_tune_lock = threading.Lock()
|
| 758 |
+
self._stop_fine_tune_worker = False
|
| 759 |
+
self._fine_tune_thread = None
|
| 760 |
+
|
| 761 |
+
if torch and self.model is not None and self.tokenizer is not None:
|
| 762 |
+
try:
|
| 763 |
+
self._fine_tune_thread = threading.Thread(target=self._fine_tune_loop_sync, daemon=True)
|
| 764 |
+
self._fine_tune_thread.start()
|
| 765 |
+
logger.info("✅ Background fine-tune worker started.")
|
| 766 |
+
except Exception as e:
|
| 767 |
+
logger.error(f"Failed to start fine-tune background worker: {e}")
|
| 768 |
+
|
| 769 |
+
# ============================================
|
| 770 |
+
# CORE HEURISTICS
|
| 771 |
+
# ============================================
|
| 772 |
+
|
| 773 |
+
def _should_use_web_search(self, intent: str, query: str, history: List[Any] = None) -> bool:
|
| 774 |
+
"""Determine if web search should be used for this query."""
|
| 775 |
+
if not query:
|
| 776 |
+
return False
|
| 777 |
+
|
| 778 |
+
q = query.lower()
|
| 779 |
+
|
| 780 |
+
# Force triggers
|
| 781 |
+
force_triggers = [
|
| 782 |
+
"with sources", "with source", "show sources", "cite", "sources",
|
| 783 |
+
"verify", "search web", "web search", "please search", "please look up",
|
| 784 |
+
"look up", "confirm from", "confirm that"
|
| 785 |
+
]
|
| 786 |
+
if any(t in q for t in force_triggers):
|
| 787 |
+
return True
|
| 788 |
+
|
| 789 |
+
if intent == "salary_info":
|
| 790 |
+
return True
|
| 791 |
+
|
| 792 |
+
# Web triggers
|
| 793 |
+
web_triggers = [
|
| 794 |
+
"latest", "current", "202", "trend", "trends", "salary", "average",
|
| 795 |
+
"median", "top", "emerging", "statistics", "how much", "pay", "ctc",
|
| 796 |
+
"package", "percent", "percentile", "growth", "outlook"
|
| 797 |
+
]
|
| 798 |
+
if any(w in q for w in web_triggers):
|
| 799 |
+
return True
|
| 800 |
+
|
| 801 |
+
# Skip web search for greetings and short queries
|
| 802 |
+
if q.strip() in self.greetings or len(q.split()) <= 4:
|
| 803 |
+
return False
|
| 804 |
+
|
| 805 |
+
return False
|
| 806 |
+
|
| 807 |
+
def _is_illegal_request(self, query: str) -> bool:
|
| 808 |
+
"""Check if the request is for illegal content."""
|
| 809 |
+
if not query:
|
| 810 |
+
return False
|
| 811 |
+
|
| 812 |
+
q = query.lower()
|
| 813 |
+
for pattern in _ILLEGAL_TRIGGERS:
|
| 814 |
+
if re.search(pattern, q):
|
| 815 |
+
return True
|
| 816 |
+
|
| 817 |
+
risky = ["how to make", "how to build a", "how to bypass", "how to hack", "evade", "explosive", "make poison", "sell drugs"]
|
| 818 |
+
if any(r in q for r in risky) and any(word in q for word in ["bomb", "poison", "explode", "assassin", "hack", "illicit", "illegal", "fraud"]):
|
| 819 |
+
return True
|
| 820 |
+
|
| 821 |
+
return False
|
| 822 |
+
|
| 823 |
+
def _detect_country_language(self, query: str) -> Dict[str, Optional[str]]:
|
| 824 |
+
"""Detect country and language mentions in query."""
|
| 825 |
+
if not query:
|
| 826 |
+
return {"country": None, "language": None}
|
| 827 |
+
|
| 828 |
+
q = query.lower()
|
| 829 |
+
country_found = next((c for c in _COUNTRY_KEYWORDS if re.search(r"\b" + re.escape(c) + r"\b", q)), None)
|
| 830 |
+
language_found = next((l for l in _LANGUAGE_KEYWORDS if re.search(r"\b" + re.escape(l) + r"\b", q)), None)
|
| 831 |
+
return {"country": country_found, "language": language_found}
|
| 832 |
+
|
| 833 |
+
def _detect_persona(self, query: str) -> str:
|
| 834 |
+
"""Detect user persona from query."""
|
| 835 |
+
if not query:
|
| 836 |
+
return "Other"
|
| 837 |
+
|
| 838 |
+
q = query.lower()
|
| 839 |
+
if any(k in q for k in ["i am a student", "student", "grade", "class", "college", "undergraduate", "btech", "mba", "high school"]):
|
| 840 |
+
return "Student"
|
| 841 |
+
if any(k in q for k in ["i am a teacher", "teacher", "instructor", "professor", "lecturer"]):
|
| 842 |
+
return "Teacher"
|
| 843 |
+
if any(k in q for k in ["my child", "parent", "mother", "father", "guardian"]):
|
| 844 |
+
return "Parent"
|
| 845 |
+
return "Other"
|
| 846 |
+
|
| 847 |
+
def _generate_mandatory_search_queries(self, user_query: str, country: Optional[str], language: Optional[str]) -> List[str]:
|
| 848 |
+
"""Generate search queries based on geo/language context."""
|
| 849 |
+
searches = []
|
| 850 |
+
base = user_query.strip()
|
| 851 |
+
|
| 852 |
+
if country:
|
| 853 |
+
searches.append(f"{base} {country} official requirements site:gov OR site:.edu")
|
| 854 |
+
searches.append(f"{base} {country} curriculum requirements OR regulations")
|
| 855 |
+
elif language:
|
| 856 |
+
searches.append(f"{base} {language} language learning resources official exams")
|
| 857 |
+
searches.append(f"{base} {language} proficiency exam requirements OR recognized certifications")
|
| 858 |
+
else:
|
| 859 |
+
searches.append(f"{base} official guidance")
|
| 860 |
+
searches.append(f"{base} statistics OR latest data")
|
| 861 |
+
|
| 862 |
+
return list(set(searches))
|
| 863 |
+
|
| 864 |
+
# ============================================
|
| 865 |
+
# LLM INVOCATION
|
| 866 |
+
# ============================================
|
| 867 |
|
| 868 |
async def _call_direct_llm(self, prompt: str, max_retries: int = 2) -> str:
|
| 869 |
+
"""Call LLM directly with retry logic."""
|
| 870 |
if self.llm is None:
|
| 871 |
return "LLM not available. Enable GOOGLE_API_KEY and ensure dependencies are installed."
|
| 872 |
+
|
| 873 |
for attempt in range(max_retries):
|
| 874 |
try:
|
| 875 |
start = time.time()
|
|
|
|
| 882 |
full_response_text = res.content if hasattr(res, 'content') else str(res)
|
| 883 |
else:
|
| 884 |
return "LLM present but has no recognized call method."
|
| 885 |
+
|
| 886 |
self._update_model_stats(self.current_model, True, time.time() - start)
|
| 887 |
return full_response_text
|
| 888 |
+
|
| 889 |
except Exception as e:
|
| 890 |
self._update_model_stats(self.current_model, False, error=str(e))
|
| 891 |
msg = str(e).lower()
|
|
|
|
| 895 |
continue
|
| 896 |
else:
|
| 897 |
raise RuntimeError("All models failed.")
|
| 898 |
+
|
| 899 |
if attempt < max_retries - 1:
|
| 900 |
await asyncio.sleep(2 ** attempt)
|
| 901 |
else:
|
| 902 |
logger.error(f"Direct LLM call failed after {max_retries} attempts: {e}")
|
| 903 |
raise
|
| 904 |
+
|
| 905 |
return "I encountered an error while generating the response after multiple retries."
|
| 906 |
|
| 907 |
+
# ============================================
|
| 908 |
+
# TAVILY SEARCH WITH ROTATION
|
| 909 |
+
# ============================================
|
| 910 |
+
|
| 911 |
async def _rotate_tavily_key(self, query: str, max_retries: int = None) -> list:
|
| 912 |
+
"""Perform Tavily search with key rotation on failure."""
|
| 913 |
if not getattr(self, "tavily_key_pool", None) or not getattr(self, "tavily_keys_list", None):
|
| 914 |
return []
|
| 915 |
+
|
| 916 |
if max_retries is None:
|
| 917 |
max_retries = min(3, len(self.tavily_keys_list))
|
| 918 |
+
|
| 919 |
for attempt in range(max_retries):
|
| 920 |
try:
|
| 921 |
if self.current_tavily_key:
|
| 922 |
os.environ["TAVILY_API_KEY"] = self.current_tavily_key
|
| 923 |
+
try:
|
| 924 |
+
self.tavily = create_tavily_retriever_safe(k=10, logger=logger)
|
| 925 |
+
except Exception:
|
| 926 |
+
pass
|
| 927 |
+
|
| 928 |
search_docs = await tavily_search_safe(self.tavily, query, logger=logger)
|
| 929 |
if search_docs:
|
| 930 |
normalized = []
|
|
|
|
| 936 |
content = getattr(doc, "page_content", None) or (doc.get("content") if isinstance(doc, dict) else str(doc))
|
| 937 |
normalized.append(SimpleDoc(source or "", content or "", title=title or "", score=score))
|
| 938 |
return normalized
|
| 939 |
+
|
| 940 |
+
# Try REST fallback
|
| 941 |
if self.current_tavily_key:
|
| 942 |
rest_results = await tavily_rest_search(self.current_tavily_key, query)
|
| 943 |
if rest_results:
|
| 944 |
+
normalized = [
|
| 945 |
+
SimpleDoc(
|
| 946 |
+
r.get("url") or r.get("source") or "",
|
| 947 |
+
r.get("content") or r.get("title") or r.get("snippet") or str(r),
|
| 948 |
+
title=r.get("title") or "",
|
| 949 |
+
score=r.get("score")
|
| 950 |
+
) for r in rest_results
|
| 951 |
+
]
|
| 952 |
return normalized
|
| 953 |
+
|
| 954 |
+
# Rotate key
|
| 955 |
if attempt < max_retries - 1:
|
| 956 |
try:
|
| 957 |
self.current_tavily_key = next(self.tavily_key_pool)
|
|
|
|
| 960 |
continue
|
| 961 |
else:
|
| 962 |
break
|
| 963 |
+
|
| 964 |
except UsageLimitExceededError:
|
| 965 |
if attempt < max_retries - 1:
|
| 966 |
try:
|
|
|
|
| 980 |
continue
|
| 981 |
else:
|
| 982 |
break
|
| 983 |
+
|
| 984 |
logger.error("🚫 All Tavily attempts failed. Falling back to no web context.")
|
| 985 |
return []
|
| 986 |
|
| 987 |
+
# ============================================
|
| 988 |
+
# INTENT CLASSIFICATION
|
| 989 |
+
# ============================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 990 |
|
| 991 |
+
async def classify_intent(self, query: str) -> str:
|
| 992 |
+
"""Classify the intent of the user query."""
|
| 993 |
+
# Check cache first
|
| 994 |
+
if self.cache:
|
| 995 |
+
key = f"intent_{hashlib.sha256(query.encode()).hexdigest()}"
|
| 996 |
+
cached = self.cache.get(key)
|
| 997 |
+
if cached:
|
| 998 |
+
return cached
|
| 999 |
+
|
| 1000 |
+
# Try user keyword matching
|
| 1001 |
+
tokens = [t for t in re.findall(r"\b[a-zA-Z]{2,}\b", query.lower())]
|
| 1002 |
+
intent_scores = Counter()
|
| 1003 |
+
for intent_name, token_map in (self.user_keywords or {}).items():
|
| 1004 |
+
for t in tokens:
|
| 1005 |
+
intent_scores[intent_name] += token_map.get(t, 0)
|
| 1006 |
+
|
| 1007 |
+
if intent_scores:
|
| 1008 |
+
top_intent, top_score = intent_scores.most_common(1)[0]
|
| 1009 |
+
if top_score >= 2 or (len(intent_scores) == 1 and top_score >= 1):
|
| 1010 |
+
if self.cache:
|
| 1011 |
+
try:
|
| 1012 |
+
self.cache.set(key, top_intent, ttl=3600)
|
| 1013 |
+
except Exception:
|
| 1014 |
+
pass
|
| 1015 |
+
return top_intent
|
| 1016 |
+
|
| 1017 |
+
# Try local model
|
| 1018 |
+
if self.model is not None and self.tokenizer is not None and torch is not None:
|
| 1019 |
+
try:
|
| 1020 |
+
self.model.eval()
|
| 1021 |
+
with torch.no_grad():
|
| 1022 |
+
enc = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True, max_length=128)
|
| 1023 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1024 |
+
for k, v in enc.items():
|
| 1025 |
+
enc[k] = v.to(device)
|
| 1026 |
+
self.model.to(device)
|
| 1027 |
+
outputs = self.model(**enc)
|
| 1028 |
+
logits = outputs.logits.cpu().numpy().tolist()[0]
|
| 1029 |
+
|
| 1030 |
+
label_map = {}
|
| 1031 |
+
if self.finetune_label_map_path.exists():
|
| 1032 |
+
try:
|
| 1033 |
+
with open(self.finetune_label_map_path, "r", encoding="utf-8") as f:
|
| 1034 |
+
label_map = json.load(f)
|
| 1035 |
+
except Exception:
|
| 1036 |
+
label_map = {}
|
| 1037 |
+
|
| 1038 |
+
if label_map:
|
| 1039 |
+
id_to_label = {int(v): k for k, v in label_map.items()}
|
| 1040 |
+
pred_idx = int(max(range(len(logits)), key=lambda i: logits[i]))
|
| 1041 |
+
intent = id_to_label.get(pred_idx, "general_guidance")
|
| 1042 |
+
if self.cache:
|
| 1043 |
+
try:
|
| 1044 |
+
self.cache.set(key, intent, ttl=3600)
|
| 1045 |
+
except Exception:
|
| 1046 |
+
pass
|
| 1047 |
+
return intent
|
| 1048 |
+
except Exception:
|
| 1049 |
+
logger.debug("Local classifier prediction failed; falling back to heuristics.")
|
| 1050 |
+
|
| 1051 |
+
# Heuristic fallback
|
| 1052 |
+
if self.intent_chain is None:
|
| 1053 |
+
q = query.lower()
|
| 1054 |
+
if any(w in q for w in ["resume", "cv", "cover letter"]):
|
| 1055 |
+
intent = "resume_advice"
|
| 1056 |
+
elif any(w in q for w in ["interview", "star method", "technical interview", "hr round"]):
|
| 1057 |
+
intent = "interview_prep"
|
| 1058 |
+
elif any(w in q for w in ["salary", "ctc", "package", "pay"]):
|
| 1059 |
+
intent = "salary_info"
|
| 1060 |
+
elif any(w in q for w in ["which course", "which college", "what should i study", "career", "i like"]):
|
| 1061 |
+
intent = "career_recommendation"
|
| 1062 |
+
elif any(w in q for w in ["school", "exam", "jee", "neet", "admission"]):
|
| 1063 |
+
intent = "educational_guidance"
|
| 1064 |
else:
|
| 1065 |
+
intent = "general_guidance"
|
| 1066 |
+
|
| 1067 |
+
if self.cache:
|
| 1068 |
+
try:
|
| 1069 |
+
self.cache.set(key, intent, ttl=3600)
|
| 1070 |
+
except Exception:
|
| 1071 |
+
pass
|
| 1072 |
+
return intent
|
| 1073 |
+
|
| 1074 |
+
# Use LLM chain
|
| 1075 |
+
try:
|
| 1076 |
+
response = await self._safe_llm_invoke(self.intent_chain, {"query": query})
|
| 1077 |
+
intent_text = response.get("text", "") if isinstance(response, dict) else str(response)
|
| 1078 |
+
intent = intent_text.strip().lower().replace(".", "")
|
| 1079 |
+
valid = ["educational_guidance", "career_recommendation", "resume_advice", "interview_prep", "salary_info", "general_guidance", "off_topic"]
|
| 1080 |
+
if intent not in valid:
|
| 1081 |
+
intent = "general_guidance"
|
| 1082 |
+
|
| 1083 |
+
if self.cache:
|
| 1084 |
+
try:
|
| 1085 |
+
self.cache.set(key, intent, ttl=3600)
|
| 1086 |
+
except Exception:
|
| 1087 |
+
pass
|
| 1088 |
+
return intent
|
| 1089 |
+
except Exception as e:
|
| 1090 |
+
logger.error(f"Intent classification failed: {e}")
|
| 1091 |
+
return "general_guidance"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1092 |
|
| 1093 |
+
async def _safe_llm_invoke(self, chain, params: Dict[str, Any], max_retries: int = 2) -> Any:
|
| 1094 |
+
"""Safely invoke LLM chain with retries."""
|
| 1095 |
+
if chain is None or self.llm is None:
|
| 1096 |
+
raise RuntimeError("LLM chain or LLM not available.")
|
| 1097 |
+
|
| 1098 |
+
for attempt in range(max_retries):
|
| 1099 |
+
try:
|
| 1100 |
+
start = time.time()
|
| 1101 |
+
if hasattr(chain, "ainvoke"):
|
| 1102 |
+
res = await chain.ainvoke(params)
|
| 1103 |
+
else:
|
| 1104 |
+
loop = asyncio.get_event_loop()
|
| 1105 |
+
res = await loop.run_in_executor(None, lambda: chain.invoke(params))
|
| 1106 |
+
self._update_model_stats(self.current_model, True, time.time() - start)
|
| 1107 |
+
return res
|
| 1108 |
+
except Exception as e:
|
| 1109 |
+
self._update_model_stats(self.current_model, False, error=str(e))
|
| 1110 |
+
msg = str(e).lower()
|
| 1111 |
+
if any(k in msg for k in ["not found", "404", "not supported", "invalid model", "model does not exist"]):
|
| 1112 |
+
if self._fallback_to_next_model():
|
| 1113 |
+
logger.info("Retrying after fallback model selection.")
|
| 1114 |
+
continue
|
| 1115 |
+
else:
|
| 1116 |
+
raise RuntimeError("All models failed.")
|
| 1117 |
+
if attempt < max_retries - 1:
|
| 1118 |
+
await asyncio.sleep(2 ** attempt)
|
| 1119 |
+
else:
|
| 1120 |
+
raise
|
| 1121 |
|
| 1122 |
+
# ============================================
|
| 1123 |
+
# MAIN API - GET COMPREHENSIVE ANSWER
|
| 1124 |
+
# ============================================
|
| 1125 |
|
| 1126 |
+
async def get_comprehensive_answer(self, user_query: str, session_id: str) -> AsyncGenerator[str, None]:
|
| 1127 |
+
"""Generate a comprehensive, adaptively-formatted answer."""
|
| 1128 |
+
|
| 1129 |
+
# Load history
|
| 1130 |
+
history = []
|
| 1131 |
+
try:
|
| 1132 |
+
if self.db:
|
| 1133 |
+
history = self.db.get_history(session_id)
|
| 1134 |
+
except Exception:
|
| 1135 |
+
logger.debug("History fetch failed.")
|
| 1136 |
+
|
| 1137 |
+
history_str = "\n".join([f"User: {h[0]}\nBot: {h[1]}" for h in history]) if history else "No history yet."
|
| 1138 |
+
|
| 1139 |
+
web_context = "No web search required or performed."
|
| 1140 |
+
sources_text = ""
|
| 1141 |
+
|
| 1142 |
+
normalized_query = user_query.strip()
|
| 1143 |
+
normalized_lower = normalized_query.lower().rstrip(".!?")
|
| 1144 |
+
|
| 1145 |
+
# Handle greetings
|
| 1146 |
+
if normalized_lower in self.greetings:
|
| 1147 |
+
greeting_response = "Hello! I'm your AI education & career counselor. How can I assist you with your education or career goals today?"
|
| 1148 |
+
yield greeting_response
|
| 1149 |
try:
|
| 1150 |
+
if self.db:
|
| 1151 |
+
self.db.save_history(session_id, history + [[user_query, greeting_response]])
|
| 1152 |
+
except Exception:
|
| 1153 |
+
pass
|
| 1154 |
+
try:
|
| 1155 |
+
self._learn_from_interaction(user_query, "general_guidance")
|
| 1156 |
+
except Exception:
|
| 1157 |
+
pass
|
| 1158 |
+
return
|
| 1159 |
+
|
| 1160 |
+
# Check for illegal requests
|
| 1161 |
+
if self._is_illegal_request(user_query):
|
| 1162 |
+
refusal = "I cannot assist with requests that enable illegal or harmful activities. I can help with lawful education, career guidance, coding practice, and study resources — please rephrase your question."
|
| 1163 |
+
yield refusal
|
| 1164 |
+
return
|
| 1165 |
+
|
| 1166 |
+
try:
|
| 1167 |
+
# Detect context
|
| 1168 |
+
persona = self._detect_persona(user_query)
|
| 1169 |
+
geo_lang = self._detect_country_language(user_query)
|
| 1170 |
+
country = geo_lang.get("country")
|
| 1171 |
+
language = geo_lang.get("language")
|
| 1172 |
+
intent = await self.classify_intent(user_query)
|
| 1173 |
+
|
| 1174 |
+
# ★ KEY INNOVATION: Detect query type for adaptive formatting
|
| 1175 |
+
query_type = detect_query_type(user_query)
|
| 1176 |
+
logger.info(f"Intent: {intent}, Query Type: {query_type}")
|
| 1177 |
+
|
| 1178 |
+
# Determine if web search needed
|
| 1179 |
+
use_web = self._should_use_web_search(intent, user_query, history) or (country or language)
|
| 1180 |
+
|
| 1181 |
+
# Perform web search if needed
|
| 1182 |
+
combined_docs: List[SimpleDoc] = []
|
| 1183 |
+
if use_web:
|
| 1184 |
+
search_queries: List[str] = []
|
| 1185 |
+
if country or language:
|
| 1186 |
+
search_queries.extend(self._generate_mandatory_search_queries(user_query, country, language))
|
| 1187 |
+
if intent == "salary_info" or self._should_use_web_search(intent, user_query):
|
| 1188 |
+
search_queries.append(f"Latest verified information for: {user_query}")
|
| 1189 |
+
|
| 1190 |
+
search_queries = list(set(search_queries))
|
| 1191 |
+
for sq in search_queries:
|
| 1192 |
+
docs = await self._rotate_tavily_key(sq)
|
| 1193 |
+
combined_docs.extend(docs)
|
| 1194 |
+
|
| 1195 |
+
if combined_docs:
|
| 1196 |
+
unique_docs = {d.page_content: d for d in combined_docs}
|
| 1197 |
+
final_docs = list(unique_docs.values())[:10]
|
| 1198 |
+
web_context = "\n\n".join([
|
| 1199 |
+
f"Source: {getattr(doc, 'metadata', {}).get('source', 'N/A')}\n"
|
| 1200 |
+
f"Title: {getattr(doc, 'metadata', {}).get('title', '')}\n"
|
| 1201 |
+
f"Content: {getattr(doc, 'page_content', '')}"
|
| 1202 |
+
for doc in final_docs
|
| 1203 |
+
])
|
| 1204 |
+
sources_text = format_sources_block(final_docs)
|
| 1205 |
+
else:
|
| 1206 |
+
web_context = "Web search performed but returned no highly relevant results."
|
| 1207 |
+
sources_text = ""
|
| 1208 |
+
|
| 1209 |
+
# Get RAG context
|
| 1210 |
+
rag_context = "No RAG content"
|
| 1211 |
+
if self.rag:
|
| 1212 |
+
rag_context = "Local knowledge base accessed and utilized."
|
| 1213 |
+
|
| 1214 |
+
# ★ Generate adaptive prompt based on query type
|
| 1215 |
+
prompt = get_adaptive_prompt(
|
| 1216 |
+
query_type=query_type,
|
| 1217 |
+
persona=persona,
|
| 1218 |
+
intent=intent,
|
| 1219 |
+
web_context=web_context,
|
| 1220 |
+
rag_context=rag_context,
|
| 1221 |
+
history_str=history_str,
|
| 1222 |
+
user_query=user_query
|
| 1223 |
+
)
|
| 1224 |
+
|
| 1225 |
+
# Get LLM response
|
| 1226 |
+
generated_answer_text = await self._call_direct_llm(prompt)
|
| 1227 |
+
|
| 1228 |
+
if not generated_answer_text:
|
| 1229 |
+
raise RuntimeError("LLM returned an empty response.")
|
| 1230 |
+
|
| 1231 |
+
# ★ Clean up response (remove any residual XML tags)
|
| 1232 |
+
final_response = self._clean_response(generated_answer_text)
|
| 1233 |
+
|
| 1234 |
+
# Add sources if available
|
| 1235 |
+
if sources_text and query_type not in [QueryType.QUICK, QueryType.CONVERSATIONAL]:
|
| 1236 |
+
final_response = final_response.strip() + "\n\n" + sources_text
|
| 1237 |
+
|
| 1238 |
+
yield final_response
|
| 1239 |
+
|
| 1240 |
+
# Save to history
|
| 1241 |
+
if self.db:
|
| 1242 |
+
try:
|
| 1243 |
+
self.db.save_history(session_id, history + [[user_query, final_response]])
|
| 1244 |
+
except Exception:
|
| 1245 |
+
pass
|
| 1246 |
+
|
| 1247 |
+
# Learn from interaction
|
| 1248 |
+
try:
|
| 1249 |
+
self._learn_from_interaction(user_query, intent)
|
| 1250 |
+
except Exception:
|
| 1251 |
+
logger.debug("Learning step failed (non-fatal).")
|
| 1252 |
+
|
| 1253 |
+
except Exception as e:
|
| 1254 |
+
logger.error(f"❌ Error in get_comprehensive_answer: {e}", exc_info=True)
|
| 1255 |
+
error_msg = f"I'm sorry, I'm encountering an unexpected error while processing your request. Please try again."
|
| 1256 |
+
if self.db:
|
| 1257 |
+
try:
|
| 1258 |
+
self.db.save_history(session_id, history + [[user_query, error_msg]])
|
| 1259 |
+
except Exception:
|
| 1260 |
+
pass
|
| 1261 |
+
yield error_msg
|
| 1262 |
|
| 1263 |
+
def _clean_response(self, text: str) -> str:
|
| 1264 |
+
"""Clean up LLM response by removing XML tags and fixing formatting."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1265 |
if not text:
|
| 1266 |
+
return ""
|
| 1267 |
+
|
| 1268 |
+
# Remove XML tags
|
| 1269 |
+
text = re.sub(r'</?summary>', '', text, flags=re.IGNORECASE)
|
| 1270 |
+
text = re.sub(r'</?explanation>', '', text, flags=re.IGNORECASE)
|
| 1271 |
+
text = re.sub(r'</?insights>', '', text, flags=re.IGNORECASE)
|
| 1272 |
+
|
| 1273 |
+
# Remove rigid section headers if they don't add value
|
| 1274 |
+
text = re.sub(r'^## Summary\s*\n', '', text, flags=re.MULTILINE)
|
| 1275 |
+
text = re.sub(r'^## Detailed Explanation\s*\n', '', text, flags=re.MULTILINE)
|
| 1276 |
+
text = re.sub(r'^## Relevant Insights\s*\n', '', text, flags=re.MULTILINE)
|
| 1277 |
+
text = re.sub(r'^Detailed Explanation\s*\n', '', text, flags=re.MULTILINE)
|
| 1278 |
+
|
| 1279 |
+
# Clean up extra whitespace while preserving intentional line breaks
|
| 1280 |
+
text = re.sub(r'\n{3,}', '\n\n', text)
|
| 1281 |
+
|
| 1282 |
+
return text.strip()
|
| 1283 |
+
|
| 1284 |
+
# ============================================
|
| 1285 |
+
# LEARNING AND PERSISTENCE
|
| 1286 |
+
# ============================================
|
| 1287 |
|
| 1288 |
+
def _persist_user_keywords(self):
|
| 1289 |
+
"""Persist user keywords to file."""
|
| 1290 |
+
try:
|
| 1291 |
+
with open(self.user_keywords_path, "w", encoding="utf-8") as f:
|
| 1292 |
+
json.dump(self.user_keywords, f, ensure_ascii=False, indent=2)
|
| 1293 |
+
except Exception as e:
|
| 1294 |
+
logger.debug(f"Failed to persist user keywords: {e}")
|
| 1295 |
|
| 1296 |
+
def _persist_greetings(self):
|
| 1297 |
+
"""Persist greetings to file."""
|
| 1298 |
+
try:
|
| 1299 |
+
with open(self.user_greetings_path, "w", encoding="utf-8") as f:
|
| 1300 |
+
json.dump({"greetings": sorted(list(self.greetings))}, f, ensure_ascii=False, indent=2)
|
| 1301 |
+
except Exception as e:
|
| 1302 |
+
logger.debug(f"Failed to persist user greetings: {e}")
|
| 1303 |
|
| 1304 |
+
def _learn_from_interaction(self, query: str, intent: str):
|
| 1305 |
+
"""Learn from user interaction for continuous improvement."""
|
| 1306 |
+
if not query:
|
| 1307 |
+
return
|
| 1308 |
+
|
| 1309 |
+
q = query.strip()
|
| 1310 |
+
|
| 1311 |
+
# Save to corpus
|
| 1312 |
+
try:
|
| 1313 |
+
with open(self.user_corpus_path, "a", encoding="utf-8") as f:
|
| 1314 |
+
f.write(q + "\n")
|
| 1315 |
+
except Exception:
|
| 1316 |
+
pass
|
| 1317 |
+
|
| 1318 |
+
# Extract and save keywords
|
| 1319 |
+
tokens = [t for t in re.findall(r"\b[a-zA-Z]{2,}\b", q.lower()) if len(t) > 1]
|
| 1320 |
+
|
| 1321 |
+
if len(tokens) <= 2 and q.lower() not in {"", "ok", "thanks", "thank you"}:
|
| 1322 |
+
if any(g in q.lower() for g in ["hi", "hello", "hey", "hlo", "hiii", "hii"]):
|
| 1323 |
+
self.greetings.add(q.lower())
|
| 1324 |
+
self._persist_greetings()
|
| 1325 |
+
|
| 1326 |
+
if intent not in self.user_keywords:
|
| 1327 |
+
self.user_keywords[intent] = {}
|
| 1328 |
+
|
| 1329 |
+
token_counts = self.user_keywords.get(intent, {})
|
| 1330 |
+
for t in tokens:
|
| 1331 |
+
token_counts[t] = token_counts.get(t, 0) + 1
|
| 1332 |
+
self.user_keywords[intent] = token_counts
|
| 1333 |
+
self._persist_user_keywords()
|
| 1334 |
+
|
| 1335 |
+
# Persist fine-tune example
|
| 1336 |
+
try:
|
| 1337 |
+
self._persist_fine_tune_example(q, intent)
|
| 1338 |
+
except Exception:
|
| 1339 |
+
logger.debug("Failed to persist fine-tune example (non-fatal).")
|
| 1340 |
|
|
|
|
| 1341 |
def _persist_fine_tune_example(self, text: str, label: str) -> None:
|
| 1342 |
+
"""Persist fine-tune example to Hugging Face dataset."""
|
| 1343 |
try:
|
| 1344 |
line = json.dumps({"text": text, "label": label}, ensure_ascii=False)
|
| 1345 |
with open(self.local_examples_path, "a", encoding="utf-8") as f:
|
| 1346 |
f.write(line + "\n")
|
| 1347 |
+
|
| 1348 |
api = HfApi()
|
| 1349 |
api.upload_file(
|
| 1350 |
path_or_fileobj=self.local_examples_path,
|
|
|
|
| 1357 |
logger.debug(f"Failed to persist fine-tune example to Hub: {e}")
|
| 1358 |
|
| 1359 |
def _load_fine_tune_examples(self) -> List[Dict[str, str]]:
|
| 1360 |
+
"""Load fine-tune examples from Hugging Face."""
|
| 1361 |
try:
|
| 1362 |
hf_hub_download(
|
| 1363 |
repo_id=self.dataset_repo_id,
|
|
|
|
| 1367 |
token=os.environ.get("HF_WRITE_TOKEN"),
|
| 1368 |
force_filename=self.examples_filename_in_repo
|
| 1369 |
)
|
| 1370 |
+
|
| 1371 |
if not self.local_examples_path.exists():
|
| 1372 |
logger.info("No examples file found in dataset repo.")
|
| 1373 |
return []
|
| 1374 |
+
|
| 1375 |
with open(self.local_examples_path, "r", encoding="utf-8") as f:
|
| 1376 |
lines = [json.loads(l) for l in f if l.strip()]
|
| 1377 |
return lines
|
|
|
|
| 1380 |
return []
|
| 1381 |
|
| 1382 |
def _clear_fine_tune_examples(self, archive: bool = True):
|
| 1383 |
+
"""Clear fine-tune examples (optionally archive first)."""
|
| 1384 |
api = HfApi()
|
| 1385 |
try:
|
| 1386 |
if archive:
|
|
|
|
| 1400 |
repo_type="dataset",
|
| 1401 |
token=os.environ.get("HF_WRITE_TOKEN")
|
| 1402 |
)
|
| 1403 |
+
|
| 1404 |
for f in glob.glob(f"./{self.examples_filename_in_repo}*"):
|
| 1405 |
try:
|
| 1406 |
os.remove(f)
|
| 1407 |
except Exception:
|
| 1408 |
pass
|
| 1409 |
+
|
| 1410 |
logger.info("Archived examples file in dataset repo.")
|
|
|
|
| 1411 |
except Exception as e:
|
| 1412 |
logger.debug(f"Failed to clear/archive examples in Hub (non-fatal): {e}")
|
| 1413 |
|
| 1414 |
+
# ============================================
|
| 1415 |
+
# FINE-TUNE WORKER
|
| 1416 |
+
# ============================================
|
| 1417 |
+
|
| 1418 |
def _fine_tune_loop_sync(self):
|
| 1419 |
+
"""Background fine-tuning loop."""
|
| 1420 |
logger.info("Fine-tune loop running.")
|
| 1421 |
while not getattr(self, "_stop_fine_tune_worker", False):
|
| 1422 |
try:
|
|
|
|
| 1426 |
time.sleep(max(10, self.fine_tune_interval))
|
| 1427 |
|
| 1428 |
def _maybe_fine_tune_once(self):
|
| 1429 |
+
"""Attempt a fine-tuning iteration if conditions are met."""
|
| 1430 |
if not self._fine_tune_lock.acquire(blocking=False):
|
| 1431 |
logger.debug("Fine-tune run already in progress; skipping this iteration.")
|
| 1432 |
return
|
| 1433 |
+
|
| 1434 |
try:
|
| 1435 |
examples = self._load_fine_tune_examples()
|
| 1436 |
if len(examples) < self.min_examples_to_train:
|
| 1437 |
logger.debug(f"Not enough examples for fine-tune (have {len(examples)}, need {self.min_examples_to_train}).")
|
| 1438 |
return
|
| 1439 |
+
|
| 1440 |
if not (torch and self.model is not None and self.tokenizer is not None):
|
| 1441 |
logger.warning("Fine-tune prerequisites missing (torch/model/tokenizer). Skipping training.")
|
| 1442 |
return
|
| 1443 |
+
|
| 1444 |
+
# Build label map
|
| 1445 |
label_to_id = {}
|
| 1446 |
if self.label_encoder is not None and hasattr(self.label_encoder, "classes_"):
|
| 1447 |
for idx, lab in enumerate(getattr(self.label_encoder, "classes_", [])):
|
|
|
|
| 1453 |
label_to_id = json.load(f)
|
| 1454 |
except Exception:
|
| 1455 |
label_to_id = {}
|
| 1456 |
+
|
| 1457 |
next_id = max(label_to_id.values()) + 1 if label_to_id else 0
|
| 1458 |
for ex in examples:
|
| 1459 |
lab = ex.get("label", "general_guidance")
|
| 1460 |
if lab not in label_to_id:
|
| 1461 |
label_to_id[lab] = next_id
|
| 1462 |
next_id += 1
|
| 1463 |
+
|
| 1464 |
try:
|
| 1465 |
with open(self.finetune_label_map_path, "w", encoding="utf-8") as f:
|
| 1466 |
json.dump(label_to_id, f, ensure_ascii=False, indent=2)
|
| 1467 |
except Exception:
|
| 1468 |
pass
|
| 1469 |
+
|
| 1470 |
+
# Prepare data
|
| 1471 |
texts = [ex["text"] for ex in examples]
|
| 1472 |
labels = [label_to_id.get(ex.get("label", "general_guidance"), 0) for ex in examples]
|
| 1473 |
+
|
| 1474 |
enc = self.tokenizer(texts, padding=True, truncation=True, max_length=128, return_tensors="pt")
|
| 1475 |
input_ids = enc["input_ids"]
|
| 1476 |
attention_mask = enc["attention_mask"]
|
| 1477 |
labels_tensor = torch.tensor(labels, dtype=torch.long)
|
| 1478 |
+
|
| 1479 |
dataset = TensorDataset(input_ids, attention_mask, labels_tensor)
|
| 1480 |
sampler = RandomSampler(dataset)
|
| 1481 |
loader = DataLoader(dataset, sampler=sampler, batch_size=self.fine_tune_batch_size)
|
| 1482 |
+
|
| 1483 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1484 |
self.model.to(device)
|
| 1485 |
self.model.train()
|
| 1486 |
optimizer = AdamW(self.model.parameters(), lr=1e-5)
|
| 1487 |
+
|
|
|
|
| 1488 |
logger.info(f"Starting fine-tune: {len(examples)} examples, {len(loader)} batches, epochs={self.fine_tune_epochs}")
|
| 1489 |
+
|
| 1490 |
for epoch in range(self.fine_tune_epochs):
|
| 1491 |
epoch_loss = 0.0
|
| 1492 |
for batch in loader:
|
|
|
|
| 1500 |
optimizer.step()
|
| 1501 |
epoch_loss += loss.item() if loss is not None else 0.0
|
| 1502 |
logger.info(f"Fine-tune epoch {epoch+1}/{self.fine_tune_epochs} loss: {epoch_loss:.4f}")
|
| 1503 |
+
|
| 1504 |
+
# Save model
|
| 1505 |
try:
|
| 1506 |
self.model.save_pretrained(self.model_path)
|
| 1507 |
try:
|
|
|
|
| 1511 |
logger.info(f"✅ Fine-tuned model saved to {self.model_path}")
|
| 1512 |
except Exception as e:
|
| 1513 |
logger.error(f"Failed to save fine-tuned model: {e}")
|
| 1514 |
+
|
| 1515 |
self._clear_fine_tune_examples(archive=True)
|
| 1516 |
+
|
| 1517 |
finally:
|
| 1518 |
try:
|
| 1519 |
self._fine_tune_lock.release()
|
| 1520 |
except Exception:
|
| 1521 |
pass
|
| 1522 |
|
| 1523 |
+
# ============================================
|
| 1524 |
+
# CAREER PREDICTION
|
| 1525 |
+
# ============================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1526 |
|
| 1527 |
async def predict_career(self, query: str) -> Dict[str, Any]:
|
| 1528 |
+
"""Predict career recommendation based on query."""
|
| 1529 |
if self.cache:
|
| 1530 |
key = f"predict_{hashlib.sha256(query.encode()).hexdigest()}"
|
| 1531 |
cached = self.cache.get(key)
|
| 1532 |
if cached:
|
| 1533 |
return cached
|
| 1534 |
+
|
| 1535 |
if not (self.model and self.tokenizer and torch and self.label_encoder is not None):
|
| 1536 |
return {"recommendation": None, "confidence": 0.0, "error": "Local prediction unavailable"}
|
| 1537 |
+
|
| 1538 |
try:
|
| 1539 |
inputs = self.tokenizer(query.lower(), return_tensors="pt", padding=True, truncation=True, max_length=128)
|
| 1540 |
with torch.no_grad():
|
|
|
|
| 1551 |
logger.error(f"Prediction failed: {e}")
|
| 1552 |
return {"recommendation": None, "confidence": 0.0, "error": str(e)}
|
| 1553 |
|
| 1554 |
+
# ============================================
|
| 1555 |
+
# HEALTH AND INFO
|
| 1556 |
+
# ============================================
|
| 1557 |
+
|
| 1558 |
def get_current_model_info(self) -> Dict[str, Any]:
|
| 1559 |
+
"""Get current model information."""
|
| 1560 |
return {
|
| 1561 |
"current_model": self.current_model,
|
| 1562 |
"available_models": self.available_models,
|
|
|
|
| 1568 |
}
|
| 1569 |
|
| 1570 |
def get_health_status(self) -> Dict[str, Any]:
|
| 1571 |
+
"""Get system health status."""
|
| 1572 |
try:
|
| 1573 |
total_models = len(self.available_models)
|
| 1574 |
working = sum(1 for s in self.model_performance_stats.values() if s.get("success_rate", 0) > 0)
|
|
|
|
| 1586 |
except Exception as e:
|
| 1587 |
return {"status": "error", "error": str(e), "last_updated": time.time()}
|
| 1588 |
|
| 1589 |
+
|
| 1590 |
+
# ============================================
|
| 1591 |
+
# DEMO
|
| 1592 |
+
# ============================================
|
| 1593 |
+
|
| 1594 |
if __name__ == "__main__":
|
| 1595 |
async def demo():
|
| 1596 |
c = UltraAdvancedHybridCounselor()
|
| 1597 |
+
|
| 1598 |
+
test_queries = [
|
| 1599 |
+
("What is machine learning?", "Quick answer test"),
|
| 1600 |
+
("Give me syllabus for machine learning", "Syllabus test"),
|
| 1601 |
+
("How to become a data scientist?", "How-to test"),
|
| 1602 |
+
("Python vs JavaScript for web development", "Comparison test"),
|
| 1603 |
+
("Give me roadmap for becoming a full stack developer", "Roadmap test"),
|
| 1604 |
+
("Top 5 programming languages to learn in 2025", "List test"),
|
| 1605 |
+
]
|
| 1606 |
+
|
| 1607 |
+
for query, desc in test_queries:
|
| 1608 |
+
print(f"\n{'='*60}")
|
| 1609 |
+
print(f"TEST: {desc}")
|
| 1610 |
+
print(f"Query: {query}")
|
| 1611 |
+
print(f"{'='*60}")
|
| 1612 |
+
async for out in c.get_comprehensive_answer(query, session_id="demo"):
|
| 1613 |
+
print(out)
|
| 1614 |
+
print()
|
| 1615 |
+
|
|
|
|
|
|
|
| 1616 |
try:
|
| 1617 |
asyncio.run(demo())
|
| 1618 |
except Exception as e:
|
| 1619 |
+
logger.error(f"Demo failed: {e}")
|