from __future__ import annotations import re from app.config import DEFAULT_TOP_K, GEMINI_MODEL, RAW_DIR from app.context_compression import compress_chunk_text from app.intermarket import build_intermarket_context from app.market_snapshot import ( build_market_snapshot_context, compressed_quick_summary_context, format_market_snapshot_answer, format_quick_stock_summary, is_market_snapshot_query, is_quick_summary_query, load_latest_market_snapshot, ) from app.prompts import load_prompt from app.retriever import hybrid_retrieve from app.schemas import RetrievedChunk from app.runtime_auth import get_gemini_api_key from app.technical_analysis import build_technical_context, is_technical_query SYSTEM_PROMPT = load_prompt("rag_system.md") def build_context(chunks: list[RetrievedChunk], question: str = "") -> str: blocks: list[str] = [] for index, chunk in enumerate(chunks, start=1): blocks.append( "\n".join( [ ( f"[{index}] ticker={chunk.ticker} modality={chunk.modality} " f"scope={chunk.scope or chunk.ticker} structure={chunk.structure_type} score={chunk.score:.4f}" ), f"source={chunk.source_path}", f"heading={' > '.join(chunk.heading_path)}", compress_chunk_text(question, chunk), ] ) ) return "\n\n---\n\n".join(blocks) def build_prompt(question: str, chunks: list[RetrievedChunk], extra_context: str = "") -> str: context_parts = [] if extra_context: context_parts.extend(["Additional structured context:", extra_context]) context_parts.extend(["Retrieved context:", build_context(chunks, question)]) return "\n\n".join( [ SYSTEM_PROMPT, *context_parts, f"User question: {question}", "Answer:", ] ) def generate_answer(question: str, chunks: list[RetrievedChunk], extra_context: str = "") -> str: if not chunks and not extra_context: return "Chưa tìm thấy context phù hợp trong vector DB. Hãy index `data/raw` trước." prompt = build_prompt(question, chunks, extra_context) api_key = get_gemini_api_key() if api_key: try: from google import genai client = genai.Client(api_key=api_key) response = client.models.generate_content(model=GEMINI_MODEL, contents=prompt) return response.text or "" except Exception as exc: return ( f"Không gọi được Gemini ({exc}). Context liên quan:\n\n" f"{extra_context}\n\n{build_context(chunks, question)}" ) return ( "Bạn chưa nhập Gemini API key nên hệ thống trả về context liên quan nhất:\n\n" f"{extra_context}\n\n{build_context(chunks, question)}" ) def generate_quick_summary( question: str, snapshot, chunks: list[RetrievedChunk], ticker: str, ) -> str: fallback = format_quick_stock_summary(snapshot, chunks, ticker=ticker) api_key = get_gemini_api_key() if not api_key: return fallback try: from google import genai client = genai.Client(api_key=api_key) prompt = load_prompt("quick_summary.md").format( ticker=ticker.upper(), question=question, compressed_context=compressed_quick_summary_context(snapshot, chunks), ) response = client.models.generate_content(model=GEMINI_MODEL, contents=prompt) answer = (response.text or "").strip() return answer or fallback except Exception: return fallback def chunk_source(chunk: RetrievedChunk) -> dict: return { "score": chunk.score, "ticker": chunk.ticker, "scope": chunk.scope, "modality": chunk.modality, "structure_type": chunk.structure_type, "heading_path": chunk.heading_path, "source_path": chunk.source_path, "url": chunk.metadata.get("url"), "artifact_path": chunk.metadata.get("artifact_path"), } def friendly_source_label(source: dict) -> str: source_path = str(source.get("source_path") or source.get("artifact_path") or "").lower() ticker = str(source.get("ticker") or source.get("scope") or "thị trường").upper() if source.get("structure_type") == "market_snapshot": return f"Bảng giá {ticker}" if "analysis_report" in source_path: return f"Báo cáo phân tích {ticker}" if "financial_document" in source_path or "financial_documents" in source_path: return f"Báo cáo tài chính {ticker}" if "ticker_news" in source_path or "news_events" in source_path: return f"Tin tức và sự kiện {ticker}" if "stock_overview" in source_path: return f"Tổng quan cổ phiếu {ticker}" if "world_market" in source_path: return "Thị trường thế giới" return f"Nguồn dữ liệu {ticker}" def format_answer_citations(answer: str, sources: list[dict]) -> str: formatted = str(answer) replacements: dict[str, str] = {} for source in sources: label = friendly_source_label(source) url = str(source.get("url") or "").strip() replacement = f"[{label}]({url})" if url else f"**{label}**" for value in [source.get("source_path"), source.get("artifact_path")]: path = str(value or "").strip().replace("\\", "/") if not path: continue variants = {path} if path.startswith("data/"): variants.add(path.removeprefix("data/")) else: variants.add(f"data/{path}") replacements.update({variant: replacement for variant in variants}) for path in sorted(replacements, key=len, reverse=True): replacement = replacements[path] escaped = re.escape(path) formatted = re.sub(rf"`{escaped}`", replacement, formatted) formatted = re.sub(rf"\({escaped}\)", replacement, formatted) formatted = re.sub(escaped, replacement, formatted) return formatted def answer_question(question: str, ticker: str | None = None, top_k: int = DEFAULT_TOP_K) -> dict: chunks = hybrid_retrieve(question, top_k=top_k, ticker=ticker) extra_context_parts: list[str] = [] sources = [chunk_source(chunk) for chunk in chunks] quick_summary_query = is_quick_summary_query(question) intermarket_context, intermarket_source = build_intermarket_context(question) if intermarket_context and not ticker: extra_context_parts.append(intermarket_context) if intermarket_source: sources.insert(0, intermarket_source) snapshot = load_latest_market_snapshot(ticker) if (is_market_snapshot_query(question) or quick_summary_query) else None snapshot_context = build_market_snapshot_context(snapshot) if snapshot_context and snapshot: extra_context_parts.append(snapshot_context) sources.insert( 0, { "score": 1.0, "ticker": snapshot.ticker, "modality": "table", "structure_type": "market_snapshot", "heading_path": [], "source_path": snapshot.path.relative_to(RAW_DIR.parent).as_posix(), "url": f"https://24hmoney.vn/stock/{snapshot.ticker}", "artifact_path": snapshot.path.relative_to(RAW_DIR.parent).as_posix(), }, ) if is_technical_query(question): technical_context = build_technical_context(ticker) if technical_context: extra_context_parts.append(technical_context) extra_context = "\n\n---\n\n".join(extra_context_parts) if quick_summary_query and ticker: answer = generate_quick_summary(question, snapshot, chunks, ticker=ticker) elif snapshot and snapshot_context: answer = format_market_snapshot_answer(snapshot) else: answer = generate_answer(question, chunks, extra_context) answer = format_answer_citations(answer, sources) return { "question": question, "ticker": ticker, "answer": answer, "sources": sources, }