Spaces:
Runtime error
Runtime error
| # updated faq_services.py | |
| import os | |
| import re | |
| import pandas as pd | |
| import uuid | |
| import hashlib | |
| import time | |
| import markdown | |
| from dotenv import load_dotenv | |
| # LangChain / Vector DB | |
| from langchain.vectorstores import Milvus | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.schema import HumanMessage | |
| from langchain.docstore.document import Document | |
| from langchain_community.document_loaders import CSVLoader | |
| from langchain.schema import SystemMessage, HumanMessage | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from pymilvus import connections, utility, Collection | |
| from pymilvus.orm.schema import FieldSchema | |
| # ---------------------- Environment Setup ---------------------- | |
| load_dotenv() | |
| os.environ["HF_HOME"] = "/tmp/hf_cache" # Optional cleanup | |
| # ---------------------- File & Model Config ---------------------- | |
| faq_path = "faqs.csv" | |
| # π° FREE Embeddings - Sentence Transformers (saves ~$18/day!) | |
| print("π Loading FREE embedding model (all-mpnet-base-v2)...") | |
| embedding_model = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-mpnet-base-v2", | |
| model_kwargs={'device': 'cpu'}, | |
| encode_kwargs={'normalize_embeddings': True} | |
| ) | |
| print("β FREE embedding model loaded successfully!") | |
| # NOTE: Dimension changed from 1536 (OpenAI) to 768 (mpnet) | |
| EMBEDDING_DIM = 768 | |
| IMPORTANT_KEYWORDS = [ | |
| "visa", "permanent residency", "PR", "study permit", "work permit", "immigration", | |
| "Canada", "eligibility", "requirement", "application", "contact us", "support" | |
| ] | |
| ANCHOR_RE = re.compile(r'<a\b[^>]*>.*?<\/a>', flags=re.IGNORECASE|re.DOTALL) | |
| EMAIL_RE = re.compile(r'(?<!["\w])([A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,})') | |
| URL_RE = re.compile(r'(?<!["\w])(https?://[^\s<>"\)]+)') | |
| def style_response(text: str) -> str: | |
| # 0) Remove stray markdown italics the model might add | |
| text = re.sub(r'\*(.*?)\*', r'\1', text) | |
| # 1) Linkify plain emails and URLs (HTML only, no Markdown) | |
| text = EMAIL_RE.sub(r'<a href="mailto:\1">\1</a>', text) | |
| text = URL_RE.sub(r'<a href="\1">\1</a>', text) | |
| # 2) Protect existing anchors so we don't bold inside them | |
| anchors = [] | |
| def _stash(m): | |
| anchors.append(m.group(0)) | |
| return f'__LINK{len(anchors)-1}__' | |
| text = ANCHOR_RE.sub(_stash, text) | |
| # 3) Bold important keywords OUTSIDE links | |
| for word in sorted(IMPORTANT_KEYWORDS, key=lambda w: -len(w)): # longest first | |
| kw = re.compile(rf'\b({re.escape(word)})\b', flags=re.IGNORECASE) | |
| text = kw.sub(r'<strong>\1</strong>', text) | |
| # 4) Restore anchors | |
| for i, a in enumerate(anchors): | |
| text = text.replace(f'__LINK{i}__', a) | |
| # Optional: wrap once in a paragraph (donβt wrap again elsewhere) | |
| return f"<p>{text}</p>" | |
| def style_response_as_html(text: str) -> str: | |
| markdown_text = style_response(text) | |
| html = markdown.markdown(markdown_text) | |
| return html | |
| # Zilliz (Milvus) Cloud Config | |
| milvus_uri = os.getenv("ZILLIZ_URI") | |
| milvus_token = os.getenv("ZILLIZ_TOKEN") | |
| collection_name = os.getenv("ZILLIZ_COLLECTION", "visaverse_faqs3") | |
| connections.connect( | |
| alias="default", | |
| uri=os.getenv("ZILLIZ_URI"), | |
| token=os.getenv("ZILLIZ_TOKEN") | |
| ) | |
| def ensure_collection_matches_schema(expected_dim: int, collection_name: str, uri: str, token: str): | |
| try: | |
| connections.connect(uri=uri, token=token) | |
| if utility.has_collection(collection_name): | |
| schema = Collection(collection_name).schema | |
| for field in schema.fields: | |
| if field.dtype.name == "FLOAT_VECTOR": | |
| if field.params and "dim" in field.params: | |
| actual_dim = int(field.params["dim"]) | |
| if actual_dim != expected_dim: | |
| print(f"β οΈ Collection dim mismatch: {actual_dim} vs {expected_dim}. Dropping old collection.") | |
| utility.drop_collection(collection_name) | |
| return | |
| else: | |
| print(f"β Collection '{collection_name}' has correct dimension: {expected_dim}.") | |
| return | |
| print(f"β οΈ Could not find vector field in collection '{collection_name}'. Dropping for safety.") | |
| utility.drop_collection(collection_name) | |
| else: | |
| print(f"π Collection '{collection_name}' does not exist. It will be created.") | |
| except Exception as e: | |
| print(f"Failed to validate or drop collection: {e}") | |
| # ---------------------- Load FAQ Vector DB ---------------------- | |
| def load_faqs(): | |
| if not os.path.exists(faq_path): | |
| pd.DataFrame(columns=["id", "prompt", "response"]).to_csv(faq_path, index=False, encoding="utf-8") | |
| # Check collection schema (768 dims for mpnet vs 1536 for OpenAI) | |
| ensure_collection_matches_schema( | |
| expected_dim=768, # FREE model dimension | |
| collection_name=collection_name, | |
| uri=milvus_uri, | |
| token=milvus_token | |
| ) | |
| loader = CSVLoader(faq_path, encoding="utf-8") | |
| docs = loader.load() | |
| if not docs: | |
| docs = [Document(page_content="This is a placeholder FAQ")] | |
| # Configure Milvus to store metadata fields | |
| return Milvus.from_documents( | |
| documents=docs, | |
| embedding=embedding_model, | |
| connection_args={ | |
| "uri": milvus_uri, | |
| "token": milvus_token, | |
| }, | |
| collection_name=collection_name, | |
| drop_old=True, | |
| ) | |
| db = load_faqs() | |
| # ---------------------- Response Cache ---------------------- | |
| # In-memory cache: { query_hash: { "response": str, "timestamp": float } } | |
| RESPONSE_CACHE: dict[str, dict] = {} | |
| CACHE_TTL = 3600 # cache responses for 1 hour (in seconds) | |
| CACHE_MAX_SIZE = 500 # max cached entries to prevent memory bloat | |
| def _cache_key(query: str) -> str: | |
| """Normalize and hash the query for cache lookup.""" | |
| normalized = query.strip().lower() | |
| return hashlib.sha256(normalized.encode()).hexdigest() | |
| def get_cached_response(query: str) -> str | None: | |
| """Return cached response if it exists and is not expired.""" | |
| key = _cache_key(query) | |
| entry = RESPONSE_CACHE.get(key) | |
| if entry and (time.time() - entry["timestamp"]) < CACHE_TTL: | |
| print(f"[Cache] HIT β saved an LLM call") | |
| return entry["response"] | |
| if entry: | |
| del RESPONSE_CACHE[key] # expired, clean up | |
| return None | |
| def set_cached_response(query: str, response: str): | |
| """Store a response in the cache.""" | |
| # Evict oldest entries if cache is full | |
| if len(RESPONSE_CACHE) >= CACHE_MAX_SIZE: | |
| oldest_key = min(RESPONSE_CACHE, key=lambda k: RESPONSE_CACHE[k]["timestamp"]) | |
| del RESPONSE_CACHE[oldest_key] | |
| key = _cache_key(query) | |
| RESPONSE_CACHE[key] = {"response": response, "timestamp": time.time()} | |
| # ---------------------- Cost Tracker ---------------------- | |
| # GPT-4o-mini pricing (per 1M tokens) | |
| INPUT_COST_PER_M = 0.15 # $0.15 per 1M input tokens | |
| OUTPUT_COST_PER_M = 0.60 # $0.60 per 1M output tokens | |
| _cost_tracker = { | |
| "total_queries": 0, | |
| "cached_queries": 0, | |
| "total_input_tokens": 0, | |
| "total_output_tokens": 0, | |
| "total_cost": 0.0, | |
| "session_start": time.time(), | |
| } | |
| def _log_cost(input_tokens: int, output_tokens: int, query_preview: str): | |
| """Log per-query cost and running totals to console.""" | |
| input_cost = (input_tokens / 1_000_000) * INPUT_COST_PER_M | |
| output_cost = (output_tokens / 1_000_000) * OUTPUT_COST_PER_M | |
| query_cost = input_cost + output_cost | |
| _cost_tracker["total_queries"] += 1 | |
| _cost_tracker["total_input_tokens"] += input_tokens | |
| _cost_tracker["total_output_tokens"] += output_tokens | |
| _cost_tracker["total_cost"] += query_cost | |
| uptime_hrs = (time.time() - _cost_tracker["session_start"]) / 3600 | |
| print(f"") | |
| print(f"π° βββ COST DEBUG βββββββββββββββββββββββββββββββββ") | |
| print(f" Query: \"{query_preview[:60]}...\"" if len(query_preview) > 60 else f" Query: \"{query_preview}\"") | |
| print(f" Input: {input_tokens:,} tokens β ${input_cost:.6f}") | |
| print(f" Output: {output_tokens:,} tokens β ${output_cost:.6f}") | |
| print(f" This query cost: ${query_cost:.6f}") | |
| print(f" ββ Session Totals ({uptime_hrs:.1f}h uptime) ββ") | |
| print(f" Total queries: {_cost_tracker['total_queries']} paid | {_cost_tracker['cached_queries']} cached (free)") | |
| print(f" Total tokens: {_cost_tracker['total_input_tokens']:,} in | {_cost_tracker['total_output_tokens']:,} out") | |
| print(f" Total session cost: ${_cost_tracker['total_cost']:.6f}") | |
| print(f" Projected daily: ${(_cost_tracker['total_cost'] / max(uptime_hrs, 0.01)) * 24:.4f}") | |
| print(f"ββββββββββββββββββββββββββββββββββββββββββββββββββ") | |
| print(f"") | |
| def _log_cache_hit(query_preview: str): | |
| """Log cache hit to console.""" | |
| _cost_tracker["cached_queries"] += 1 | |
| total = _cost_tracker["total_queries"] + _cost_tracker["cached_queries"] | |
| hit_rate = (_cost_tracker["cached_queries"] / total * 100) if total > 0 else 0 | |
| print(f"") | |
| print(f"β βββ CACHE HIT (FREE) βββββββββββββββββββββββββββ") | |
| print(f" Query: \"{query_preview[:60]}...\"" if len(query_preview) > 60 else f" Query: \"{query_preview}\"") | |
| print(f" Saved: ~$0.000275") | |
| print(f" Cache hit rate: {hit_rate:.0f}% ({_cost_tracker['cached_queries']}/{total})") | |
| print(f" Session cost still: ${_cost_tracker['total_cost']:.6f}") | |
| print(f"ββββββββββββββββββββββββββββββββββββββββββββββββββ") | |
| print(f"") | |
| # ---------------------- Vector DB Retrieval ---------------------- | |
| # Relevance threshold β L2 distance; lower = more relevant. | |
| # Docs with score > this are too dissimilar and get dropped. | |
| RELEVANCE_THRESHOLD = 0.85 | |
| def retrieve_faqs(question: str, k: int = 3): | |
| """Return top-k *relevant* FAQ Documents (score-filtered).""" | |
| try: | |
| results_with_scores = db.similarity_search_with_score(question, k=k) | |
| # Filter: keep only docs whose distance is below the threshold | |
| relevant = [(doc, score) for doc, score in results_with_scores | |
| if score <= RELEVANCE_THRESHOLD] | |
| best_score = f"{results_with_scores[0][1]:.3f}" if results_with_scores else "n/a" | |
| if relevant: | |
| print(f"[retrieve] {len(relevant)}/{len(results_with_scores)} docs passed threshold " | |
| f"(best={best_score}, cutoff={RELEVANCE_THRESHOLD})") | |
| else: | |
| print(f"[retrieve] 0/{len(results_with_scores)} docs relevant " | |
| f"(best={best_score}, cutoff={RELEVANCE_THRESHOLD}) β skipping context") | |
| return [doc for doc, _ in relevant] | |
| except Exception as e: | |
| print(f"[retrieve_faqs] retrieval failed: {e}") | |
| return [] | |
| # ---------------------- LLM Wrapper ---------------------- | |
| def ask_openai(prompt: str, use_faqs: bool = True, k: int = 3) -> str: | |
| # Check cache first β avoid LLM call entirely | |
| cached = get_cached_response(prompt) | |
| if cached: | |
| _log_cache_hit(prompt) | |
| return cached | |
| chat = ChatOpenAI( | |
| model_name="gpt-4o-mini", | |
| temperature=0.3, | |
| max_tokens=500, | |
| openai_api_key=os.getenv("OPENAI_API_KEY") | |
| ) | |
| # Retrieve relevant docs first so we can choose the right prompt | |
| docs = [] | |
| source_urls = [] | |
| if use_faqs: | |
| docs = retrieve_faqs(prompt, k=k) | |
| if docs: | |
| for d in docs: | |
| if hasattr(d, 'metadata') and d.metadata: | |
| if 'source_url' in d.metadata and 'page_title' in d.metadata: | |
| source_urls.append({ | |
| 'url': d.metadata['source_url'], | |
| 'title': d.metadata['page_title'] | |
| }) | |
| # Choose system prompt based on whether we have relevant context | |
| if docs: | |
| # Full prompt with source-citation instructions | |
| system_msg = SystemMessage(content=( | |
| "You are Noah, a friendly assistant by VisaVerse (https://visaverse.ca). " | |
| "Answer concisely from the provided excerpts. " | |
| "Cite IRCC sources at the end as: **Sources:** - [Title](URL)" | |
| )) | |
| else: | |
| # Lightweight prompt β no context, no citation rules needed | |
| system_msg = SystemMessage(content=( | |
| "You are Noah, a friendly assistant by VisaVerse. " | |
| "Answer briefly. For immigration questions, suggest visiting https://visaverse.ca." | |
| )) | |
| messages = [system_msg] | |
| # Add FAQ context only if relevant docs were found | |
| if docs: | |
| excerpts = [d.page_content for d in docs] | |
| context_text = "\n---\n".join(excerpts) | |
| messages.append(SystemMessage(content=f"Excerpts:\n{context_text}")) | |
| messages.append(HumanMessage(content=prompt)) | |
| result = chat.generate([messages]) | |
| resp = result.generations[0][0].text.strip() | |
| # Extract token usage and log cost | |
| token_usage = result.llm_output.get("token_usage", {}) if result.llm_output else {} | |
| input_tokens = token_usage.get("prompt_tokens", 0) | |
| output_tokens = token_usage.get("completion_tokens", 0) | |
| _log_cost(input_tokens, output_tokens, prompt) | |
| # Append source links if available | |
| if source_urls: | |
| unique_sources = [] | |
| seen_urls = set() | |
| for src in source_urls: | |
| if src['url'] not in seen_urls: | |
| unique_sources.append(src) | |
| seen_urls.add(src['url']) | |
| if unique_sources: | |
| resp += "\n\n**Sources:**" | |
| for src in unique_sources[:3]: # Limit to top 3 sources | |
| resp += f"\n- [{src['title']}]({src['url']})" | |
| # Cache the response for future identical questions | |
| set_cached_response(prompt, resp) | |
| return resp | |
| # ---------------------- Append New FAQ to CSV ---------------------- | |
| def add_faq_to_csv(question: str, answer: str): | |
| df = pd.read_csv(faq_path, encoding="utf-8") | |
| if not ((df["prompt"] == question) & (df["response"] == answer)).any(): | |
| new_row = pd.DataFrame([{"id": str(uuid.uuid4()), "prompt": question, "response": answer}]) | |
| df = pd.concat([df, new_row], ignore_index=True) | |
| df.to_csv(faq_path, index=False, encoding="utf-8") | |