Spaces:
Running
Running
quanho114 commited on
Commit ·
ebb8326
1
Parent(s): 0d3f194
Deploy VietQA API
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Dockerfile +18 -0
- api.py +164 -0
- requirements-prod.txt +28 -0
- src/__init__.py +2 -0
- src/__pycache__/__init__.cpython-312.pyc +0 -0
- src/__pycache__/__init__.cpython-314.pyc +0 -0
- src/__pycache__/config.cpython-312.pyc +0 -0
- src/__pycache__/config.cpython-314.pyc +0 -0
- src/__pycache__/graph.cpython-312.pyc +0 -0
- src/__pycache__/pipeline.cpython-312.pyc +0 -0
- src/__pycache__/state.cpython-312.pyc +0 -0
- src/config.py +110 -0
- src/data_processing/__init__.py +26 -0
- src/data_processing/__pycache__/__init__.cpython-312.pyc +0 -0
- src/data_processing/__pycache__/__init__.cpython-314.pyc +0 -0
- src/data_processing/__pycache__/answer.cpython-312.pyc +0 -0
- src/data_processing/__pycache__/answer.cpython-314.pyc +0 -0
- src/data_processing/__pycache__/formatting.cpython-312.pyc +0 -0
- src/data_processing/__pycache__/formatting.cpython-314.pyc +0 -0
- src/data_processing/__pycache__/loaders.cpython-312.pyc +0 -0
- src/data_processing/__pycache__/loaders.cpython-314.pyc +0 -0
- src/data_processing/__pycache__/models.cpython-312.pyc +0 -0
- src/data_processing/__pycache__/models.cpython-314.pyc +0 -0
- src/data_processing/answer.py +151 -0
- src/data_processing/formatting.py +37 -0
- src/data_processing/loaders.py +151 -0
- src/data_processing/models.py +29 -0
- src/graph.py +47 -0
- src/nodes/__init__.py +15 -0
- src/nodes/__pycache__/__init__.cpython-312.pyc +0 -0
- src/nodes/__pycache__/direct.cpython-312.pyc +0 -0
- src/nodes/__pycache__/logic.cpython-312.pyc +0 -0
- src/nodes/__pycache__/rag.cpython-312.pyc +0 -0
- src/nodes/__pycache__/router.cpython-312.pyc +0 -0
- src/nodes/direct.py +42 -0
- src/nodes/logic.py +253 -0
- src/nodes/rag.py +141 -0
- src/nodes/router.py +112 -0
- src/pipeline.py +215 -0
- src/state.py +16 -0
- src/templates/direct_answer.j2 +19 -0
- src/templates/logic_solver.j2 +37 -0
- src/templates/rag.j2 +25 -0
- src/templates/router.j2 +43 -0
- src/utils/__init__.py +47 -0
- src/utils/__pycache__/__init__.cpython-312.pyc +0 -0
- src/utils/__pycache__/__init__.cpython-314.pyc +0 -0
- src/utils/__pycache__/checkpointing.cpython-312.pyc +0 -0
- src/utils/__pycache__/checkpointing.cpython-314.pyc +0 -0
- src/utils/__pycache__/common.cpython-312.pyc +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile for Hugging Face Spaces
|
| 2 |
+
FROM python:3.11-slim
|
| 3 |
+
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
# Install dependencies
|
| 7 |
+
COPY requirements-prod.txt .
|
| 8 |
+
RUN pip install --no-cache-dir -r requirements-prod.txt
|
| 9 |
+
|
| 10 |
+
# Copy application code
|
| 11 |
+
COPY api.py .
|
| 12 |
+
COPY src/ ./src/
|
| 13 |
+
|
| 14 |
+
# Expose port (HF Spaces uses port 7860)
|
| 15 |
+
EXPOSE 7860
|
| 16 |
+
|
| 17 |
+
# Run the API
|
| 18 |
+
CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
|
api.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI Backend for VietQA Multi-Agent System."""
|
| 2 |
+
|
| 3 |
+
from contextlib import asynccontextmanager
|
| 4 |
+
from fastapi import FastAPI, HTTPException
|
| 5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
|
| 8 |
+
from src.data_processing.models import QuestionInput
|
| 9 |
+
from src.data_processing.formatting import question_to_state
|
| 10 |
+
from src.data_processing.answer import normalize_answer
|
| 11 |
+
from src.graph import get_graph
|
| 12 |
+
from src.utils.llm import set_large_model_override, get_available_large_models
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@asynccontextmanager
|
| 16 |
+
async def lifespan(app: FastAPI):
|
| 17 |
+
"""Fast startup - lazy load models on first request."""
|
| 18 |
+
print("[Startup] Server starting (models will load on first request)...")
|
| 19 |
+
print("[Startup] Server ready!")
|
| 20 |
+
yield
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
app = FastAPI(
|
| 24 |
+
title="VietQA Multi-Agent API",
|
| 25 |
+
description="API cho hệ thống trả lời câu hỏi trắc nghiệm tiếng Việt",
|
| 26 |
+
version="1.0.0",
|
| 27 |
+
lifespan=lifespan
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# CORS for frontend
|
| 31 |
+
app.add_middleware(
|
| 32 |
+
CORSMiddleware,
|
| 33 |
+
allow_origins=["*"],
|
| 34 |
+
allow_credentials=True,
|
| 35 |
+
allow_methods=["*"],
|
| 36 |
+
allow_headers=["*"],
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class SolveRequest(BaseModel):
|
| 41 |
+
question: str
|
| 42 |
+
choices: list[str]
|
| 43 |
+
model: str | None = None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class SolveResponse(BaseModel):
|
| 47 |
+
answer: str
|
| 48 |
+
route: str
|
| 49 |
+
reasoning: str
|
| 50 |
+
context: str
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def clean_thinking_tags(text: str) -> str:
|
| 54 |
+
"""Remove <think>...</think> tags from model response."""
|
| 55 |
+
import re
|
| 56 |
+
# Remove think tags and their content
|
| 57 |
+
cleaned = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
|
| 58 |
+
return cleaned.strip()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class ChatRequest(BaseModel):
|
| 62 |
+
message: str
|
| 63 |
+
model: str | None = None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class ChatResponse(BaseModel):
|
| 67 |
+
response: str
|
| 68 |
+
route: str
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class ModelsResponse(BaseModel):
|
| 72 |
+
models: list[dict]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@app.get("/")
|
| 76 |
+
async def root():
|
| 77 |
+
return {"message": "VietQA Multi-Agent API", "status": "running"}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@app.get("/health")
|
| 81 |
+
async def health():
|
| 82 |
+
"""Health check endpoint for Render."""
|
| 83 |
+
return {"status": "ok"}
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@app.get("/api/models", response_model=ModelsResponse)
|
| 87 |
+
async def get_models():
|
| 88 |
+
"""Get available large models."""
|
| 89 |
+
models = get_available_large_models()
|
| 90 |
+
return {
|
| 91 |
+
"models": [
|
| 92 |
+
{"id": m, "name": m.split("/")[-1]}
|
| 93 |
+
for m in models
|
| 94 |
+
]
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@app.post("/api/solve", response_model=SolveResponse)
|
| 99 |
+
async def solve_question(req: SolveRequest):
|
| 100 |
+
"""Solve a multiple-choice question."""
|
| 101 |
+
if not req.question.strip():
|
| 102 |
+
raise HTTPException(400, "Question is required")
|
| 103 |
+
if len(req.choices) < 2:
|
| 104 |
+
raise HTTPException(400, "At least 2 choices required")
|
| 105 |
+
|
| 106 |
+
set_large_model_override(req.model)
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
q = QuestionInput(qid="api", question=req.question, choices=req.choices)
|
| 110 |
+
state = question_to_state(q)
|
| 111 |
+
graph = get_graph()
|
| 112 |
+
|
| 113 |
+
result = await graph.ainvoke(state)
|
| 114 |
+
|
| 115 |
+
answer = normalize_answer(
|
| 116 |
+
answer=result.get("answer", "A"),
|
| 117 |
+
num_choices=len(req.choices),
|
| 118 |
+
question_id="api",
|
| 119 |
+
default="A"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
return SolveResponse(
|
| 123 |
+
answer=answer,
|
| 124 |
+
route=result.get("route", "unknown"),
|
| 125 |
+
reasoning=clean_thinking_tags(result.get("raw_response", "")),
|
| 126 |
+
context=result.get("context", "")
|
| 127 |
+
)
|
| 128 |
+
except Exception as e:
|
| 129 |
+
import traceback
|
| 130 |
+
traceback.print_exc()
|
| 131 |
+
raise HTTPException(500, str(e))
|
| 132 |
+
finally:
|
| 133 |
+
set_large_model_override(None)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@app.post("/api/chat", response_model=ChatResponse)
|
| 137 |
+
async def chat(req: ChatRequest):
|
| 138 |
+
"""Free-form chat (routes through pipeline without choices)."""
|
| 139 |
+
if not req.message.strip():
|
| 140 |
+
raise HTTPException(400, "Message is required")
|
| 141 |
+
|
| 142 |
+
set_large_model_override(req.model)
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
# Use empty choices for chat mode
|
| 146 |
+
q = QuestionInput(qid="chat", question=req.message, choices=[])
|
| 147 |
+
state = question_to_state(q)
|
| 148 |
+
graph = get_graph()
|
| 149 |
+
|
| 150 |
+
result = await graph.ainvoke(state)
|
| 151 |
+
|
| 152 |
+
return ChatResponse(
|
| 153 |
+
response=clean_thinking_tags(result.get("raw_response", "")),
|
| 154 |
+
route=result.get("route", "unknown")
|
| 155 |
+
)
|
| 156 |
+
except Exception as e:
|
| 157 |
+
raise HTTPException(500, str(e))
|
| 158 |
+
finally:
|
| 159 |
+
set_large_model_override(None)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
if __name__ == "__main__":
|
| 163 |
+
import uvicorn
|
| 164 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
requirements-prod.txt
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Production dependencies - minimal for API only
|
| 2 |
+
fastapi==0.115.0
|
| 3 |
+
uvicorn==0.30.6
|
| 4 |
+
pydantic==2.12.5
|
| 5 |
+
pydantic-settings==2.12.0
|
| 6 |
+
python-dotenv==1.2.1
|
| 7 |
+
|
| 8 |
+
# LangChain
|
| 9 |
+
langchain==1.1.0
|
| 10 |
+
langchain-core==1.1.0
|
| 11 |
+
langchain-community==0.4.1
|
| 12 |
+
langchain-text-splitters==1.0.0
|
| 13 |
+
langgraph==1.0.4
|
| 14 |
+
|
| 15 |
+
# LangChain integrations
|
| 16 |
+
langchain-openai
|
| 17 |
+
|
| 18 |
+
# HTTP client
|
| 19 |
+
httpx==0.28.1
|
| 20 |
+
requests==2.32.5
|
| 21 |
+
|
| 22 |
+
# Jinja2 for templates
|
| 23 |
+
jinja2==3.1.6
|
| 24 |
+
|
| 25 |
+
# Other essentials
|
| 26 |
+
tenacity==9.1.2
|
| 27 |
+
pyyaml==6.0.3
|
| 28 |
+
jsonpatch==1.33
|
src/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""VNPT AI RAG Pipeline for Vietnamese multiple-choice questions."""
|
| 2 |
+
|
src/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (225 Bytes). View file
|
|
|
src/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (227 Bytes). View file
|
|
|
src/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (3.91 kB). View file
|
|
|
src/__pycache__/config.cpython-314.pyc
ADDED
|
Binary file (3.88 kB). View file
|
|
|
src/__pycache__/graph.cpython-312.pyc
ADDED
|
Binary file (1.82 kB). View file
|
|
|
src/__pycache__/pipeline.cpython-312.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
src/__pycache__/state.cpython-312.pyc
ADDED
|
Binary file (726 Bytes). View file
|
|
|
src/config.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
from pydantic import Field
|
| 6 |
+
from pydantic_settings import BaseSettings
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
|
| 10 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 11 |
+
|
| 12 |
+
DATA_DIR = Path(os.getenv("DATA_DIR", PROJECT_ROOT / "data"))
|
| 13 |
+
DATA_INPUT_DIR = Path(os.getenv("DATA_INPUT_DIR", PROJECT_ROOT / "test_data"))
|
| 14 |
+
DATA_OUTPUT_DIR = Path(os.getenv("DATA_OUTPUT_DIR", PROJECT_ROOT / "output"))
|
| 15 |
+
DATA_CRAWLED_DIR = Path(os.getenv("DATA_CRAWLED_DIR", DATA_DIR / "crawl"))
|
| 16 |
+
BATCH_SIZE = 1
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Settings(BaseSettings):
|
| 20 |
+
"""Application settings with environment variable support."""
|
| 21 |
+
|
| 22 |
+
# MegaLLM API settings (for small model)
|
| 23 |
+
megallm_api_key: str = Field(
|
| 24 |
+
default="",
|
| 25 |
+
alias="MEGALLM_API_KEY",
|
| 26 |
+
description="API key for MegaLLM",
|
| 27 |
+
)
|
| 28 |
+
megallm_base_url: str = Field(
|
| 29 |
+
default="https://ai.megallm.io/v1",
|
| 30 |
+
alias="MEGALLM_BASE_URL",
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# Groq API settings (for large model)
|
| 34 |
+
groq_api_key: str = Field(
|
| 35 |
+
default="",
|
| 36 |
+
alias="GROQ_API_KEY",
|
| 37 |
+
description="API key for Groq",
|
| 38 |
+
)
|
| 39 |
+
groq_base_url: str = Field(
|
| 40 |
+
default="https://api.groq.com/openai/v1",
|
| 41 |
+
alias="GROQ_BASE_URL",
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# OpenRouter API (fallback)
|
| 45 |
+
openrouter_api_key: str = Field(
|
| 46 |
+
default="",
|
| 47 |
+
alias="OPENROUTER_API_KEY",
|
| 48 |
+
description="API key for OpenRouter (fallback)",
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Model names
|
| 52 |
+
model_small: str = Field(
|
| 53 |
+
default="qwen/qwen3-32b",
|
| 54 |
+
alias="MODEL_SMALL",
|
| 55 |
+
description="Small model for routing, reranking, and RAG",
|
| 56 |
+
)
|
| 57 |
+
model_large: str = Field(
|
| 58 |
+
default="meta-llama/llama-4-scout-17b-16e-instruct",
|
| 59 |
+
alias="MODEL_LARGE",
|
| 60 |
+
description="Large model for logic/direct answering",
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Available large models for testing
|
| 64 |
+
available_large_models: list[str] = [
|
| 65 |
+
"llama-3.3-70b-versatile",
|
| 66 |
+
"meta-llama/llama-4-scout-17b-16e-instruct",
|
| 67 |
+
"moonshotai/kimi-k2-instruct-0905",
|
| 68 |
+
"openai/gpt-oss-120b"
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
# Local embedding model (Vietnamese)
|
| 72 |
+
embedding_model: str = Field(
|
| 73 |
+
default="bkai-foundation-models/vietnamese-bi-encoder",
|
| 74 |
+
alias="EMBEDDING_MODEL",
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Vector database
|
| 78 |
+
qdrant_collection: str = Field(
|
| 79 |
+
default="vnpt_knowledge_base",
|
| 80 |
+
alias="QDRANT_COLLECTION",
|
| 81 |
+
)
|
| 82 |
+
vector_db_path: str = Field(
|
| 83 |
+
default="",
|
| 84 |
+
alias="VECTOR_DB_PATH",
|
| 85 |
+
description="Path to Qdrant storage. Defaults to DATA_DIR/qdrant_storage if empty.",
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
chunk_size: int = 1000
|
| 89 |
+
chunk_overlap: int = 200
|
| 90 |
+
top_k_retrieval: int = 10
|
| 91 |
+
top_k_rerank: int = 3
|
| 92 |
+
|
| 93 |
+
@property
|
| 94 |
+
def vector_db_path_resolved(self) -> Path:
|
| 95 |
+
"""Resolve vector database path, defaulting to DATA_DIR/qdrant_storage."""
|
| 96 |
+
if self.vector_db_path:
|
| 97 |
+
return Path(self.vector_db_path)
|
| 98 |
+
return DATA_DIR / "qdrant_storage"
|
| 99 |
+
|
| 100 |
+
class Config:
|
| 101 |
+
env_file = ".env"
|
| 102 |
+
extra = "ignore"
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
settings = Settings()
|
| 106 |
+
|
| 107 |
+
# Validate API key on import
|
| 108 |
+
if not settings.megallm_api_key:
|
| 109 |
+
import warnings
|
| 110 |
+
warnings.warn("MEGALLM_API_KEY not set. LLM calls will fail.")
|
src/data_processing/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data processing utilities for the RAG pipeline."""
|
| 2 |
+
|
| 3 |
+
from src.data_processing.answer import (
|
| 4 |
+
extract_answer,
|
| 5 |
+
extract_and_normalize,
|
| 6 |
+
normalize_answer,
|
| 7 |
+
validate_answer,
|
| 8 |
+
)
|
| 9 |
+
from src.data_processing.formatting import format_choices, format_choices_display, question_to_state
|
| 10 |
+
from src.data_processing.loaders import load_test_data_from_csv, load_test_data_from_json
|
| 11 |
+
from src.data_processing.models import InferenceLogEntry, PredictionOutput, QuestionInput
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"QuestionInput",
|
| 15 |
+
"PredictionOutput",
|
| 16 |
+
"InferenceLogEntry",
|
| 17 |
+
"load_test_data_from_json",
|
| 18 |
+
"load_test_data_from_csv",
|
| 19 |
+
"question_to_state",
|
| 20 |
+
"format_choices",
|
| 21 |
+
"format_choices_display",
|
| 22 |
+
"extract_answer",
|
| 23 |
+
"validate_answer",
|
| 24 |
+
"normalize_answer",
|
| 25 |
+
"extract_and_normalize",
|
| 26 |
+
]
|
src/data_processing/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (846 Bytes). View file
|
|
|
src/data_processing/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (843 Bytes). View file
|
|
|
src/data_processing/__pycache__/answer.cpython-312.pyc
ADDED
|
Binary file (5.15 kB). View file
|
|
|
src/data_processing/__pycache__/answer.cpython-314.pyc
ADDED
|
Binary file (5.61 kB). View file
|
|
|
src/data_processing/__pycache__/formatting.cpython-312.pyc
ADDED
|
Binary file (2.2 kB). View file
|
|
|
src/data_processing/__pycache__/formatting.cpython-314.pyc
ADDED
|
Binary file (2.72 kB). View file
|
|
|
src/data_processing/__pycache__/loaders.cpython-312.pyc
ADDED
|
Binary file (6.14 kB). View file
|
|
|
src/data_processing/__pycache__/loaders.cpython-314.pyc
ADDED
|
Binary file (7.11 kB). View file
|
|
|
src/data_processing/__pycache__/models.cpython-312.pyc
ADDED
|
Binary file (2.11 kB). View file
|
|
|
src/data_processing/__pycache__/models.cpython-314.pyc
ADDED
|
Binary file (2.94 kB). View file
|
|
|
src/data_processing/answer.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Answer extraction and validation utilities.
|
| 2 |
+
|
| 3 |
+
Consolidates answer-related logic:
|
| 4 |
+
- Extraction from LLM responses (CoT format)
|
| 5 |
+
- Validation against valid choices
|
| 6 |
+
- Normalization with fallback defaults
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import re
|
| 10 |
+
import string
|
| 11 |
+
|
| 12 |
+
from src.utils.logging import print_log
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def extract_answer(response: str, num_choices: int = 4, require_end: bool = False) -> str | None:
|
| 16 |
+
"""Extract answer letter from LLM response using strict explicit answer lines.
|
| 17 |
+
|
| 18 |
+
Only accepts answers from explicit final-answer lines with colon:
|
| 19 |
+
- "Đáp án: A", "Answer: B" (preferred)
|
| 20 |
+
- "Lựa chọn: C" (secondary)
|
| 21 |
+
|
| 22 |
+
Returns the LAST valid explicit answer line found (later lines override earlier).
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
response: Response text from LLM
|
| 26 |
+
num_choices: Number of valid choices
|
| 27 |
+
require_end: If True, only extract answer from last 20% of response
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
Answer letter (A, B, C, D) or None if no explicit answer found
|
| 31 |
+
"""
|
| 32 |
+
if not response:
|
| 33 |
+
return None
|
| 34 |
+
|
| 35 |
+
valid_labels = string.ascii_uppercase[:num_choices]
|
| 36 |
+
|
| 37 |
+
# If require_end, only look at last 20% of response
|
| 38 |
+
search_text = response
|
| 39 |
+
if require_end and len(response) > 100:
|
| 40 |
+
cutoff = int(len(response) * 0.8)
|
| 41 |
+
search_text = response[cutoff:]
|
| 42 |
+
|
| 43 |
+
# Pattern for primary labels: "Đáp án:" or "Answer:" (highest priority)
|
| 44 |
+
primary_pattern = r"(?:Đáp\s*án|Answer)[ \t]*[::][ \t]*\**([A-Z])\b"
|
| 45 |
+
|
| 46 |
+
# Pattern for secondary label: "Lựa chọn:" (lower priority)
|
| 47 |
+
secondary_pattern = r"Lựa\s*chọn[ \t]*[::][ \t]*\**([A-Z])\b"
|
| 48 |
+
|
| 49 |
+
# Find all matches for both patterns
|
| 50 |
+
primary_matches = re.findall(primary_pattern, search_text, flags=re.IGNORECASE)
|
| 51 |
+
secondary_matches = re.findall(secondary_pattern, search_text, flags=re.IGNORECASE)
|
| 52 |
+
|
| 53 |
+
if primary_matches:
|
| 54 |
+
answer = primary_matches[-1].upper()
|
| 55 |
+
if answer in valid_labels:
|
| 56 |
+
return answer
|
| 57 |
+
|
| 58 |
+
if secondary_matches:
|
| 59 |
+
answer = secondary_matches[-1].upper()
|
| 60 |
+
if answer in valid_labels:
|
| 61 |
+
return answer
|
| 62 |
+
|
| 63 |
+
# Single letter response (entire response is just a letter)
|
| 64 |
+
clean_response = search_text.strip()
|
| 65 |
+
if len(clean_response) == 1 and clean_response.upper() in valid_labels:
|
| 66 |
+
return clean_response.upper()
|
| 67 |
+
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def validate_answer(answer: str, num_choices: int) -> tuple[bool, str]:
|
| 72 |
+
"""Validate if answer is within valid range and normalize it.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
answer: Raw answer string from model
|
| 76 |
+
num_choices: Number of choices available (A, B, C, D, ...)
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
Tuple of (is_valid, normalized_answer)
|
| 80 |
+
"""
|
| 81 |
+
valid_answers = string.ascii_uppercase[:num_choices]
|
| 82 |
+
if answer and answer.upper() in valid_answers:
|
| 83 |
+
return True, answer.upper()
|
| 84 |
+
|
| 85 |
+
return False, answer or ""
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def normalize_answer(
|
| 89 |
+
answer: str | None,
|
| 90 |
+
num_choices: int,
|
| 91 |
+
question_id: str | None = None,
|
| 92 |
+
default: str = "A",
|
| 93 |
+
) -> str:
|
| 94 |
+
"""Normalize and validate answer with fallback to default.
|
| 95 |
+
|
| 96 |
+
Combines extraction, validation, and normalization:
|
| 97 |
+
- Validates answer is within valid range (A, B, C, D, ...)
|
| 98 |
+
- Normalizes refusal responses
|
| 99 |
+
- Falls back to default for invalid answers
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
answer: Raw answer string from model (can be None)
|
| 103 |
+
num_choices: Number of choices available
|
| 104 |
+
question_id: Optional question ID for logging warnings
|
| 105 |
+
default: Default answer if validation fails
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
Normalized answer string
|
| 109 |
+
"""
|
| 110 |
+
if answer is None:
|
| 111 |
+
if question_id:
|
| 112 |
+
print_log(
|
| 113 |
+
f" [Warning] No answer extracted for {question_id}, "
|
| 114 |
+
f"defaulting to {default}"
|
| 115 |
+
)
|
| 116 |
+
return default
|
| 117 |
+
|
| 118 |
+
is_valid, normalized = validate_answer(answer, num_choices)
|
| 119 |
+
|
| 120 |
+
if not is_valid:
|
| 121 |
+
if question_id:
|
| 122 |
+
print_log(
|
| 123 |
+
f" [Warning] Invalid answer '{answer}' for {question_id}, "
|
| 124 |
+
f"defaulting to {default}"
|
| 125 |
+
)
|
| 126 |
+
return default
|
| 127 |
+
|
| 128 |
+
return normalized
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def extract_and_normalize(
|
| 132 |
+
response: str,
|
| 133 |
+
num_choices: int,
|
| 134 |
+
question_id: str | None = None,
|
| 135 |
+
default: str = "A",
|
| 136 |
+
) -> str:
|
| 137 |
+
"""Extract answer from response and normalize it (convenience function).
|
| 138 |
+
|
| 139 |
+
Combines extract_answer() and normalize_answer() into a single call.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
response: Raw LLM response text
|
| 143 |
+
num_choices: Number of valid choices
|
| 144 |
+
question_id: Optional question ID for logging
|
| 145 |
+
default: Default answer if extraction/validation fails
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Normalized answer string
|
| 149 |
+
"""
|
| 150 |
+
extracted = extract_answer(response, num_choices=num_choices)
|
| 151 |
+
return normalize_answer(extracted, num_choices, question_id, default)
|
src/data_processing/formatting.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import string
|
| 2 |
+
from typing import TYPE_CHECKING
|
| 3 |
+
|
| 4 |
+
if TYPE_CHECKING:
|
| 5 |
+
from src.data_processing.models import QuestionInput
|
| 6 |
+
from src.state import GraphState
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def question_to_state(q: "QuestionInput") -> "GraphState":
|
| 10 |
+
"""Convert QuestionInput to GraphState for pipeline processing."""
|
| 11 |
+
state: "GraphState" = {
|
| 12 |
+
"question_id": q.qid,
|
| 13 |
+
"question": q.question,
|
| 14 |
+
"all_choices": q.choices,
|
| 15 |
+
}
|
| 16 |
+
return state
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def format_choices(choices: list[str]) -> str:
|
| 20 |
+
"""Format choices for LLM prompts (A. ..., B. ..., etc.)."""
|
| 21 |
+
return "\n".join(f"{label}. {text}" for label, text in zip(string.ascii_uppercase, choices))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def format_choices_display(choices: list[str]) -> str:
|
| 25 |
+
"""Format choices for console display (2 columns)."""
|
| 26 |
+
labels = string.ascii_uppercase
|
| 27 |
+
lines = []
|
| 28 |
+
for i in range(0, len(choices), 2):
|
| 29 |
+
parts = []
|
| 30 |
+
for j in range(2):
|
| 31 |
+
idx = i + j
|
| 32 |
+
if idx < len(choices):
|
| 33 |
+
label = labels[idx] if idx < len(labels) else str(idx)
|
| 34 |
+
parts.append(f"{label}. {choices[idx]:<30}")
|
| 35 |
+
if parts:
|
| 36 |
+
lines.append(" " + " ".join(parts))
|
| 37 |
+
return "\n".join(lines)
|
src/data_processing/loaders.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data loading utilities for test questions."""
|
| 2 |
+
|
| 3 |
+
import csv
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
from src.data_processing.models import QuestionInput
|
| 8 |
+
|
| 9 |
+
# Standard column mappings for choice columns
|
| 10 |
+
_CHOICE_COLUMN_MAPPINGS = {
|
| 11 |
+
"choice_a": 0, "choice_b": 1, "choice_c": 2, "choice_d": 3,
|
| 12 |
+
"option_a": 0, "option_b": 1, "option_c": 2, "option_d": 3,
|
| 13 |
+
"a": 0, "b": 1, "c": 2, "d": 3,
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def load_test_data_from_json(file_path: Path) -> list[QuestionInput]:
|
| 18 |
+
"""Load test questions from JSON file.
|
| 19 |
+
|
| 20 |
+
Expected format: List of dicts with qid, question, choices, answer (optional)
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
file_path: Path to JSON file
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
List of QuestionInput objects
|
| 27 |
+
|
| 28 |
+
Raises:
|
| 29 |
+
FileNotFoundError: If file doesn't exist
|
| 30 |
+
ValueError: If file format is invalid
|
| 31 |
+
"""
|
| 32 |
+
if not file_path.exists():
|
| 33 |
+
raise FileNotFoundError(f"Test data file not found: {file_path}")
|
| 34 |
+
|
| 35 |
+
if file_path.suffix.lower() != ".json":
|
| 36 |
+
raise ValueError(f"Only JSON files are supported: {file_path}")
|
| 37 |
+
|
| 38 |
+
with open(file_path, encoding="utf-8") as f:
|
| 39 |
+
data = json.load(f)
|
| 40 |
+
|
| 41 |
+
if not isinstance(data, list):
|
| 42 |
+
raise ValueError(f"JSON file must contain a list of questions: {file_path}")
|
| 43 |
+
|
| 44 |
+
questions = []
|
| 45 |
+
for item in data:
|
| 46 |
+
if "choices" not in item or not isinstance(item["choices"], list):
|
| 47 |
+
raise ValueError(f"Question {item.get('qid', 'unknown')} must have 'choices' as a list")
|
| 48 |
+
|
| 49 |
+
questions.append(QuestionInput(
|
| 50 |
+
qid=item["qid"],
|
| 51 |
+
question=item["question"],
|
| 52 |
+
choices=item["choices"],
|
| 53 |
+
answer=item.get("answer"),
|
| 54 |
+
))
|
| 55 |
+
|
| 56 |
+
return questions
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _normalize_row_keys(row: dict[str, str]) -> dict[str, str]:
|
| 60 |
+
"""Normalize row keys to lowercase and strip whitespace."""
|
| 61 |
+
return {k.lower().strip(): v for k, v in row.items()}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _extract_choices_from_row(row: dict[str, str]) -> list[str]:
|
| 65 |
+
"""Extract choices from a normalized CSV row.
|
| 66 |
+
|
| 67 |
+
Tries multiple strategies:
|
| 68 |
+
1. Individual choice columns (choice_a/option_a/a, etc.)
|
| 69 |
+
2. JSON array in 'choices' column
|
| 70 |
+
3. Comma/semicolon separated string in 'choices' column
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
row: Normalized row dict with lowercase keys
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
List of choice strings (may contain empty strings)
|
| 77 |
+
"""
|
| 78 |
+
# Strategy 1: Individual columns (choice_a, option_a, a, etc.)
|
| 79 |
+
choices = ["", "", "", ""]
|
| 80 |
+
found_individual = False
|
| 81 |
+
|
| 82 |
+
for col_name, idx in _CHOICE_COLUMN_MAPPINGS.items():
|
| 83 |
+
if col_name in row and row[col_name]:
|
| 84 |
+
choices[idx] = row[col_name].strip()
|
| 85 |
+
found_individual = True
|
| 86 |
+
|
| 87 |
+
if found_individual:
|
| 88 |
+
return [c for c in choices if c]
|
| 89 |
+
|
| 90 |
+
# Strategy 2 & 3: Parse 'choices' column
|
| 91 |
+
choices_raw = row.get("choices", "")
|
| 92 |
+
if not choices_raw:
|
| 93 |
+
return []
|
| 94 |
+
|
| 95 |
+
# Try JSON parse first
|
| 96 |
+
try:
|
| 97 |
+
parsed = json.loads(choices_raw)
|
| 98 |
+
if isinstance(parsed, list):
|
| 99 |
+
return [str(c).strip() for c in parsed if str(c).strip()]
|
| 100 |
+
except (json.JSONDecodeError, TypeError):
|
| 101 |
+
pass
|
| 102 |
+
|
| 103 |
+
# Fallback: split by comma or semicolon
|
| 104 |
+
return [c.strip() for c in choices_raw.replace(";", ",").split(",") if c.strip()]
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def load_test_data_from_csv(file_path: Path) -> list[QuestionInput]:
|
| 108 |
+
"""Load test questions from CSV file.
|
| 109 |
+
|
| 110 |
+
Supports multiple CSV formats:
|
| 111 |
+
- Columns: qid, question, choice_a, choice_b, choice_c, choice_d
|
| 112 |
+
- Columns: qid, question, option_a, option_b, option_c, option_d
|
| 113 |
+
- Columns: qid, question, A, B, C, D
|
| 114 |
+
- Columns: qid, question, choices (JSON array or comma-separated)
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
file_path: Path to CSV file
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
List of QuestionInput objects
|
| 121 |
+
|
| 122 |
+
Raises:
|
| 123 |
+
FileNotFoundError: If file doesn't exist
|
| 124 |
+
"""
|
| 125 |
+
if not file_path.exists():
|
| 126 |
+
raise FileNotFoundError(f"Test data file not found: {file_path}")
|
| 127 |
+
|
| 128 |
+
questions = []
|
| 129 |
+
with open(file_path, encoding="utf-8") as f:
|
| 130 |
+
reader = csv.DictReader(f)
|
| 131 |
+
for row in reader:
|
| 132 |
+
norm_row = _normalize_row_keys(row)
|
| 133 |
+
|
| 134 |
+
qid = norm_row.get("qid", "").strip()
|
| 135 |
+
question = norm_row.get("question", "").strip()
|
| 136 |
+
|
| 137 |
+
if not qid or not question:
|
| 138 |
+
continue
|
| 139 |
+
|
| 140 |
+
choices = _extract_choices_from_row(norm_row)
|
| 141 |
+
if not choices:
|
| 142 |
+
choices = ["", "", "", ""]
|
| 143 |
+
|
| 144 |
+
questions.append(QuestionInput(
|
| 145 |
+
qid=qid,
|
| 146 |
+
question=question,
|
| 147 |
+
choices=choices,
|
| 148 |
+
answer=norm_row.get("answer", "").strip() or None,
|
| 149 |
+
))
|
| 150 |
+
|
| 151 |
+
return questions
|
src/data_processing/models.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class QuestionInput(BaseModel):
|
| 5 |
+
"""Input schema for a multiple-choice question."""
|
| 6 |
+
|
| 7 |
+
qid: str = Field(description="Question identifier")
|
| 8 |
+
question: str = Field(description="Question text in Vietnamese")
|
| 9 |
+
choices: list[str] = Field(description="List of answer choices")
|
| 10 |
+
answer: str | None = Field(default=None, description="Correct answer (A, B, C, ...)")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class PredictionOutput(BaseModel):
|
| 14 |
+
"""Output schema for a prediction."""
|
| 15 |
+
|
| 16 |
+
qid: str = Field(description="Question identifier")
|
| 17 |
+
answer: str = Field(description="Predicted answer: A, B, C, D, ...")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class InferenceLogEntry(BaseModel):
|
| 21 |
+
"""Schema for JSONL inference log entry (used for checkpointing)."""
|
| 22 |
+
|
| 23 |
+
qid: str = Field(description="Question identifier")
|
| 24 |
+
question: str = Field(description="Original question text")
|
| 25 |
+
choices: list[str] = Field(description="List of answer choices")
|
| 26 |
+
final_answer: str = Field(description="Final predicted answer")
|
| 27 |
+
raw_response: str = Field(default="", description="Raw LLM response")
|
| 28 |
+
route: str = Field(default="unknown", description="Pipeline route taken")
|
| 29 |
+
retrieved_context: str = Field(default="", description="Retrieved context from RAG")
|
src/graph.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LangGraph definition for the RAG pipeline."""
|
| 2 |
+
|
| 3 |
+
from langgraph.graph import END, StateGraph
|
| 4 |
+
|
| 5 |
+
from src.state import GraphState
|
| 6 |
+
from src.nodes.logic import logic_solver_node
|
| 7 |
+
from src.nodes.rag import knowledge_rag_node
|
| 8 |
+
from src.nodes.router import route_question, router_node
|
| 9 |
+
from src.nodes.direct import direct_answer_node
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def build_graph() -> StateGraph:
|
| 13 |
+
"""Build and compile the LangGraph pipeline."""
|
| 14 |
+
|
| 15 |
+
workflow = StateGraph(GraphState)
|
| 16 |
+
|
| 17 |
+
workflow.add_node("router", router_node)
|
| 18 |
+
workflow.add_node("knowledge_rag", knowledge_rag_node)
|
| 19 |
+
workflow.add_node("logic_solver", logic_solver_node)
|
| 20 |
+
workflow.add_node("direct_answer", direct_answer_node)
|
| 21 |
+
|
| 22 |
+
workflow.set_entry_point("router")
|
| 23 |
+
|
| 24 |
+
workflow.add_conditional_edges(
|
| 25 |
+
"router",
|
| 26 |
+
route_question,
|
| 27 |
+
{
|
| 28 |
+
"knowledge_rag": "knowledge_rag",
|
| 29 |
+
"logic_solver": "logic_solver",
|
| 30 |
+
"direct_answer": "direct_answer",
|
| 31 |
+
"__end__": END,
|
| 32 |
+
},
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
workflow.add_edge("knowledge_rag", END)
|
| 36 |
+
workflow.add_edge("logic_solver", END)
|
| 37 |
+
workflow.add_edge("direct_answer", END)
|
| 38 |
+
return workflow.compile()
|
| 39 |
+
|
| 40 |
+
graph = None
|
| 41 |
+
|
| 42 |
+
def get_graph():
|
| 43 |
+
"""Get or create the compiled graph singleton."""
|
| 44 |
+
global graph
|
| 45 |
+
if graph is None:
|
| 46 |
+
graph = build_graph()
|
| 47 |
+
return graph
|
src/nodes/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Node implementations for the LangGraph pipeline."""
|
| 2 |
+
|
| 3 |
+
from src.nodes.direct import direct_answer_node
|
| 4 |
+
from src.nodes.logic import logic_solver_node
|
| 5 |
+
from src.nodes.rag import knowledge_rag_node
|
| 6 |
+
from src.nodes.router import route_question, router_node
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"direct_answer_node",
|
| 10 |
+
"knowledge_rag_node",
|
| 11 |
+
"logic_solver_node",
|
| 12 |
+
"route_question",
|
| 13 |
+
"router_node",
|
| 14 |
+
]
|
| 15 |
+
|
src/nodes/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (526 Bytes). View file
|
|
|
src/nodes/__pycache__/direct.cpython-312.pyc
ADDED
|
Binary file (2.17 kB). View file
|
|
|
src/nodes/__pycache__/logic.cpython-312.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
src/nodes/__pycache__/rag.cpython-312.pyc
ADDED
|
Binary file (6.47 kB). View file
|
|
|
src/nodes/__pycache__/router.cpython-312.pyc
ADDED
|
Binary file (5.9 kB). View file
|
|
|
src/nodes/direct.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Direct Answer node for reading comprehension or general questions without RAG."""
|
| 2 |
+
|
| 3 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 4 |
+
|
| 5 |
+
from src.data_processing.answer import extract_answer
|
| 6 |
+
from src.data_processing.formatting import format_choices
|
| 7 |
+
from src.state import GraphState
|
| 8 |
+
from src.utils.llm import get_large_model
|
| 9 |
+
from src.utils.logging import print_log
|
| 10 |
+
from src.utils.prompts import load_prompt
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def direct_answer_node(state: GraphState) -> dict:
|
| 14 |
+
"""Answer questions directly using Large Model (Skip Retrieval)."""
|
| 15 |
+
print_log(" [Direct] Processing Reading Comprehension/General Question...")
|
| 16 |
+
|
| 17 |
+
all_choices = state["all_choices"]
|
| 18 |
+
choices_text = format_choices(all_choices)
|
| 19 |
+
|
| 20 |
+
llm = get_large_model()
|
| 21 |
+
|
| 22 |
+
system_prompt = load_prompt("direct_answer.j2", "system")
|
| 23 |
+
user_prompt = load_prompt("direct_answer.j2", "user", question=state["question"], choices=choices_text)
|
| 24 |
+
|
| 25 |
+
# Escape curly braces to prevent LangChain from parsing them as variables
|
| 26 |
+
system_prompt = system_prompt.replace("{", "{{").replace("}", "}}")
|
| 27 |
+
user_prompt = user_prompt.replace("{", "{{").replace("}", "}}")
|
| 28 |
+
|
| 29 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 30 |
+
("system", system_prompt),
|
| 31 |
+
("human", user_prompt),
|
| 32 |
+
])
|
| 33 |
+
|
| 34 |
+
chain = prompt | llm
|
| 35 |
+
response = chain.invoke({})
|
| 36 |
+
|
| 37 |
+
content = response.content.strip()
|
| 38 |
+
print_log(f" [Direct] Reasoning: {content}...")
|
| 39 |
+
|
| 40 |
+
answer = extract_answer(content, num_choices=len(all_choices) or 4)
|
| 41 |
+
print_log(f" [Direct] Final Answer: {answer}")
|
| 42 |
+
return {"answer": answer, "raw_response": content}
|
src/nodes/logic.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Logic solver node implementing a Manual Code Execution workflow."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
import string
|
| 5 |
+
|
| 6 |
+
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
| 7 |
+
from langchain_experimental.utilities import PythonREPL
|
| 8 |
+
|
| 9 |
+
from src.data_processing.answer import extract_answer
|
| 10 |
+
from src.data_processing.formatting import format_choices
|
| 11 |
+
from src.state import GraphState
|
| 12 |
+
from src.utils.llm import get_large_model
|
| 13 |
+
from src.utils.logging import print_log
|
| 14 |
+
from src.utils.prompts import load_prompt
|
| 15 |
+
|
| 16 |
+
_python_repl = PythonREPL()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def extract_python_code(text: str) -> str | None:
|
| 20 |
+
"""Find and extract Python code from block ``` python ... ```"""
|
| 21 |
+
match = re.search(r"```(?:python)?\s*(.*?)```", text, re.DOTALL | re.IGNORECASE)
|
| 22 |
+
if match:
|
| 23 |
+
return match.group(1).strip()
|
| 24 |
+
return None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _validate_code_syntax(code: str) -> tuple[bool, str]:
|
| 28 |
+
"""Check if code has valid Python syntax. Returns (is_valid, error_message)."""
|
| 29 |
+
try:
|
| 30 |
+
compile(code, "<string>", "exec")
|
| 31 |
+
return True, ""
|
| 32 |
+
except SyntaxError as e:
|
| 33 |
+
return False, str(e)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _is_placeholder_code(code: str) -> bool:
|
| 37 |
+
"""Check if code contains placeholders or is incomplete."""
|
| 38 |
+
if not code or len(code.strip()) < 10:
|
| 39 |
+
return True
|
| 40 |
+
if "..." in code:
|
| 41 |
+
return True
|
| 42 |
+
# Check for {key}-style placeholders (but not f-string or dict literals)
|
| 43 |
+
if re.search(r"\{[a-zA-Z_][a-zA-Z0-9_]*\}", code):
|
| 44 |
+
# Exclude common dict/set patterns and f-strings
|
| 45 |
+
if not re.search(r'["\'][^"\']*\{[a-zA-Z_]', code):
|
| 46 |
+
return True
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _indent_code(code: str) -> str:
|
| 51 |
+
"""Format code to make it easier to read in the terminal."""
|
| 52 |
+
return "\n".join(f" {line}" for line in code.splitlines())
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _fallback_text_reasoning(llm, question: str, choices_text: str) -> dict:
|
| 56 |
+
"""Fallback to CoT reasoning when code execution fails."""
|
| 57 |
+
print_log(" [Logic] Falling back to CoT reasoning...")
|
| 58 |
+
|
| 59 |
+
fallback_system = (
|
| 60 |
+
"Nhiệm vụ của bạn là trả lời câu hỏi "
|
| 61 |
+
"được đưa ra bằng khả năng phân tích và suy luận logic. "
|
| 62 |
+
"Hãy phân tích vấn đề và suy luận đề từng bước một. "
|
| 63 |
+
"Cuối cùng, hãy trả lời theo đúng định dạng: 'Đáp án: X' "
|
| 64 |
+
"trong đó X là ký tự đại diện cho lựa chọn đúng (A, B, C, D, ...)."
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
fallback_user = (
|
| 68 |
+
f"Câu hỏi: {question}\n"
|
| 69 |
+
f"{choices_text}"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
fallback_messages: list[BaseMessage] = [
|
| 73 |
+
SystemMessage(content=fallback_system),
|
| 74 |
+
HumanMessage(content=fallback_user)
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
fallback_response = llm.invoke(fallback_messages)
|
| 78 |
+
fallback_content = fallback_response.content
|
| 79 |
+
print_log(f" [Logic] Fallback response received.")
|
| 80 |
+
|
| 81 |
+
return {"text": fallback_content}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _request_final_answer(llm, question: str, choices_text: str, computed_results: str) -> str:
|
| 85 |
+
"""Request a strict final answer from the model."""
|
| 86 |
+
system_prompt = (
|
| 87 |
+
"Bạn là trợ lý AI. Dựa vào kết quả tính toán được cung cấp, "
|
| 88 |
+
"hãy đưa ra đáp án cuối cùng. CHỈ trả lời đúng một dòng: Đáp án: X "
|
| 89 |
+
"(trong đó X là A, B, C hoặc D)."
|
| 90 |
+
)
|
| 91 |
+
user_prompt = (
|
| 92 |
+
f"Câu hỏi: {question}\n"
|
| 93 |
+
f"{choices_text}\n"
|
| 94 |
+
f"Kết quả tính toán: {computed_results}\n\n"
|
| 95 |
+
"Trả lời đúng một dòng: Đáp án: X"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
messages: list[BaseMessage] = [
|
| 99 |
+
SystemMessage(content=system_prompt),
|
| 100 |
+
HumanMessage(content=user_prompt)
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
response = llm.invoke(messages)
|
| 104 |
+
return response.content
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def logic_solver_node(state: GraphState) -> dict:
|
| 108 |
+
"""Solve math/logic questions using Python code execution."""
|
| 109 |
+
llm = get_large_model()
|
| 110 |
+
all_choices = state["all_choices"]
|
| 111 |
+
num_choices = len(all_choices)
|
| 112 |
+
choices_text = format_choices(all_choices)
|
| 113 |
+
|
| 114 |
+
system_prompt = load_prompt("logic_solver.j2", "system")
|
| 115 |
+
user_prompt = load_prompt("logic_solver.j2", "user", question=state["question"], choices=choices_text)
|
| 116 |
+
|
| 117 |
+
messages: list[BaseMessage] = [
|
| 118 |
+
SystemMessage(content=system_prompt),
|
| 119 |
+
HumanMessage(content=user_prompt)
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
step_texts: list[str] = []
|
| 123 |
+
computed_outputs: list[str] = []
|
| 124 |
+
|
| 125 |
+
max_steps = 5
|
| 126 |
+
for step in range(max_steps):
|
| 127 |
+
response = llm.invoke(messages)
|
| 128 |
+
content = response.content
|
| 129 |
+
step_texts.append(content)
|
| 130 |
+
messages.append(response)
|
| 131 |
+
|
| 132 |
+
code_block = extract_python_code(content)
|
| 133 |
+
|
| 134 |
+
if code_block:
|
| 135 |
+
if _is_placeholder_code(code_block):
|
| 136 |
+
print_log(f" [Logic] Step {step+1}: Placeholder code detected. Requesting complete code...")
|
| 137 |
+
regen_msg = (
|
| 138 |
+
"Code không hợp lệ (chứa placeholder hoặc không đầy đủ). "
|
| 139 |
+
"Hãy cung cấp code Python hoàn chỉnh, có thể chạy được, không chứa '...' hay placeholder. "
|
| 140 |
+
"In ra các giá trị tính toán được. "
|
| 141 |
+
"Cuối cùng, kết thúc bằng một dòng duy nhất: Đáp án: X (X là A, B, C hoặc D)."
|
| 142 |
+
)
|
| 143 |
+
messages.append(HumanMessage(content=regen_msg))
|
| 144 |
+
continue
|
| 145 |
+
|
| 146 |
+
print_log(f" [Logic] Step {step+1}: Found Python code. Executing...")
|
| 147 |
+
|
| 148 |
+
# Validate syntax before execution
|
| 149 |
+
is_valid, syntax_error = _validate_code_syntax(code_block)
|
| 150 |
+
if not is_valid:
|
| 151 |
+
print_log(f" [Error] Syntax error detected: {syntax_error}")
|
| 152 |
+
error_msg = f"SyntaxError: {syntax_error}. "
|
| 153 |
+
error_msg += "Lưu ý: KHÔNG sử dụng các từ khóa Python như 'lambda', 'class', 'def' làm tên biến. "
|
| 154 |
+
error_msg += "Hãy đổi tên biến và thử lại."
|
| 155 |
+
messages.append(HumanMessage(content=error_msg))
|
| 156 |
+
continue
|
| 157 |
+
|
| 158 |
+
print_log(f" [Logic] Code:\n{_indent_code(code_block)}")
|
| 159 |
+
|
| 160 |
+
try:
|
| 161 |
+
if "print" not in code_block:
|
| 162 |
+
lines = code_block.splitlines()
|
| 163 |
+
if lines:
|
| 164 |
+
last_line = lines[-1]
|
| 165 |
+
if "=" in last_line:
|
| 166 |
+
var_name = last_line.split("=")[0].strip()
|
| 167 |
+
else:
|
| 168 |
+
var_name = last_line.strip()
|
| 169 |
+
code_block += f"\nprint({var_name})"
|
| 170 |
+
|
| 171 |
+
output = _python_repl.run(code_block)
|
| 172 |
+
output = output.strip() if output else "No output."
|
| 173 |
+
print_log(f" [Logic] Code output: {output}")
|
| 174 |
+
computed_outputs.append(output)
|
| 175 |
+
|
| 176 |
+
# Do NOT extract answer from code output directly
|
| 177 |
+
# Instead, feed output back to model and ask for final answer line
|
| 178 |
+
feedback_msg = (
|
| 179 |
+
f"Kết quả thực thi code: {output}\n\n"
|
| 180 |
+
"Dựa vào kết quả trên, hãy so sánh với các đáp án và đưa ra câu trả lời cuối cùng. "
|
| 181 |
+
"Kết thúc bằng đúng một dòng: Đáp án: X (X là A, B, C hoặc D)."
|
| 182 |
+
)
|
| 183 |
+
messages.append(HumanMessage(content=feedback_msg))
|
| 184 |
+
|
| 185 |
+
except Exception as e:
|
| 186 |
+
error_msg = f"Error running code: {str(e)}"
|
| 187 |
+
print_log(f" [Error] {error_msg}")
|
| 188 |
+
messages.append(HumanMessage(content=f"{error_msg}. Hãy kiểm tra logic và sửa lại code."))
|
| 189 |
+
|
| 190 |
+
continue
|
| 191 |
+
|
| 192 |
+
# Check if current step contains an explicit answer (only at end of response)
|
| 193 |
+
step_answer = extract_answer(content, num_choices=num_choices, require_end=True)
|
| 194 |
+
if step_answer:
|
| 195 |
+
print_log(f" [Logic] Step {step+1}: Found explicit answer: {step_answer}")
|
| 196 |
+
combined_raw = "\n---STEP---\n".join(step_texts)
|
| 197 |
+
return {"answer": step_answer, "raw_response": combined_raw, "route": "math"}
|
| 198 |
+
|
| 199 |
+
# Also check if response contains clear conclusion without "Đáp án:" format
|
| 200 |
+
if any(phrase in content.lower() for phrase in ["kết luận", "vậy đáp án", "do đó", "vì vậy"]):
|
| 201 |
+
# Try to extract any single letter at end of response
|
| 202 |
+
lines = content.strip().split('\n')
|
| 203 |
+
for line in reversed(lines[-3:]): # Check last 3 lines
|
| 204 |
+
line = line.strip()
|
| 205 |
+
if len(line) == 1 and line.upper() in string.ascii_uppercase[:num_choices]:
|
| 206 |
+
print_log(f" [Logic] Step {step+1}: Found implicit answer: {line.upper()}")
|
| 207 |
+
combined_raw = "\n---STEP---\n".join(step_texts)
|
| 208 |
+
return {"answer": line.upper(), "raw_response": combined_raw, "route": "math"}
|
| 209 |
+
|
| 210 |
+
if step < max_steps - 1:
|
| 211 |
+
print_log(" [Warning] No code or answer found. Reminding model...")
|
| 212 |
+
messages.append(HumanMessage(content="Lưu ý: Bạn vẫn chưa đưa ra đáp án cuối cùng. Hãy kết thúc bằng: Đáp án: X"))
|
| 213 |
+
|
| 214 |
+
# Max steps reached - build combined_raw and try to extract answer
|
| 215 |
+
print_log(" [Warning] Max steps reached. Attempting answer extraction from combined text...")
|
| 216 |
+
|
| 217 |
+
# Build combined_raw from all steps
|
| 218 |
+
combined_raw = "\n---STEP---\n".join(step_texts) if step_texts else ""
|
| 219 |
+
|
| 220 |
+
# Try fallback text reasoning with error handling
|
| 221 |
+
try:
|
| 222 |
+
fallback_result = _fallback_text_reasoning(llm, state["question"], choices_text)
|
| 223 |
+
fallback_text = fallback_result["text"]
|
| 224 |
+
if fallback_text:
|
| 225 |
+
combined_raw += "\n---FALLBACK---\n" + fallback_text
|
| 226 |
+
except Exception as e:
|
| 227 |
+
print_log(f" [Error] Fallback reasoning failed: {e}")
|
| 228 |
+
fallback_text = ""
|
| 229 |
+
|
| 230 |
+
# Extract answer from the entire combined text (takes LAST explicit answer)
|
| 231 |
+
final_answer = extract_answer(combined_raw, num_choices=num_choices)
|
| 232 |
+
|
| 233 |
+
if final_answer:
|
| 234 |
+
print_log(f" [Logic] Extracted final answer from combined text: {final_answer}")
|
| 235 |
+
return {"answer": final_answer, "raw_response": combined_raw, "route": "math"}
|
| 236 |
+
|
| 237 |
+
# Still no answer - do one final strict LLM call with error handling
|
| 238 |
+
print_log(" [Logic] No explicit answer found. Requesting strict final answer...")
|
| 239 |
+
computed_str = "; ".join(computed_outputs) if computed_outputs else "Không có kết quả tính toán"
|
| 240 |
+
try:
|
| 241 |
+
strict_response = _request_final_answer(llm, state["question"], choices_text, computed_str)
|
| 242 |
+
combined_raw += "\n---FINAL---\n" + strict_response
|
| 243 |
+
|
| 244 |
+
final_answer = extract_answer(strict_response, num_choices=num_choices)
|
| 245 |
+
if final_answer:
|
| 246 |
+
print_log(f" [Logic] Final strict answer: {final_answer}")
|
| 247 |
+
return {"answer": final_answer, "raw_response": combined_raw, "route": "math"}
|
| 248 |
+
except Exception as e:
|
| 249 |
+
print_log(f" [Error] Final answer request failed: {e}")
|
| 250 |
+
|
| 251 |
+
# Absolute fallback - default to A
|
| 252 |
+
print_log(" [Warning] All extraction attempts failed. Defaulting to A.")
|
| 253 |
+
return {"answer": "A", "raw_response": combined_raw, "route": "math"}
|
src/nodes/rag.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RAG node for knowledge-based question answering with Retrieve & Rerank."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 6 |
+
|
| 7 |
+
from src.config import settings
|
| 8 |
+
from src.data_processing.answer import extract_answer
|
| 9 |
+
from src.data_processing.formatting import format_choices
|
| 10 |
+
from src.state import GraphState
|
| 11 |
+
from src.utils.ingestion import get_vector_store
|
| 12 |
+
from src.utils.llm import get_small_model
|
| 13 |
+
from src.utils.logging import print_log
|
| 14 |
+
from src.utils.prompts import load_prompt
|
| 15 |
+
from src.nodes.direct import direct_answer_node
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _rerank_documents(query: str, docs: list, top_k: int = 3) -> list:
|
| 19 |
+
"""Rerank retrieved documents using the small LLM.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
query: The user question
|
| 23 |
+
docs: List of retrieved documents
|
| 24 |
+
top_k: Number of top documents to return after reranking
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
List of reranked documents (top_k most relevant)
|
| 28 |
+
"""
|
| 29 |
+
if len(docs) <= top_k:
|
| 30 |
+
return docs
|
| 31 |
+
|
| 32 |
+
llm = get_small_model()
|
| 33 |
+
|
| 34 |
+
# Build document list for reranking prompt
|
| 35 |
+
doc_list = ""
|
| 36 |
+
for i, doc in enumerate(docs):
|
| 37 |
+
content_preview = doc.page_content[:350].replace("\n", " ")
|
| 38 |
+
doc_list += f"[{i}] {content_preview}...\n\n"
|
| 39 |
+
|
| 40 |
+
rerank_system = (
|
| 41 |
+
"/no_think\n"
|
| 42 |
+
"Bạn là chuyên gia đánh giá độ liên quan của văn bản. "
|
| 43 |
+
"Nhiệm vụ: Chọn ra các đoạn văn bản LIÊN QUAN NHẤT với câu hỏi.\n"
|
| 44 |
+
"Chỉ trả về danh sách các số ID (ví dụ: 0, 3, 5), không giải thích."
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
rerank_user = (
|
| 48 |
+
f"Câu hỏi: {query}\n\n"
|
| 49 |
+
f"Các đoạn văn bản:\n{doc_list}\n"
|
| 50 |
+
f"Hãy chọn {top_k} đoạn văn bản LIÊN QUAN NHẤT với câu hỏi. "
|
| 51 |
+
f"Trả về danh sách ID (số từ 0 đến {len(docs)-1}), cách nhau bởi dấu phẩy."
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 55 |
+
("system", rerank_system),
|
| 56 |
+
("human", rerank_user),
|
| 57 |
+
])
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
chain = prompt | llm
|
| 61 |
+
response = chain.invoke({})
|
| 62 |
+
content = response.content.strip()
|
| 63 |
+
print_log(f" [RAG] Reranker response: {content}")
|
| 64 |
+
|
| 65 |
+
# Parse selected IDs from response
|
| 66 |
+
selected_ids = []
|
| 67 |
+
numbers = re.findall(r'\d+', content)
|
| 68 |
+
for num_str in numbers:
|
| 69 |
+
idx = int(num_str)
|
| 70 |
+
if 0 <= idx < len(docs) and idx not in selected_ids:
|
| 71 |
+
selected_ids.append(idx)
|
| 72 |
+
if len(selected_ids) >= top_k:
|
| 73 |
+
break
|
| 74 |
+
|
| 75 |
+
if selected_ids:
|
| 76 |
+
reranked = [docs[i] for i in selected_ids]
|
| 77 |
+
print_log(f" [RAG] Reranked: selected {len(reranked)} docs from {len(docs)}")
|
| 78 |
+
return reranked
|
| 79 |
+
|
| 80 |
+
print_log(" [RAG] Rerank parsing failed, using first top_k docs")
|
| 81 |
+
return docs[:top_k]
|
| 82 |
+
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print_log(f" [RAG] Reranking failed: {e}. Using keyword boosting fallback.")
|
| 85 |
+
return docs[:top_k]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def knowledge_rag_node(state: GraphState) -> dict:
|
| 90 |
+
"""Retrieve relevant context, rerank, and answer knowledge-based questions."""
|
| 91 |
+
vector_store = get_vector_store()
|
| 92 |
+
query = state["question"]
|
| 93 |
+
print_log(f" [RAG] Retrieving context for: '{query}'")
|
| 94 |
+
|
| 95 |
+
docs = vector_store.similarity_search(query, k=settings.top_k_retrieval)
|
| 96 |
+
print_log(f" [RAG] Retrieved {len(docs)} documents")
|
| 97 |
+
|
| 98 |
+
if not docs:
|
| 99 |
+
print_log(" [Warning] No relevant documents found in Knowledge Base.")
|
| 100 |
+
context = ""
|
| 101 |
+
else:
|
| 102 |
+
reranked_docs = _rerank_documents(query, docs, top_k=settings.top_k_rerank)
|
| 103 |
+
|
| 104 |
+
context = "\n\n---\n\n".join([doc.page_content for doc in reranked_docs])
|
| 105 |
+
|
| 106 |
+
if reranked_docs:
|
| 107 |
+
print_log(f" [RAG] Using {len(reranked_docs)} reranked docs. Top: \"{reranked_docs[0].page_content[:80]}...\"")
|
| 108 |
+
|
| 109 |
+
all_choices = state["all_choices"]
|
| 110 |
+
choices_text = format_choices(all_choices)
|
| 111 |
+
|
| 112 |
+
llm = get_small_model()
|
| 113 |
+
|
| 114 |
+
system_prompt = load_prompt("rag.j2", "system", context=context)
|
| 115 |
+
user_prompt = load_prompt("rag.j2", "user", question=state["question"], choices=choices_text)
|
| 116 |
+
|
| 117 |
+
# Escape curly braces to prevent LangChain from parsing them as variables
|
| 118 |
+
system_prompt = system_prompt.replace("{", "{{").replace("}", "}}")
|
| 119 |
+
user_prompt = user_prompt.replace("{", "{{").replace("}", "}}")
|
| 120 |
+
|
| 121 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 122 |
+
("system", system_prompt),
|
| 123 |
+
("human", user_prompt),
|
| 124 |
+
])
|
| 125 |
+
|
| 126 |
+
chain = prompt | llm
|
| 127 |
+
response = chain.invoke({})
|
| 128 |
+
content = response.content.strip()
|
| 129 |
+
print_log(f" [RAG] Reasoning: {content}")
|
| 130 |
+
|
| 131 |
+
answer = extract_answer(content, num_choices=len(all_choices) or 4)
|
| 132 |
+
print_log(f" [RAG] Final Answer: {answer}")
|
| 133 |
+
|
| 134 |
+
# Fallback to direct mode if RAG context was not helpful
|
| 135 |
+
if answer is None:
|
| 136 |
+
print_log(" [RAG] Context not relevant, falling back to direct mode...")
|
| 137 |
+
direct_result = direct_answer_node(state)
|
| 138 |
+
direct_result["route"] = "rag->direct" # Track the fallback
|
| 139 |
+
return direct_result
|
| 140 |
+
|
| 141 |
+
return {"answer": answer, "context": context, "raw_response": content}
|
src/nodes/router.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Router node for classifying questions and directing to appropriate handlers."""
|
| 2 |
+
|
| 3 |
+
import string
|
| 4 |
+
from typing import Literal
|
| 5 |
+
|
| 6 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 7 |
+
|
| 8 |
+
from src.data_processing.formatting import format_choices
|
| 9 |
+
from src.state import GraphState
|
| 10 |
+
from src.utils.llm import get_small_model
|
| 11 |
+
from src.utils.logging import print_log
|
| 12 |
+
from src.utils.prompts import load_prompt
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _find_refusal_option(state: GraphState) -> str | None:
|
| 16 |
+
"""Find refusal option in choices and return corresponding letter."""
|
| 17 |
+
all_choices = state["all_choices"]
|
| 18 |
+
option_labels = list(string.ascii_uppercase[:len(all_choices)])
|
| 19 |
+
|
| 20 |
+
refusal_patterns = [
|
| 21 |
+
"tôi không thể", "không thể trả lời", "không thể cung cấp", "không thể chia sẻ",
|
| 22 |
+
"từ chối trả lời", "từ chối cung cấp",
|
| 23 |
+
"nằm ngoài phạm vi", "không thuộc phạm vi", "tôi là mô hình ngôn ngữ",
|
| 24 |
+
"hành vi vi phạm", "trái pháp luật", "không hỗ trợ",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
for i, choice in enumerate(all_choices):
|
| 28 |
+
txt = choice.lower().strip()
|
| 29 |
+
if any(p in txt for p in refusal_patterns):
|
| 30 |
+
return option_labels[i]
|
| 31 |
+
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _classify_with_llm(state: GraphState) -> str:
|
| 36 |
+
"""Classify question using LLM."""
|
| 37 |
+
choices_text = format_choices(state["all_choices"])
|
| 38 |
+
llm = get_small_model()
|
| 39 |
+
|
| 40 |
+
system_prompt = load_prompt("router.j2", "system")
|
| 41 |
+
user_prompt = load_prompt("router.j2", "user", question=state["question"], choices=choices_text)
|
| 42 |
+
|
| 43 |
+
# Escape curly braces to prevent LangChain from parsing them as variables
|
| 44 |
+
system_prompt = system_prompt.replace("{", "{{").replace("}", "}}")
|
| 45 |
+
user_prompt = user_prompt.replace("{", "{{").replace("}", "}}")
|
| 46 |
+
|
| 47 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 48 |
+
("system", system_prompt),
|
| 49 |
+
("human", user_prompt),
|
| 50 |
+
])
|
| 51 |
+
chain = prompt | llm
|
| 52 |
+
response = chain.invoke({})
|
| 53 |
+
return response.content.strip().lower()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def router_node(state: GraphState) -> dict:
|
| 57 |
+
"""Analyze question and determine routing path. Returns answer immediately for toxic content."""
|
| 58 |
+
question = state["question"].lower()
|
| 59 |
+
|
| 60 |
+
# Fast-track: Direct answer for reading comprehension
|
| 61 |
+
direct_keywords = ["đoạn thông tin", "đoạn văn", "bài đọc", "căn cứ vào đoạn", "theo đoạn"]
|
| 62 |
+
if any(k in question for k in direct_keywords) and len(question.split()) > 50:
|
| 63 |
+
print_log(" [Router] Fast-track: Direct Answer (Found Context block)")
|
| 64 |
+
return {"route": "direct"}
|
| 65 |
+
|
| 66 |
+
# Fast-track: Math/Logic for LaTeX or math keywords
|
| 67 |
+
math_signals = [
|
| 68 |
+
"$", "\\frac", "^", "=", "tính giá trị", "biểu thức", "phương trình",
|
| 69 |
+
"hàm số", "đạo hàm", "xác suất", "lãi suất", "vận tốc", "gia tốc",
|
| 70 |
+
"điện trở", "gam", "mol", "nguyên tử khối", "gdp", "lạm phát", "công suất"
|
| 71 |
+
]
|
| 72 |
+
if any(s in question for s in math_signals):
|
| 73 |
+
print_log(" [Router] Fast-track: Math (Keywords/LaTeX detected)")
|
| 74 |
+
return {"route": "math"}
|
| 75 |
+
|
| 76 |
+
print_log(" [Router] Slow-track: Using LLM to classify...")
|
| 77 |
+
try:
|
| 78 |
+
route = _classify_with_llm(state)
|
| 79 |
+
print_log(f" [Router] LLM Decision: {route}")
|
| 80 |
+
|
| 81 |
+
if "direct" in route:
|
| 82 |
+
route_type = "direct"
|
| 83 |
+
elif "math" in route or "logic" in route:
|
| 84 |
+
route_type = "math"
|
| 85 |
+
elif "toxic" in route:
|
| 86 |
+
refusal_answer = _find_refusal_option(state)
|
| 87 |
+
if refusal_answer:
|
| 88 |
+
print_log(f" [Router] Toxic detected, found refusal option: {refusal_answer}")
|
| 89 |
+
return {"route": "toxic", "answer": refusal_answer}
|
| 90 |
+
print_log(" [Router] Toxic detected, no refusal option found, defaulting to A")
|
| 91 |
+
return {"route": "toxic", "answer": "A"}
|
| 92 |
+
else:
|
| 93 |
+
route_type = "rag"
|
| 94 |
+
|
| 95 |
+
return {"route": route_type}
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print_log(f" [Router] Error: {e}. Fallback to RAG.")
|
| 98 |
+
return {"route": "rag"}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def route_question(state: GraphState) -> Literal["knowledge_rag", "logic_solver", "direct_answer", "__end__"]:
|
| 102 |
+
"""Conditional edge function to route to appropriate node based on state route."""
|
| 103 |
+
route = state.get("route", "rag")
|
| 104 |
+
answer = state.get("answer")
|
| 105 |
+
|
| 106 |
+
if route == "toxic":
|
| 107 |
+
return "__end__"
|
| 108 |
+
if route == "direct":
|
| 109 |
+
return "direct_answer"
|
| 110 |
+
if route == "math":
|
| 111 |
+
return "logic_solver"
|
| 112 |
+
return "knowledge_rag"
|
src/pipeline.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Core pipeline execution logic for the RAG system."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import csv
|
| 5 |
+
import sys
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
from src.config import BATCH_SIZE, DATA_OUTPUT_DIR
|
| 10 |
+
from src.data_processing.answer import normalize_answer
|
| 11 |
+
from src.data_processing.formatting import format_choices_display, question_to_state
|
| 12 |
+
from src.data_processing.models import InferenceLogEntry, PredictionOutput, QuestionInput
|
| 13 |
+
from src.graph import get_graph
|
| 14 |
+
from src.utils.checkpointing import (
|
| 15 |
+
append_log_entry,
|
| 16 |
+
consolidate_log_file,
|
| 17 |
+
generate_csv_from_log,
|
| 18 |
+
is_rate_limit_error,
|
| 19 |
+
)
|
| 20 |
+
from src.utils.common import sort_qids
|
| 21 |
+
from src.utils.ingestion import get_vector_store
|
| 22 |
+
from src.utils.logging import log_done, log_pipeline, log_stats, print_log
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def sort_questions_by_qid(questions: list[QuestionInput]) -> list[QuestionInput]:
|
| 26 |
+
"""Sort questions by qid using natural sorting."""
|
| 27 |
+
qid_to_question = {q.qid: q for q in questions}
|
| 28 |
+
sorted_qids = sort_qids(list(qid_to_question.keys()))
|
| 29 |
+
return [qid_to_question[qid] for qid in sorted_qids]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
async def run_pipeline_async(
|
| 33 |
+
questions: list[QuestionInput],
|
| 34 |
+
batch_size: int = BATCH_SIZE,
|
| 35 |
+
) -> list[PredictionOutput]:
|
| 36 |
+
"""Run pipeline for inference (assumes pre-built Vector DB).
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
questions: List of questions to process
|
| 40 |
+
batch_size: Number of concurrent questions to process
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
List of PredictionOutput objects sorted by qid
|
| 44 |
+
"""
|
| 45 |
+
log_pipeline("Loading pre-built vector store...")
|
| 46 |
+
get_vector_store()
|
| 47 |
+
|
| 48 |
+
questions = sort_questions_by_qid(questions)
|
| 49 |
+
|
| 50 |
+
graph = get_graph()
|
| 51 |
+
total = len(questions)
|
| 52 |
+
start_time = time.perf_counter()
|
| 53 |
+
|
| 54 |
+
sem = asyncio.Semaphore(batch_size)
|
| 55 |
+
results: dict[str, PredictionOutput] = {}
|
| 56 |
+
|
| 57 |
+
async def process_single_question(q: QuestionInput) -> None:
|
| 58 |
+
async with sem:
|
| 59 |
+
print_log(f"\n[{q.qid}] {q.question}")
|
| 60 |
+
print_log(format_choices_display(q.choices))
|
| 61 |
+
state = question_to_state(q)
|
| 62 |
+
result = await graph.ainvoke(state)
|
| 63 |
+
|
| 64 |
+
answer = result.get("answer", "A")
|
| 65 |
+
route = result.get("route", "unknown")
|
| 66 |
+
num_choices = len(q.choices)
|
| 67 |
+
|
| 68 |
+
normalized_answer = normalize_answer(
|
| 69 |
+
answer=answer,
|
| 70 |
+
num_choices=num_choices,
|
| 71 |
+
question_id=q.qid,
|
| 72 |
+
default="A",
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
log_done(f"{q.qid}: {normalized_answer} (Route: {route})")
|
| 76 |
+
results[q.qid] = PredictionOutput(qid=q.qid, answer=normalized_answer)
|
| 77 |
+
|
| 78 |
+
tasks = [process_single_question(q) for q in questions]
|
| 79 |
+
await asyncio.gather(*tasks)
|
| 80 |
+
|
| 81 |
+
elapsed = time.perf_counter() - start_time
|
| 82 |
+
throughput = total / elapsed if elapsed > 0 else 0
|
| 83 |
+
log_stats(f"Completed {total} questions in {elapsed:.2f}s ({throughput:.2f} req/s)")
|
| 84 |
+
|
| 85 |
+
sorted_qids = sort_qids(list(results.keys()))
|
| 86 |
+
return [results[qid] for qid in sorted_qids]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
async def run_pipeline_with_checkpointing(
|
| 90 |
+
questions: list[QuestionInput],
|
| 91 |
+
log_path: Path,
|
| 92 |
+
batch_size: int = BATCH_SIZE,
|
| 93 |
+
) -> int:
|
| 94 |
+
"""Run pipeline with JSONL checkpointing for resume capability.
|
| 95 |
+
|
| 96 |
+
Questions are processed in qid order. Results are appended to log file
|
| 97 |
+
immediately for fault tolerance, then consolidated at the end.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
questions: List of questions to process (already filtered for unprocessed)
|
| 101 |
+
log_path: Path to JSONL log file for checkpointing
|
| 102 |
+
batch_size: Number of concurrent questions to process
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
Count of newly processed questions
|
| 106 |
+
"""
|
| 107 |
+
log_pipeline("Loading pre-built vector store...")
|
| 108 |
+
get_vector_store()
|
| 109 |
+
|
| 110 |
+
questions = sort_questions_by_qid(questions)
|
| 111 |
+
log_pipeline(f"Processing {len(questions)} questions in qid order...")
|
| 112 |
+
|
| 113 |
+
graph = get_graph()
|
| 114 |
+
total = len(questions)
|
| 115 |
+
start_time = time.perf_counter()
|
| 116 |
+
processed_count = 0
|
| 117 |
+
|
| 118 |
+
sem = asyncio.Semaphore(batch_size)
|
| 119 |
+
stop_event = asyncio.Event()
|
| 120 |
+
|
| 121 |
+
async def process_single_question(q: QuestionInput) -> None:
|
| 122 |
+
nonlocal processed_count
|
| 123 |
+
if stop_event.is_set():
|
| 124 |
+
return
|
| 125 |
+
|
| 126 |
+
async with sem:
|
| 127 |
+
if stop_event.is_set():
|
| 128 |
+
return
|
| 129 |
+
print_log(f"\n[{q.qid}] {q.question}")
|
| 130 |
+
print_log(format_choices_display(q.choices))
|
| 131 |
+
state = question_to_state(q)
|
| 132 |
+
|
| 133 |
+
try:
|
| 134 |
+
result = await graph.ainvoke(state)
|
| 135 |
+
route = result.get("route", "unknown")
|
| 136 |
+
raw_response = result.get("raw_response", "")
|
| 137 |
+
context = result.get("context", "")
|
| 138 |
+
|
| 139 |
+
answer = normalize_answer(
|
| 140 |
+
answer=result.get("answer"),
|
| 141 |
+
num_choices=len(q.choices),
|
| 142 |
+
question_id=q.qid,
|
| 143 |
+
default="A",
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
log_entry = InferenceLogEntry(
|
| 147 |
+
qid=q.qid,
|
| 148 |
+
question=q.question,
|
| 149 |
+
choices=q.choices,
|
| 150 |
+
final_answer=answer,
|
| 151 |
+
raw_response=raw_response,
|
| 152 |
+
route=route,
|
| 153 |
+
retrieved_context=context,
|
| 154 |
+
)
|
| 155 |
+
await append_log_entry(log_path, log_entry)
|
| 156 |
+
|
| 157 |
+
log_done(f"{q.qid}: {answer} (Route: {route})")
|
| 158 |
+
processed_count += 1
|
| 159 |
+
# await asyncio.sleep(150)
|
| 160 |
+
|
| 161 |
+
except Exception as e:
|
| 162 |
+
if is_rate_limit_error(e):
|
| 163 |
+
print_log(f" [CRITICAL] Rate Limit Detected on {q.qid}: {e}")
|
| 164 |
+
stop_event.set()
|
| 165 |
+
else:
|
| 166 |
+
print_log(f" [Error] Failed to process {q.qid}: {e}")
|
| 167 |
+
|
| 168 |
+
tasks = [asyncio.create_task(process_single_question(q)) for q in questions]
|
| 169 |
+
await asyncio.gather(*tasks)
|
| 170 |
+
|
| 171 |
+
if stop_event.is_set():
|
| 172 |
+
log_pipeline("!!! PIPELINE STOPPED DUE TO RATE LIMIT !!!")
|
| 173 |
+
log_pipeline("Consolidating logs and generating emergency submission...")
|
| 174 |
+
consolidate_log_file(log_path)
|
| 175 |
+
|
| 176 |
+
output_file = DATA_OUTPUT_DIR / "submission_emergency.csv"
|
| 177 |
+
total_entries = generate_csv_from_log(log_path, output_file)
|
| 178 |
+
log_pipeline(f"Saved emergency submission with {total_entries} entries to: {output_file}")
|
| 179 |
+
|
| 180 |
+
sys.exit(0)
|
| 181 |
+
|
| 182 |
+
log_pipeline("Consolidating log file...")
|
| 183 |
+
consolidate_log_file(log_path)
|
| 184 |
+
|
| 185 |
+
elapsed = time.perf_counter() - start_time
|
| 186 |
+
throughput = total / elapsed if elapsed > 0 else 0
|
| 187 |
+
log_stats(f"Processed {processed_count}/{total} questions in {elapsed:.2f}s ({throughput:.2f} req/s)")
|
| 188 |
+
|
| 189 |
+
return processed_count
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def save_predictions(
|
| 193 |
+
predictions: list[PredictionOutput],
|
| 194 |
+
output_path: Path,
|
| 195 |
+
ensure_dir: bool = True,
|
| 196 |
+
) -> None:
|
| 197 |
+
"""Save predictions to CSV file, sorted by qid.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
predictions: List of prediction outputs
|
| 201 |
+
output_path: Path to output CSV file
|
| 202 |
+
ensure_dir: If True, create parent directory if it doesn't exist
|
| 203 |
+
"""
|
| 204 |
+
if ensure_dir:
|
| 205 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 206 |
+
|
| 207 |
+
sorted_qids = sort_qids([p.qid for p in predictions])
|
| 208 |
+
pred_dict = {p.qid: p for p in predictions}
|
| 209 |
+
|
| 210 |
+
with open(output_path, "w", newline="", encoding="utf-8") as f:
|
| 211 |
+
writer = csv.DictWriter(f, fieldnames=["qid", "answer"])
|
| 212 |
+
writer.writeheader()
|
| 213 |
+
for qid in sorted_qids:
|
| 214 |
+
writer.writerow({"qid": qid, "answer": pred_dict[qid].answer})
|
| 215 |
+
log_pipeline(f"Predictions saved to: {output_path}")
|
src/state.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""State schema definitions for the RAG pipeline graph."""
|
| 2 |
+
|
| 3 |
+
from typing import TypedDict
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class GraphState(TypedDict, total=False):
|
| 7 |
+
"""State schema for the RAG pipeline graph."""
|
| 8 |
+
|
| 9 |
+
question_id: str
|
| 10 |
+
question: str
|
| 11 |
+
all_choices: list[str]
|
| 12 |
+
route: str
|
| 13 |
+
context: str
|
| 14 |
+
answer: str
|
| 15 |
+
raw_response: str
|
| 16 |
+
|
src/templates/direct_answer.j2
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{# Direct Answer Node Prompt Templates #}
|
| 2 |
+
{% block system %}
|
| 3 |
+
/no_think
|
| 4 |
+
Bạn là một chuyên gia trả lời câu hỏi trắc nghiệm. Nhiệm vụ của bạn là phân tích và chọn đáp án đúng nhất cho câu hỏi.
|
| 5 |
+
|
| 6 |
+
NGÔN NGỮ: Toàn bộ suy luận, giải thích PHẢI bằng TIẾNG VIỆT 100%. KHÔNG dùng tiếng Anh.
|
| 7 |
+
|
| 8 |
+
Lưu ý:
|
| 9 |
+
1. Nếu đề bài có đoạn văn, CHỈ dựa vào đoạn văn đó để suy luận.
|
| 10 |
+
2. Suy luận từng bước logic.
|
| 11 |
+
- Với câu hỏi về ngày tháng, con số: So sánh chính xác từng ký tự.
|
| 12 |
+
- Nếu câu hỏi yêu cầu tìm từ sai/đúng: Đối chiếu từng phương án với văn bản.
|
| 13 |
+
3. Trả lời bằng: "Đáp án: X" (X là một trong các lựa chọn A, B, C, D, ...).
|
| 14 |
+
{% endblock %}
|
| 15 |
+
|
| 16 |
+
{% block user %}
|
| 17 |
+
Câu hỏi: {{ question }}
|
| 18 |
+
{{ choices }}
|
| 19 |
+
{% endblock %}
|
src/templates/logic_solver.j2
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{# Logic Solver (Code Agent) Prompt Templates #}
|
| 2 |
+
{% block system %}
|
| 3 |
+
/no_think
|
| 4 |
+
Bạn là chuyên gia giải toán và logic. Trả lời NGẮN GỌN, SÚNG TÍCH.
|
| 5 |
+
|
| 6 |
+
NGÔN NGỮ: Toàn bộ suy luận, giải thích PHẢI bằng TIẾNG VIỆT 100%. KHÔNG dùng tiếng Anh.
|
| 7 |
+
|
| 8 |
+
QUY TẮC:
|
| 9 |
+
1. Suy luận ngắn gọn, đi thẳng vào vấn đề
|
| 10 |
+
2. Chỉ nêu các bước quan trọng, bỏ qua chi tiết thừa
|
| 11 |
+
3. Dòng "Đáp án: X" PHẢI là dòng CUỐI CÙNG
|
| 12 |
+
4. Tối đa 5-7 dòng suy luận
|
| 13 |
+
|
| 14 |
+
CẤU TRÚC:
|
| 15 |
+
1. Phân tích ngắn gọn
|
| 16 |
+
2. Suy luận chính
|
| 17 |
+
3. Kết luận
|
| 18 |
+
4. Đáp án: X
|
| 19 |
+
|
| 20 |
+
VÍ DỤ TỐT (NGẮN GỌN):
|
| 21 |
+
```
|
| 22 |
+
Phân tích: Tính 2 + 3 * 4
|
| 23 |
+
Thứ tự: 3 * 4 = 12, sau đó 2 + 12 = 14
|
| 24 |
+
Kết luận: 14 tương ứng đáp án B
|
| 25 |
+
|
| 26 |
+
Đáp án: B
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
NHẮC LẠI: NGẮN GỌN, SÚNG TÍCH! Chỉ 5-7 dòng! TIẾNG VIỆT 100%!
|
| 30 |
+
{% endblock %}
|
| 31 |
+
|
| 32 |
+
{% block user %}
|
| 33 |
+
{{ question }}
|
| 34 |
+
{{ choices }}
|
| 35 |
+
|
| 36 |
+
Suy luận ngắn gọn:
|
| 37 |
+
{% endblock %}
|
src/templates/rag.j2
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{# RAG Node Prompt Templates #}
|
| 2 |
+
{% block system %}
|
| 3 |
+
/no_think
|
| 4 |
+
Bạn là một chuyên gia phân tích thông tin và đọc hiểu văn bản chính xác tuyệt đối.
|
| 5 |
+
Nhiệm vụ: Trả lời câu hỏi trắc nghiệm CHỈ dựa trên thông tin trong phần Văn bản được cung cấp bên dưới.
|
| 6 |
+
|
| 7 |
+
NGÔN NGỮ: Toàn bộ suy luận, giải thích PHẢI bằng TIẾNG VIỆT 100%. KHÔNG dùng tiếng Anh.
|
| 8 |
+
|
| 9 |
+
Văn bản:
|
| 10 |
+
{{ context }}
|
| 11 |
+
|
| 12 |
+
Quy tắc bắt buộc:
|
| 13 |
+
1. Đọc kỹ văn bản để tìm các từ khóa liên quan đến câu hỏi.
|
| 14 |
+
2. So sánh từng lựa chọn với thông tin tìm được trong văn bản.
|
| 15 |
+
3. Suy luận từng bước:
|
| 16 |
+
- Nếu văn bản chứa câu trả lời trực tiếp: Trích dẫn ý đó để xác nhận.
|
| 17 |
+
- Nếu văn bản KHÔNG chứa câu trả lời trực tiếp: Sử dụng phương pháp loại trừ các đáp án sai để chọn đáp án phù hợp và
|
| 18 |
+
đúng nhất.
|
| 19 |
+
4. Trả lời cuối cùng theo định dạng: "Đáp án: X" (trong đó X là ký tự lựa chọn). Ví dụ: "Đáp án: A"
|
| 20 |
+
{% endblock %}
|
| 21 |
+
|
| 22 |
+
{% block user %}
|
| 23 |
+
Câu hỏi: {{ question }}
|
| 24 |
+
{{ choices }}
|
| 25 |
+
{% endblock %}
|
src/templates/router.j2
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{# Router Node Prompt Templates #}
|
| 2 |
+
{% block system %}
|
| 3 |
+
/no_think
|
| 4 |
+
Bạn là hệ thống phân loại câu hỏi tiếng Việt. LUÔN trả lời bằng TIẾNG VIỆT.
|
| 5 |
+
|
| 6 |
+
Nhiệm vụ: Phân loại câu hỏi vào duy nhất 1 trong 4 nhóm: "toxic", "direct", "math", hoặc "rag".
|
| 7 |
+
|
| 8 |
+
Hãy thực hiện theo quy trình kiểm tra thứ tự ưu tiên sau đây (QUAN TRỌNG):
|
| 9 |
+
|
| 10 |
+
Ưu tiên 1: Kiểm tra "toxic" (An toàn là trên hết)
|
| 11 |
+
- Nếu câu hỏi yêu cầu hướng dẫn thực hiện hành vi vi phạm pháp luật (trốn thuế, làm giả giấy tờ, tham nhũng, buôn lậu,
|
| 12 |
+
chế tạo vũ khí...).
|
| 13 |
+
- Nếu câu hỏi mang tính chất phản động, chống phá nhà nước, bôi nhọ lãnh tụ, hoặc vi phạm thuần phong mỹ tục.
|
| 14 |
+
-> Trả về: toxic
|
| 15 |
+
|
| 16 |
+
Ưu tiên 2: Kiểm tra "direct" (Đọc hiểu văn bản có sẵn)
|
| 17 |
+
- Hãy nhìn vào dữ liệu đầu vào. Nếu nó chứa các từ khóa đánh dấu văn bản như: "Đoạn thông tin:", "Văn bản:", "Document",
|
| 18 |
+
"Title:", "Nội dung:", hoặc một đoạn văn dài đi kèm trước câu hỏi.
|
| 19 |
+
- Bất kể nội dung là Lịch sử hay Khoa học, nếu ĐÃ CÓ đoạn văn bản đi kèm để trả lời -> Phải chọn nhóm này.
|
| 20 |
+
-> Trả về: direct
|
| 21 |
+
|
| 22 |
+
Ưu tiên 3: Kiểm tra "math" (Tư duy logic & Tính toán & Lập trình)
|
| 23 |
+
- Các bài tập Toán, Lý, Hóa, Sinh yêu cầu tính toán ra số liệu cụ thể (không phải lý thuyết suông).
|
| 24 |
+
- Các câu hỏi chứa công thức toán học (LaTeX, dấu $, phương trình).
|
| 25 |
+
- Các câu hỏi về Lập trình.
|
| 26 |
+
- Các câu hỏi tư duy logic, chuỗi số, xác suất thống kê.
|
| 27 |
+
-> Trả về: math
|
| 28 |
+
|
| 29 |
+
Ưu tiên 4: Kiểm tra "rag" (Tra cứu kiến thức)
|
| 30 |
+
- Các câu hỏi kiến thức về Lịch sử, Địa lý, Luật pháp, Văn hóa, Xã hội.
|
| 31 |
+
- Các câu hỏi lý thuyết khoa học (không cần tính toán).
|
| 32 |
+
- Câu hỏi mà KHÔNG CÓ đoạn văn bản đi kèm.
|
| 33 |
+
-> Trả về: rag
|
| 34 |
+
|
| 35 |
+
QUAN TRỌNG: Chỉ trả về đúng 1 từ kết quả (toxic/direct/math/rag). Không giải thích thêm.
|
| 36 |
+
{% endblock %}
|
| 37 |
+
|
| 38 |
+
{% block user %}
|
| 39 |
+
Câu hỏi: {{ question }}
|
| 40 |
+
{{ choices }}
|
| 41 |
+
|
| 42 |
+
Kết quả phân loại (chỉ 1 từ):
|
| 43 |
+
{% endblock %}
|
src/utils/__init__.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility functions for the RAG pipeline."""
|
| 2 |
+
|
| 3 |
+
from src.utils.checkpointing import (
|
| 4 |
+
append_log_entry,
|
| 5 |
+
consolidate_log_file,
|
| 6 |
+
generate_csv_from_log,
|
| 7 |
+
is_rate_limit_error,
|
| 8 |
+
load_log_entries,
|
| 9 |
+
load_processed_qids,
|
| 10 |
+
)
|
| 11 |
+
from src.utils.common import normalize_text, remove_diacritics, sort_qids
|
| 12 |
+
from src.utils.ingestion import (
|
| 13 |
+
get_embeddings,
|
| 14 |
+
get_qdrant_client,
|
| 15 |
+
get_vector_store,
|
| 16 |
+
ingest_all_data,
|
| 17 |
+
ingest_files,
|
| 18 |
+
)
|
| 19 |
+
from src.utils.llm import get_large_model, get_small_model
|
| 20 |
+
from src.utils.web_crawler import WebCrawler, crawl_website, save_crawled_data
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
# Checkpointing
|
| 24 |
+
"load_processed_qids",
|
| 25 |
+
"load_log_entries",
|
| 26 |
+
"append_log_entry",
|
| 27 |
+
"consolidate_log_file",
|
| 28 |
+
"generate_csv_from_log",
|
| 29 |
+
"is_rate_limit_error",
|
| 30 |
+
"sort_qids",
|
| 31 |
+
# Ingestion
|
| 32 |
+
"get_embeddings",
|
| 33 |
+
"get_qdrant_client",
|
| 34 |
+
"get_vector_store",
|
| 35 |
+
"ingest_all_data",
|
| 36 |
+
"ingest_files",
|
| 37 |
+
# LLM
|
| 38 |
+
"get_small_model",
|
| 39 |
+
"get_large_model",
|
| 40 |
+
# Text utilities
|
| 41 |
+
"normalize_text",
|
| 42 |
+
"remove_diacritics",
|
| 43 |
+
# Web crawler
|
| 44 |
+
"WebCrawler",
|
| 45 |
+
"crawl_website",
|
| 46 |
+
"save_crawled_data",
|
| 47 |
+
]
|
src/utils/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.03 kB). View file
|
|
|
src/utils/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (1.03 kB). View file
|
|
|
src/utils/__pycache__/checkpointing.cpython-312.pyc
ADDED
|
Binary file (5.38 kB). View file
|
|
|
src/utils/__pycache__/checkpointing.cpython-314.pyc
ADDED
|
Binary file (6.5 kB). View file
|
|
|
src/utils/__pycache__/common.cpython-312.pyc
ADDED
|
Binary file (3.31 kB). View file
|
|
|