visaverse-chatbot / faq_services.py
shakauthossain's picture
Added Rate Limit and Other things
0c0d03a
# updated faq_services.py
import os
import re
import pandas as pd
import uuid
import hashlib
import time
import markdown
from dotenv import load_dotenv
# LangChain / Vector DB
from langchain.vectorstores import Milvus
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage
from langchain.docstore.document import Document
from langchain_community.document_loaders import CSVLoader
from langchain.schema import SystemMessage, HumanMessage
from langchain_community.embeddings import HuggingFaceEmbeddings
from pymilvus import connections, utility, Collection
from pymilvus.orm.schema import FieldSchema
# ---------------------- Environment Setup ----------------------
load_dotenv()
os.environ["HF_HOME"] = "/tmp/hf_cache" # Optional cleanup
# ---------------------- File & Model Config ----------------------
faq_path = "faqs.csv"
# πŸ’° FREE Embeddings - Sentence Transformers (saves ~$18/day!)
print("πŸš€ Loading FREE embedding model (all-mpnet-base-v2)...")
embedding_model = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-mpnet-base-v2",
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
print("βœ… FREE embedding model loaded successfully!")
# NOTE: Dimension changed from 1536 (OpenAI) to 768 (mpnet)
EMBEDDING_DIM = 768
IMPORTANT_KEYWORDS = [
"visa", "permanent residency", "PR", "study permit", "work permit", "immigration",
"Canada", "eligibility", "requirement", "application", "contact us", "support"
]
ANCHOR_RE = re.compile(r'<a\b[^>]*>.*?<\/a>', flags=re.IGNORECASE|re.DOTALL)
EMAIL_RE = re.compile(r'(?<!["\w])([A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,})')
URL_RE = re.compile(r'(?<!["\w])(https?://[^\s<>"\)]+)')
def style_response(text: str) -> str:
# 0) Remove stray markdown italics the model might add
text = re.sub(r'\*(.*?)\*', r'\1', text)
# 1) Linkify plain emails and URLs (HTML only, no Markdown)
text = EMAIL_RE.sub(r'<a href="mailto:\1">\1</a>', text)
text = URL_RE.sub(r'<a href="\1">\1</a>', text)
# 2) Protect existing anchors so we don't bold inside them
anchors = []
def _stash(m):
anchors.append(m.group(0))
return f'__LINK{len(anchors)-1}__'
text = ANCHOR_RE.sub(_stash, text)
# 3) Bold important keywords OUTSIDE links
for word in sorted(IMPORTANT_KEYWORDS, key=lambda w: -len(w)): # longest first
kw = re.compile(rf'\b({re.escape(word)})\b', flags=re.IGNORECASE)
text = kw.sub(r'<strong>\1</strong>', text)
# 4) Restore anchors
for i, a in enumerate(anchors):
text = text.replace(f'__LINK{i}__', a)
# Optional: wrap once in a paragraph (don’t wrap again elsewhere)
return f"<p>{text}</p>"
def style_response_as_html(text: str) -> str:
markdown_text = style_response(text)
html = markdown.markdown(markdown_text)
return html
# Zilliz (Milvus) Cloud Config
milvus_uri = os.getenv("ZILLIZ_URI")
milvus_token = os.getenv("ZILLIZ_TOKEN")
collection_name = os.getenv("ZILLIZ_COLLECTION", "visaverse_faqs3")
connections.connect(
alias="default",
uri=os.getenv("ZILLIZ_URI"),
token=os.getenv("ZILLIZ_TOKEN")
)
def ensure_collection_matches_schema(expected_dim: int, collection_name: str, uri: str, token: str):
try:
connections.connect(uri=uri, token=token)
if utility.has_collection(collection_name):
schema = Collection(collection_name).schema
for field in schema.fields:
if field.dtype.name == "FLOAT_VECTOR":
if field.params and "dim" in field.params:
actual_dim = int(field.params["dim"])
if actual_dim != expected_dim:
print(f"⚠️ Collection dim mismatch: {actual_dim} vs {expected_dim}. Dropping old collection.")
utility.drop_collection(collection_name)
return
else:
print(f"βœ… Collection '{collection_name}' has correct dimension: {expected_dim}.")
return
print(f"⚠️ Could not find vector field in collection '{collection_name}'. Dropping for safety.")
utility.drop_collection(collection_name)
else:
print(f"πŸ“ Collection '{collection_name}' does not exist. It will be created.")
except Exception as e:
print(f"Failed to validate or drop collection: {e}")
# ---------------------- Load FAQ Vector DB ----------------------
def load_faqs():
if not os.path.exists(faq_path):
pd.DataFrame(columns=["id", "prompt", "response"]).to_csv(faq_path, index=False, encoding="utf-8")
# Check collection schema (768 dims for mpnet vs 1536 for OpenAI)
ensure_collection_matches_schema(
expected_dim=768, # FREE model dimension
collection_name=collection_name,
uri=milvus_uri,
token=milvus_token
)
loader = CSVLoader(faq_path, encoding="utf-8")
docs = loader.load()
if not docs:
docs = [Document(page_content="This is a placeholder FAQ")]
# Configure Milvus to store metadata fields
return Milvus.from_documents(
documents=docs,
embedding=embedding_model,
connection_args={
"uri": milvus_uri,
"token": milvus_token,
},
collection_name=collection_name,
drop_old=True,
)
db = load_faqs()
# ---------------------- Response Cache ----------------------
# In-memory cache: { query_hash: { "response": str, "timestamp": float } }
RESPONSE_CACHE: dict[str, dict] = {}
CACHE_TTL = 3600 # cache responses for 1 hour (in seconds)
CACHE_MAX_SIZE = 500 # max cached entries to prevent memory bloat
def _cache_key(query: str) -> str:
"""Normalize and hash the query for cache lookup."""
normalized = query.strip().lower()
return hashlib.sha256(normalized.encode()).hexdigest()
def get_cached_response(query: str) -> str | None:
"""Return cached response if it exists and is not expired."""
key = _cache_key(query)
entry = RESPONSE_CACHE.get(key)
if entry and (time.time() - entry["timestamp"]) < CACHE_TTL:
print(f"[Cache] HIT β€” saved an LLM call")
return entry["response"]
if entry:
del RESPONSE_CACHE[key] # expired, clean up
return None
def set_cached_response(query: str, response: str):
"""Store a response in the cache."""
# Evict oldest entries if cache is full
if len(RESPONSE_CACHE) >= CACHE_MAX_SIZE:
oldest_key = min(RESPONSE_CACHE, key=lambda k: RESPONSE_CACHE[k]["timestamp"])
del RESPONSE_CACHE[oldest_key]
key = _cache_key(query)
RESPONSE_CACHE[key] = {"response": response, "timestamp": time.time()}
# ---------------------- Cost Tracker ----------------------
# GPT-4o-mini pricing (per 1M tokens)
INPUT_COST_PER_M = 0.15 # $0.15 per 1M input tokens
OUTPUT_COST_PER_M = 0.60 # $0.60 per 1M output tokens
_cost_tracker = {
"total_queries": 0,
"cached_queries": 0,
"total_input_tokens": 0,
"total_output_tokens": 0,
"total_cost": 0.0,
"session_start": time.time(),
}
def _log_cost(input_tokens: int, output_tokens: int, query_preview: str):
"""Log per-query cost and running totals to console."""
input_cost = (input_tokens / 1_000_000) * INPUT_COST_PER_M
output_cost = (output_tokens / 1_000_000) * OUTPUT_COST_PER_M
query_cost = input_cost + output_cost
_cost_tracker["total_queries"] += 1
_cost_tracker["total_input_tokens"] += input_tokens
_cost_tracker["total_output_tokens"] += output_tokens
_cost_tracker["total_cost"] += query_cost
uptime_hrs = (time.time() - _cost_tracker["session_start"]) / 3600
print(f"")
print(f"πŸ’° ─── COST DEBUG ─────────────────────────────────")
print(f" Query: \"{query_preview[:60]}...\"" if len(query_preview) > 60 else f" Query: \"{query_preview}\"")
print(f" Input: {input_tokens:,} tokens β†’ ${input_cost:.6f}")
print(f" Output: {output_tokens:,} tokens β†’ ${output_cost:.6f}")
print(f" This query cost: ${query_cost:.6f}")
print(f" ── Session Totals ({uptime_hrs:.1f}h uptime) ──")
print(f" Total queries: {_cost_tracker['total_queries']} paid | {_cost_tracker['cached_queries']} cached (free)")
print(f" Total tokens: {_cost_tracker['total_input_tokens']:,} in | {_cost_tracker['total_output_tokens']:,} out")
print(f" Total session cost: ${_cost_tracker['total_cost']:.6f}")
print(f" Projected daily: ${(_cost_tracker['total_cost'] / max(uptime_hrs, 0.01)) * 24:.4f}")
print(f"──────────────────────────────────────────────────")
print(f"")
def _log_cache_hit(query_preview: str):
"""Log cache hit to console."""
_cost_tracker["cached_queries"] += 1
total = _cost_tracker["total_queries"] + _cost_tracker["cached_queries"]
hit_rate = (_cost_tracker["cached_queries"] / total * 100) if total > 0 else 0
print(f"")
print(f"βœ… ─── CACHE HIT (FREE) ───────────────────────────")
print(f" Query: \"{query_preview[:60]}...\"" if len(query_preview) > 60 else f" Query: \"{query_preview}\"")
print(f" Saved: ~$0.000275")
print(f" Cache hit rate: {hit_rate:.0f}% ({_cost_tracker['cached_queries']}/{total})")
print(f" Session cost still: ${_cost_tracker['total_cost']:.6f}")
print(f"──────────────────────────────────────────────────")
print(f"")
# ---------------------- Vector DB Retrieval ----------------------
# Relevance threshold β€” L2 distance; lower = more relevant.
# Docs with score > this are too dissimilar and get dropped.
RELEVANCE_THRESHOLD = 0.85
def retrieve_faqs(question: str, k: int = 3):
"""Return top-k *relevant* FAQ Documents (score-filtered)."""
try:
results_with_scores = db.similarity_search_with_score(question, k=k)
# Filter: keep only docs whose distance is below the threshold
relevant = [(doc, score) for doc, score in results_with_scores
if score <= RELEVANCE_THRESHOLD]
best_score = f"{results_with_scores[0][1]:.3f}" if results_with_scores else "n/a"
if relevant:
print(f"[retrieve] {len(relevant)}/{len(results_with_scores)} docs passed threshold "
f"(best={best_score}, cutoff={RELEVANCE_THRESHOLD})")
else:
print(f"[retrieve] 0/{len(results_with_scores)} docs relevant "
f"(best={best_score}, cutoff={RELEVANCE_THRESHOLD}) β€” skipping context")
return [doc for doc, _ in relevant]
except Exception as e:
print(f"[retrieve_faqs] retrieval failed: {e}")
return []
# ---------------------- LLM Wrapper ----------------------
def ask_openai(prompt: str, use_faqs: bool = True, k: int = 3) -> str:
# Check cache first β€” avoid LLM call entirely
cached = get_cached_response(prompt)
if cached:
_log_cache_hit(prompt)
return cached
chat = ChatOpenAI(
model_name="gpt-4o-mini",
temperature=0.3,
max_tokens=500,
openai_api_key=os.getenv("OPENAI_API_KEY")
)
# Retrieve relevant docs first so we can choose the right prompt
docs = []
source_urls = []
if use_faqs:
docs = retrieve_faqs(prompt, k=k)
if docs:
for d in docs:
if hasattr(d, 'metadata') and d.metadata:
if 'source_url' in d.metadata and 'page_title' in d.metadata:
source_urls.append({
'url': d.metadata['source_url'],
'title': d.metadata['page_title']
})
# Choose system prompt based on whether we have relevant context
if docs:
# Full prompt with source-citation instructions
system_msg = SystemMessage(content=(
"You are Noah, a friendly assistant by VisaVerse (https://visaverse.ca). "
"Answer concisely from the provided excerpts. "
"Cite IRCC sources at the end as: **Sources:** - [Title](URL)"
))
else:
# Lightweight prompt β€” no context, no citation rules needed
system_msg = SystemMessage(content=(
"You are Noah, a friendly assistant by VisaVerse. "
"Answer briefly. For immigration questions, suggest visiting https://visaverse.ca."
))
messages = [system_msg]
# Add FAQ context only if relevant docs were found
if docs:
excerpts = [d.page_content for d in docs]
context_text = "\n---\n".join(excerpts)
messages.append(SystemMessage(content=f"Excerpts:\n{context_text}"))
messages.append(HumanMessage(content=prompt))
result = chat.generate([messages])
resp = result.generations[0][0].text.strip()
# Extract token usage and log cost
token_usage = result.llm_output.get("token_usage", {}) if result.llm_output else {}
input_tokens = token_usage.get("prompt_tokens", 0)
output_tokens = token_usage.get("completion_tokens", 0)
_log_cost(input_tokens, output_tokens, prompt)
# Append source links if available
if source_urls:
unique_sources = []
seen_urls = set()
for src in source_urls:
if src['url'] not in seen_urls:
unique_sources.append(src)
seen_urls.add(src['url'])
if unique_sources:
resp += "\n\n**Sources:**"
for src in unique_sources[:3]: # Limit to top 3 sources
resp += f"\n- [{src['title']}]({src['url']})"
# Cache the response for future identical questions
set_cached_response(prompt, resp)
return resp
# ---------------------- Append New FAQ to CSV ----------------------
def add_faq_to_csv(question: str, answer: str):
df = pd.read_csv(faq_path, encoding="utf-8")
if not ((df["prompt"] == question) & (df["response"] == answer)).any():
new_row = pd.DataFrame([{"id": str(uuid.uuid4()), "prompt": question, "response": answer}])
df = pd.concat([df, new_row], ignore_index=True)
df.to_csv(faq_path, index=False, encoding="utf-8")