quanho114 commited on
Commit
ebb8326
·
1 Parent(s): 0d3f194

Deploy VietQA API

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +18 -0
  2. api.py +164 -0
  3. requirements-prod.txt +28 -0
  4. src/__init__.py +2 -0
  5. src/__pycache__/__init__.cpython-312.pyc +0 -0
  6. src/__pycache__/__init__.cpython-314.pyc +0 -0
  7. src/__pycache__/config.cpython-312.pyc +0 -0
  8. src/__pycache__/config.cpython-314.pyc +0 -0
  9. src/__pycache__/graph.cpython-312.pyc +0 -0
  10. src/__pycache__/pipeline.cpython-312.pyc +0 -0
  11. src/__pycache__/state.cpython-312.pyc +0 -0
  12. src/config.py +110 -0
  13. src/data_processing/__init__.py +26 -0
  14. src/data_processing/__pycache__/__init__.cpython-312.pyc +0 -0
  15. src/data_processing/__pycache__/__init__.cpython-314.pyc +0 -0
  16. src/data_processing/__pycache__/answer.cpython-312.pyc +0 -0
  17. src/data_processing/__pycache__/answer.cpython-314.pyc +0 -0
  18. src/data_processing/__pycache__/formatting.cpython-312.pyc +0 -0
  19. src/data_processing/__pycache__/formatting.cpython-314.pyc +0 -0
  20. src/data_processing/__pycache__/loaders.cpython-312.pyc +0 -0
  21. src/data_processing/__pycache__/loaders.cpython-314.pyc +0 -0
  22. src/data_processing/__pycache__/models.cpython-312.pyc +0 -0
  23. src/data_processing/__pycache__/models.cpython-314.pyc +0 -0
  24. src/data_processing/answer.py +151 -0
  25. src/data_processing/formatting.py +37 -0
  26. src/data_processing/loaders.py +151 -0
  27. src/data_processing/models.py +29 -0
  28. src/graph.py +47 -0
  29. src/nodes/__init__.py +15 -0
  30. src/nodes/__pycache__/__init__.cpython-312.pyc +0 -0
  31. src/nodes/__pycache__/direct.cpython-312.pyc +0 -0
  32. src/nodes/__pycache__/logic.cpython-312.pyc +0 -0
  33. src/nodes/__pycache__/rag.cpython-312.pyc +0 -0
  34. src/nodes/__pycache__/router.cpython-312.pyc +0 -0
  35. src/nodes/direct.py +42 -0
  36. src/nodes/logic.py +253 -0
  37. src/nodes/rag.py +141 -0
  38. src/nodes/router.py +112 -0
  39. src/pipeline.py +215 -0
  40. src/state.py +16 -0
  41. src/templates/direct_answer.j2 +19 -0
  42. src/templates/logic_solver.j2 +37 -0
  43. src/templates/rag.j2 +25 -0
  44. src/templates/router.j2 +43 -0
  45. src/utils/__init__.py +47 -0
  46. src/utils/__pycache__/__init__.cpython-312.pyc +0 -0
  47. src/utils/__pycache__/__init__.cpython-314.pyc +0 -0
  48. src/utils/__pycache__/checkpointing.cpython-312.pyc +0 -0
  49. src/utils/__pycache__/checkpointing.cpython-314.pyc +0 -0
  50. src/utils/__pycache__/common.cpython-312.pyc +0 -0
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile for Hugging Face Spaces
2
+ FROM python:3.11-slim
3
+
4
+ WORKDIR /app
5
+
6
+ # Install dependencies
7
+ COPY requirements-prod.txt .
8
+ RUN pip install --no-cache-dir -r requirements-prod.txt
9
+
10
+ # Copy application code
11
+ COPY api.py .
12
+ COPY src/ ./src/
13
+
14
+ # Expose port (HF Spaces uses port 7860)
15
+ EXPOSE 7860
16
+
17
+ # Run the API
18
+ CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
api.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI Backend for VietQA Multi-Agent System."""
2
+
3
+ from contextlib import asynccontextmanager
4
+ from fastapi import FastAPI, HTTPException
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from pydantic import BaseModel
7
+
8
+ from src.data_processing.models import QuestionInput
9
+ from src.data_processing.formatting import question_to_state
10
+ from src.data_processing.answer import normalize_answer
11
+ from src.graph import get_graph
12
+ from src.utils.llm import set_large_model_override, get_available_large_models
13
+
14
+
15
+ @asynccontextmanager
16
+ async def lifespan(app: FastAPI):
17
+ """Fast startup - lazy load models on first request."""
18
+ print("[Startup] Server starting (models will load on first request)...")
19
+ print("[Startup] Server ready!")
20
+ yield
21
+
22
+
23
+ app = FastAPI(
24
+ title="VietQA Multi-Agent API",
25
+ description="API cho hệ thống trả lời câu hỏi trắc nghiệm tiếng Việt",
26
+ version="1.0.0",
27
+ lifespan=lifespan
28
+ )
29
+
30
+ # CORS for frontend
31
+ app.add_middleware(
32
+ CORSMiddleware,
33
+ allow_origins=["*"],
34
+ allow_credentials=True,
35
+ allow_methods=["*"],
36
+ allow_headers=["*"],
37
+ )
38
+
39
+
40
+ class SolveRequest(BaseModel):
41
+ question: str
42
+ choices: list[str]
43
+ model: str | None = None
44
+
45
+
46
+ class SolveResponse(BaseModel):
47
+ answer: str
48
+ route: str
49
+ reasoning: str
50
+ context: str
51
+
52
+
53
+ def clean_thinking_tags(text: str) -> str:
54
+ """Remove <think>...</think> tags from model response."""
55
+ import re
56
+ # Remove think tags and their content
57
+ cleaned = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
58
+ return cleaned.strip()
59
+
60
+
61
+ class ChatRequest(BaseModel):
62
+ message: str
63
+ model: str | None = None
64
+
65
+
66
+ class ChatResponse(BaseModel):
67
+ response: str
68
+ route: str
69
+
70
+
71
+ class ModelsResponse(BaseModel):
72
+ models: list[dict]
73
+
74
+
75
+ @app.get("/")
76
+ async def root():
77
+ return {"message": "VietQA Multi-Agent API", "status": "running"}
78
+
79
+
80
+ @app.get("/health")
81
+ async def health():
82
+ """Health check endpoint for Render."""
83
+ return {"status": "ok"}
84
+
85
+
86
+ @app.get("/api/models", response_model=ModelsResponse)
87
+ async def get_models():
88
+ """Get available large models."""
89
+ models = get_available_large_models()
90
+ return {
91
+ "models": [
92
+ {"id": m, "name": m.split("/")[-1]}
93
+ for m in models
94
+ ]
95
+ }
96
+
97
+
98
+ @app.post("/api/solve", response_model=SolveResponse)
99
+ async def solve_question(req: SolveRequest):
100
+ """Solve a multiple-choice question."""
101
+ if not req.question.strip():
102
+ raise HTTPException(400, "Question is required")
103
+ if len(req.choices) < 2:
104
+ raise HTTPException(400, "At least 2 choices required")
105
+
106
+ set_large_model_override(req.model)
107
+
108
+ try:
109
+ q = QuestionInput(qid="api", question=req.question, choices=req.choices)
110
+ state = question_to_state(q)
111
+ graph = get_graph()
112
+
113
+ result = await graph.ainvoke(state)
114
+
115
+ answer = normalize_answer(
116
+ answer=result.get("answer", "A"),
117
+ num_choices=len(req.choices),
118
+ question_id="api",
119
+ default="A"
120
+ )
121
+
122
+ return SolveResponse(
123
+ answer=answer,
124
+ route=result.get("route", "unknown"),
125
+ reasoning=clean_thinking_tags(result.get("raw_response", "")),
126
+ context=result.get("context", "")
127
+ )
128
+ except Exception as e:
129
+ import traceback
130
+ traceback.print_exc()
131
+ raise HTTPException(500, str(e))
132
+ finally:
133
+ set_large_model_override(None)
134
+
135
+
136
+ @app.post("/api/chat", response_model=ChatResponse)
137
+ async def chat(req: ChatRequest):
138
+ """Free-form chat (routes through pipeline without choices)."""
139
+ if not req.message.strip():
140
+ raise HTTPException(400, "Message is required")
141
+
142
+ set_large_model_override(req.model)
143
+
144
+ try:
145
+ # Use empty choices for chat mode
146
+ q = QuestionInput(qid="chat", question=req.message, choices=[])
147
+ state = question_to_state(q)
148
+ graph = get_graph()
149
+
150
+ result = await graph.ainvoke(state)
151
+
152
+ return ChatResponse(
153
+ response=clean_thinking_tags(result.get("raw_response", "")),
154
+ route=result.get("route", "unknown")
155
+ )
156
+ except Exception as e:
157
+ raise HTTPException(500, str(e))
158
+ finally:
159
+ set_large_model_override(None)
160
+
161
+
162
+ if __name__ == "__main__":
163
+ import uvicorn
164
+ uvicorn.run(app, host="0.0.0.0", port=8000)
requirements-prod.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Production dependencies - minimal for API only
2
+ fastapi==0.115.0
3
+ uvicorn==0.30.6
4
+ pydantic==2.12.5
5
+ pydantic-settings==2.12.0
6
+ python-dotenv==1.2.1
7
+
8
+ # LangChain
9
+ langchain==1.1.0
10
+ langchain-core==1.1.0
11
+ langchain-community==0.4.1
12
+ langchain-text-splitters==1.0.0
13
+ langgraph==1.0.4
14
+
15
+ # LangChain integrations
16
+ langchain-openai
17
+
18
+ # HTTP client
19
+ httpx==0.28.1
20
+ requests==2.32.5
21
+
22
+ # Jinja2 for templates
23
+ jinja2==3.1.6
24
+
25
+ # Other essentials
26
+ tenacity==9.1.2
27
+ pyyaml==6.0.3
28
+ jsonpatch==1.33
src/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """VNPT AI RAG Pipeline for Vietnamese multiple-choice questions."""
2
+
src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (225 Bytes). View file
 
src/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (227 Bytes). View file
 
src/__pycache__/config.cpython-312.pyc ADDED
Binary file (3.91 kB). View file
 
src/__pycache__/config.cpython-314.pyc ADDED
Binary file (3.88 kB). View file
 
src/__pycache__/graph.cpython-312.pyc ADDED
Binary file (1.82 kB). View file
 
src/__pycache__/pipeline.cpython-312.pyc ADDED
Binary file (10.9 kB). View file
 
src/__pycache__/state.cpython-312.pyc ADDED
Binary file (726 Bytes). View file
 
src/config.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ from dotenv import load_dotenv
5
+ from pydantic import Field
6
+ from pydantic_settings import BaseSettings
7
+
8
+ load_dotenv()
9
+
10
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
11
+
12
+ DATA_DIR = Path(os.getenv("DATA_DIR", PROJECT_ROOT / "data"))
13
+ DATA_INPUT_DIR = Path(os.getenv("DATA_INPUT_DIR", PROJECT_ROOT / "test_data"))
14
+ DATA_OUTPUT_DIR = Path(os.getenv("DATA_OUTPUT_DIR", PROJECT_ROOT / "output"))
15
+ DATA_CRAWLED_DIR = Path(os.getenv("DATA_CRAWLED_DIR", DATA_DIR / "crawl"))
16
+ BATCH_SIZE = 1
17
+
18
+
19
+ class Settings(BaseSettings):
20
+ """Application settings with environment variable support."""
21
+
22
+ # MegaLLM API settings (for small model)
23
+ megallm_api_key: str = Field(
24
+ default="",
25
+ alias="MEGALLM_API_KEY",
26
+ description="API key for MegaLLM",
27
+ )
28
+ megallm_base_url: str = Field(
29
+ default="https://ai.megallm.io/v1",
30
+ alias="MEGALLM_BASE_URL",
31
+ )
32
+
33
+ # Groq API settings (for large model)
34
+ groq_api_key: str = Field(
35
+ default="",
36
+ alias="GROQ_API_KEY",
37
+ description="API key for Groq",
38
+ )
39
+ groq_base_url: str = Field(
40
+ default="https://api.groq.com/openai/v1",
41
+ alias="GROQ_BASE_URL",
42
+ )
43
+
44
+ # OpenRouter API (fallback)
45
+ openrouter_api_key: str = Field(
46
+ default="",
47
+ alias="OPENROUTER_API_KEY",
48
+ description="API key for OpenRouter (fallback)",
49
+ )
50
+
51
+ # Model names
52
+ model_small: str = Field(
53
+ default="qwen/qwen3-32b",
54
+ alias="MODEL_SMALL",
55
+ description="Small model for routing, reranking, and RAG",
56
+ )
57
+ model_large: str = Field(
58
+ default="meta-llama/llama-4-scout-17b-16e-instruct",
59
+ alias="MODEL_LARGE",
60
+ description="Large model for logic/direct answering",
61
+ )
62
+
63
+ # Available large models for testing
64
+ available_large_models: list[str] = [
65
+ "llama-3.3-70b-versatile",
66
+ "meta-llama/llama-4-scout-17b-16e-instruct",
67
+ "moonshotai/kimi-k2-instruct-0905",
68
+ "openai/gpt-oss-120b"
69
+ ]
70
+
71
+ # Local embedding model (Vietnamese)
72
+ embedding_model: str = Field(
73
+ default="bkai-foundation-models/vietnamese-bi-encoder",
74
+ alias="EMBEDDING_MODEL",
75
+ )
76
+
77
+ # Vector database
78
+ qdrant_collection: str = Field(
79
+ default="vnpt_knowledge_base",
80
+ alias="QDRANT_COLLECTION",
81
+ )
82
+ vector_db_path: str = Field(
83
+ default="",
84
+ alias="VECTOR_DB_PATH",
85
+ description="Path to Qdrant storage. Defaults to DATA_DIR/qdrant_storage if empty.",
86
+ )
87
+
88
+ chunk_size: int = 1000
89
+ chunk_overlap: int = 200
90
+ top_k_retrieval: int = 10
91
+ top_k_rerank: int = 3
92
+
93
+ @property
94
+ def vector_db_path_resolved(self) -> Path:
95
+ """Resolve vector database path, defaulting to DATA_DIR/qdrant_storage."""
96
+ if self.vector_db_path:
97
+ return Path(self.vector_db_path)
98
+ return DATA_DIR / "qdrant_storage"
99
+
100
+ class Config:
101
+ env_file = ".env"
102
+ extra = "ignore"
103
+
104
+
105
+ settings = Settings()
106
+
107
+ # Validate API key on import
108
+ if not settings.megallm_api_key:
109
+ import warnings
110
+ warnings.warn("MEGALLM_API_KEY not set. LLM calls will fail.")
src/data_processing/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data processing utilities for the RAG pipeline."""
2
+
3
+ from src.data_processing.answer import (
4
+ extract_answer,
5
+ extract_and_normalize,
6
+ normalize_answer,
7
+ validate_answer,
8
+ )
9
+ from src.data_processing.formatting import format_choices, format_choices_display, question_to_state
10
+ from src.data_processing.loaders import load_test_data_from_csv, load_test_data_from_json
11
+ from src.data_processing.models import InferenceLogEntry, PredictionOutput, QuestionInput
12
+
13
+ __all__ = [
14
+ "QuestionInput",
15
+ "PredictionOutput",
16
+ "InferenceLogEntry",
17
+ "load_test_data_from_json",
18
+ "load_test_data_from_csv",
19
+ "question_to_state",
20
+ "format_choices",
21
+ "format_choices_display",
22
+ "extract_answer",
23
+ "validate_answer",
24
+ "normalize_answer",
25
+ "extract_and_normalize",
26
+ ]
src/data_processing/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (846 Bytes). View file
 
src/data_processing/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (843 Bytes). View file
 
src/data_processing/__pycache__/answer.cpython-312.pyc ADDED
Binary file (5.15 kB). View file
 
src/data_processing/__pycache__/answer.cpython-314.pyc ADDED
Binary file (5.61 kB). View file
 
src/data_processing/__pycache__/formatting.cpython-312.pyc ADDED
Binary file (2.2 kB). View file
 
src/data_processing/__pycache__/formatting.cpython-314.pyc ADDED
Binary file (2.72 kB). View file
 
src/data_processing/__pycache__/loaders.cpython-312.pyc ADDED
Binary file (6.14 kB). View file
 
src/data_processing/__pycache__/loaders.cpython-314.pyc ADDED
Binary file (7.11 kB). View file
 
src/data_processing/__pycache__/models.cpython-312.pyc ADDED
Binary file (2.11 kB). View file
 
src/data_processing/__pycache__/models.cpython-314.pyc ADDED
Binary file (2.94 kB). View file
 
src/data_processing/answer.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Answer extraction and validation utilities.
2
+
3
+ Consolidates answer-related logic:
4
+ - Extraction from LLM responses (CoT format)
5
+ - Validation against valid choices
6
+ - Normalization with fallback defaults
7
+ """
8
+
9
+ import re
10
+ import string
11
+
12
+ from src.utils.logging import print_log
13
+
14
+
15
+ def extract_answer(response: str, num_choices: int = 4, require_end: bool = False) -> str | None:
16
+ """Extract answer letter from LLM response using strict explicit answer lines.
17
+
18
+ Only accepts answers from explicit final-answer lines with colon:
19
+ - "Đáp án: A", "Answer: B" (preferred)
20
+ - "Lựa chọn: C" (secondary)
21
+
22
+ Returns the LAST valid explicit answer line found (later lines override earlier).
23
+
24
+ Args:
25
+ response: Response text from LLM
26
+ num_choices: Number of valid choices
27
+ require_end: If True, only extract answer from last 20% of response
28
+
29
+ Returns:
30
+ Answer letter (A, B, C, D) or None if no explicit answer found
31
+ """
32
+ if not response:
33
+ return None
34
+
35
+ valid_labels = string.ascii_uppercase[:num_choices]
36
+
37
+ # If require_end, only look at last 20% of response
38
+ search_text = response
39
+ if require_end and len(response) > 100:
40
+ cutoff = int(len(response) * 0.8)
41
+ search_text = response[cutoff:]
42
+
43
+ # Pattern for primary labels: "Đáp án:" or "Answer:" (highest priority)
44
+ primary_pattern = r"(?:Đáp\s*án|Answer)[ \t]*[::][ \t]*\**([A-Z])\b"
45
+
46
+ # Pattern for secondary label: "Lựa chọn:" (lower priority)
47
+ secondary_pattern = r"Lựa\s*chọn[ \t]*[::][ \t]*\**([A-Z])\b"
48
+
49
+ # Find all matches for both patterns
50
+ primary_matches = re.findall(primary_pattern, search_text, flags=re.IGNORECASE)
51
+ secondary_matches = re.findall(secondary_pattern, search_text, flags=re.IGNORECASE)
52
+
53
+ if primary_matches:
54
+ answer = primary_matches[-1].upper()
55
+ if answer in valid_labels:
56
+ return answer
57
+
58
+ if secondary_matches:
59
+ answer = secondary_matches[-1].upper()
60
+ if answer in valid_labels:
61
+ return answer
62
+
63
+ # Single letter response (entire response is just a letter)
64
+ clean_response = search_text.strip()
65
+ if len(clean_response) == 1 and clean_response.upper() in valid_labels:
66
+ return clean_response.upper()
67
+
68
+ return None
69
+
70
+
71
+ def validate_answer(answer: str, num_choices: int) -> tuple[bool, str]:
72
+ """Validate if answer is within valid range and normalize it.
73
+
74
+ Args:
75
+ answer: Raw answer string from model
76
+ num_choices: Number of choices available (A, B, C, D, ...)
77
+
78
+ Returns:
79
+ Tuple of (is_valid, normalized_answer)
80
+ """
81
+ valid_answers = string.ascii_uppercase[:num_choices]
82
+ if answer and answer.upper() in valid_answers:
83
+ return True, answer.upper()
84
+
85
+ return False, answer or ""
86
+
87
+
88
+ def normalize_answer(
89
+ answer: str | None,
90
+ num_choices: int,
91
+ question_id: str | None = None,
92
+ default: str = "A",
93
+ ) -> str:
94
+ """Normalize and validate answer with fallback to default.
95
+
96
+ Combines extraction, validation, and normalization:
97
+ - Validates answer is within valid range (A, B, C, D, ...)
98
+ - Normalizes refusal responses
99
+ - Falls back to default for invalid answers
100
+
101
+ Args:
102
+ answer: Raw answer string from model (can be None)
103
+ num_choices: Number of choices available
104
+ question_id: Optional question ID for logging warnings
105
+ default: Default answer if validation fails
106
+
107
+ Returns:
108
+ Normalized answer string
109
+ """
110
+ if answer is None:
111
+ if question_id:
112
+ print_log(
113
+ f" [Warning] No answer extracted for {question_id}, "
114
+ f"defaulting to {default}"
115
+ )
116
+ return default
117
+
118
+ is_valid, normalized = validate_answer(answer, num_choices)
119
+
120
+ if not is_valid:
121
+ if question_id:
122
+ print_log(
123
+ f" [Warning] Invalid answer '{answer}' for {question_id}, "
124
+ f"defaulting to {default}"
125
+ )
126
+ return default
127
+
128
+ return normalized
129
+
130
+
131
+ def extract_and_normalize(
132
+ response: str,
133
+ num_choices: int,
134
+ question_id: str | None = None,
135
+ default: str = "A",
136
+ ) -> str:
137
+ """Extract answer from response and normalize it (convenience function).
138
+
139
+ Combines extract_answer() and normalize_answer() into a single call.
140
+
141
+ Args:
142
+ response: Raw LLM response text
143
+ num_choices: Number of valid choices
144
+ question_id: Optional question ID for logging
145
+ default: Default answer if extraction/validation fails
146
+
147
+ Returns:
148
+ Normalized answer string
149
+ """
150
+ extracted = extract_answer(response, num_choices=num_choices)
151
+ return normalize_answer(extracted, num_choices, question_id, default)
src/data_processing/formatting.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import string
2
+ from typing import TYPE_CHECKING
3
+
4
+ if TYPE_CHECKING:
5
+ from src.data_processing.models import QuestionInput
6
+ from src.state import GraphState
7
+
8
+
9
+ def question_to_state(q: "QuestionInput") -> "GraphState":
10
+ """Convert QuestionInput to GraphState for pipeline processing."""
11
+ state: "GraphState" = {
12
+ "question_id": q.qid,
13
+ "question": q.question,
14
+ "all_choices": q.choices,
15
+ }
16
+ return state
17
+
18
+
19
+ def format_choices(choices: list[str]) -> str:
20
+ """Format choices for LLM prompts (A. ..., B. ..., etc.)."""
21
+ return "\n".join(f"{label}. {text}" for label, text in zip(string.ascii_uppercase, choices))
22
+
23
+
24
+ def format_choices_display(choices: list[str]) -> str:
25
+ """Format choices for console display (2 columns)."""
26
+ labels = string.ascii_uppercase
27
+ lines = []
28
+ for i in range(0, len(choices), 2):
29
+ parts = []
30
+ for j in range(2):
31
+ idx = i + j
32
+ if idx < len(choices):
33
+ label = labels[idx] if idx < len(labels) else str(idx)
34
+ parts.append(f"{label}. {choices[idx]:<30}")
35
+ if parts:
36
+ lines.append(" " + " ".join(parts))
37
+ return "\n".join(lines)
src/data_processing/loaders.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data loading utilities for test questions."""
2
+
3
+ import csv
4
+ import json
5
+ from pathlib import Path
6
+
7
+ from src.data_processing.models import QuestionInput
8
+
9
+ # Standard column mappings for choice columns
10
+ _CHOICE_COLUMN_MAPPINGS = {
11
+ "choice_a": 0, "choice_b": 1, "choice_c": 2, "choice_d": 3,
12
+ "option_a": 0, "option_b": 1, "option_c": 2, "option_d": 3,
13
+ "a": 0, "b": 1, "c": 2, "d": 3,
14
+ }
15
+
16
+
17
+ def load_test_data_from_json(file_path: Path) -> list[QuestionInput]:
18
+ """Load test questions from JSON file.
19
+
20
+ Expected format: List of dicts with qid, question, choices, answer (optional)
21
+
22
+ Args:
23
+ file_path: Path to JSON file
24
+
25
+ Returns:
26
+ List of QuestionInput objects
27
+
28
+ Raises:
29
+ FileNotFoundError: If file doesn't exist
30
+ ValueError: If file format is invalid
31
+ """
32
+ if not file_path.exists():
33
+ raise FileNotFoundError(f"Test data file not found: {file_path}")
34
+
35
+ if file_path.suffix.lower() != ".json":
36
+ raise ValueError(f"Only JSON files are supported: {file_path}")
37
+
38
+ with open(file_path, encoding="utf-8") as f:
39
+ data = json.load(f)
40
+
41
+ if not isinstance(data, list):
42
+ raise ValueError(f"JSON file must contain a list of questions: {file_path}")
43
+
44
+ questions = []
45
+ for item in data:
46
+ if "choices" not in item or not isinstance(item["choices"], list):
47
+ raise ValueError(f"Question {item.get('qid', 'unknown')} must have 'choices' as a list")
48
+
49
+ questions.append(QuestionInput(
50
+ qid=item["qid"],
51
+ question=item["question"],
52
+ choices=item["choices"],
53
+ answer=item.get("answer"),
54
+ ))
55
+
56
+ return questions
57
+
58
+
59
+ def _normalize_row_keys(row: dict[str, str]) -> dict[str, str]:
60
+ """Normalize row keys to lowercase and strip whitespace."""
61
+ return {k.lower().strip(): v for k, v in row.items()}
62
+
63
+
64
+ def _extract_choices_from_row(row: dict[str, str]) -> list[str]:
65
+ """Extract choices from a normalized CSV row.
66
+
67
+ Tries multiple strategies:
68
+ 1. Individual choice columns (choice_a/option_a/a, etc.)
69
+ 2. JSON array in 'choices' column
70
+ 3. Comma/semicolon separated string in 'choices' column
71
+
72
+ Args:
73
+ row: Normalized row dict with lowercase keys
74
+
75
+ Returns:
76
+ List of choice strings (may contain empty strings)
77
+ """
78
+ # Strategy 1: Individual columns (choice_a, option_a, a, etc.)
79
+ choices = ["", "", "", ""]
80
+ found_individual = False
81
+
82
+ for col_name, idx in _CHOICE_COLUMN_MAPPINGS.items():
83
+ if col_name in row and row[col_name]:
84
+ choices[idx] = row[col_name].strip()
85
+ found_individual = True
86
+
87
+ if found_individual:
88
+ return [c for c in choices if c]
89
+
90
+ # Strategy 2 & 3: Parse 'choices' column
91
+ choices_raw = row.get("choices", "")
92
+ if not choices_raw:
93
+ return []
94
+
95
+ # Try JSON parse first
96
+ try:
97
+ parsed = json.loads(choices_raw)
98
+ if isinstance(parsed, list):
99
+ return [str(c).strip() for c in parsed if str(c).strip()]
100
+ except (json.JSONDecodeError, TypeError):
101
+ pass
102
+
103
+ # Fallback: split by comma or semicolon
104
+ return [c.strip() for c in choices_raw.replace(";", ",").split(",") if c.strip()]
105
+
106
+
107
+ def load_test_data_from_csv(file_path: Path) -> list[QuestionInput]:
108
+ """Load test questions from CSV file.
109
+
110
+ Supports multiple CSV formats:
111
+ - Columns: qid, question, choice_a, choice_b, choice_c, choice_d
112
+ - Columns: qid, question, option_a, option_b, option_c, option_d
113
+ - Columns: qid, question, A, B, C, D
114
+ - Columns: qid, question, choices (JSON array or comma-separated)
115
+
116
+ Args:
117
+ file_path: Path to CSV file
118
+
119
+ Returns:
120
+ List of QuestionInput objects
121
+
122
+ Raises:
123
+ FileNotFoundError: If file doesn't exist
124
+ """
125
+ if not file_path.exists():
126
+ raise FileNotFoundError(f"Test data file not found: {file_path}")
127
+
128
+ questions = []
129
+ with open(file_path, encoding="utf-8") as f:
130
+ reader = csv.DictReader(f)
131
+ for row in reader:
132
+ norm_row = _normalize_row_keys(row)
133
+
134
+ qid = norm_row.get("qid", "").strip()
135
+ question = norm_row.get("question", "").strip()
136
+
137
+ if not qid or not question:
138
+ continue
139
+
140
+ choices = _extract_choices_from_row(norm_row)
141
+ if not choices:
142
+ choices = ["", "", "", ""]
143
+
144
+ questions.append(QuestionInput(
145
+ qid=qid,
146
+ question=question,
147
+ choices=choices,
148
+ answer=norm_row.get("answer", "").strip() or None,
149
+ ))
150
+
151
+ return questions
src/data_processing/models.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+
3
+
4
+ class QuestionInput(BaseModel):
5
+ """Input schema for a multiple-choice question."""
6
+
7
+ qid: str = Field(description="Question identifier")
8
+ question: str = Field(description="Question text in Vietnamese")
9
+ choices: list[str] = Field(description="List of answer choices")
10
+ answer: str | None = Field(default=None, description="Correct answer (A, B, C, ...)")
11
+
12
+
13
+ class PredictionOutput(BaseModel):
14
+ """Output schema for a prediction."""
15
+
16
+ qid: str = Field(description="Question identifier")
17
+ answer: str = Field(description="Predicted answer: A, B, C, D, ...")
18
+
19
+
20
+ class InferenceLogEntry(BaseModel):
21
+ """Schema for JSONL inference log entry (used for checkpointing)."""
22
+
23
+ qid: str = Field(description="Question identifier")
24
+ question: str = Field(description="Original question text")
25
+ choices: list[str] = Field(description="List of answer choices")
26
+ final_answer: str = Field(description="Final predicted answer")
27
+ raw_response: str = Field(default="", description="Raw LLM response")
28
+ route: str = Field(default="unknown", description="Pipeline route taken")
29
+ retrieved_context: str = Field(default="", description="Retrieved context from RAG")
src/graph.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangGraph definition for the RAG pipeline."""
2
+
3
+ from langgraph.graph import END, StateGraph
4
+
5
+ from src.state import GraphState
6
+ from src.nodes.logic import logic_solver_node
7
+ from src.nodes.rag import knowledge_rag_node
8
+ from src.nodes.router import route_question, router_node
9
+ from src.nodes.direct import direct_answer_node
10
+
11
+
12
+ def build_graph() -> StateGraph:
13
+ """Build and compile the LangGraph pipeline."""
14
+
15
+ workflow = StateGraph(GraphState)
16
+
17
+ workflow.add_node("router", router_node)
18
+ workflow.add_node("knowledge_rag", knowledge_rag_node)
19
+ workflow.add_node("logic_solver", logic_solver_node)
20
+ workflow.add_node("direct_answer", direct_answer_node)
21
+
22
+ workflow.set_entry_point("router")
23
+
24
+ workflow.add_conditional_edges(
25
+ "router",
26
+ route_question,
27
+ {
28
+ "knowledge_rag": "knowledge_rag",
29
+ "logic_solver": "logic_solver",
30
+ "direct_answer": "direct_answer",
31
+ "__end__": END,
32
+ },
33
+ )
34
+
35
+ workflow.add_edge("knowledge_rag", END)
36
+ workflow.add_edge("logic_solver", END)
37
+ workflow.add_edge("direct_answer", END)
38
+ return workflow.compile()
39
+
40
+ graph = None
41
+
42
+ def get_graph():
43
+ """Get or create the compiled graph singleton."""
44
+ global graph
45
+ if graph is None:
46
+ graph = build_graph()
47
+ return graph
src/nodes/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Node implementations for the LangGraph pipeline."""
2
+
3
+ from src.nodes.direct import direct_answer_node
4
+ from src.nodes.logic import logic_solver_node
5
+ from src.nodes.rag import knowledge_rag_node
6
+ from src.nodes.router import route_question, router_node
7
+
8
+ __all__ = [
9
+ "direct_answer_node",
10
+ "knowledge_rag_node",
11
+ "logic_solver_node",
12
+ "route_question",
13
+ "router_node",
14
+ ]
15
+
src/nodes/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (526 Bytes). View file
 
src/nodes/__pycache__/direct.cpython-312.pyc ADDED
Binary file (2.17 kB). View file
 
src/nodes/__pycache__/logic.cpython-312.pyc ADDED
Binary file (11.8 kB). View file
 
src/nodes/__pycache__/rag.cpython-312.pyc ADDED
Binary file (6.47 kB). View file
 
src/nodes/__pycache__/router.cpython-312.pyc ADDED
Binary file (5.9 kB). View file
 
src/nodes/direct.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Direct Answer node for reading comprehension or general questions without RAG."""
2
+
3
+ from langchain_core.prompts import ChatPromptTemplate
4
+
5
+ from src.data_processing.answer import extract_answer
6
+ from src.data_processing.formatting import format_choices
7
+ from src.state import GraphState
8
+ from src.utils.llm import get_large_model
9
+ from src.utils.logging import print_log
10
+ from src.utils.prompts import load_prompt
11
+
12
+
13
+ def direct_answer_node(state: GraphState) -> dict:
14
+ """Answer questions directly using Large Model (Skip Retrieval)."""
15
+ print_log(" [Direct] Processing Reading Comprehension/General Question...")
16
+
17
+ all_choices = state["all_choices"]
18
+ choices_text = format_choices(all_choices)
19
+
20
+ llm = get_large_model()
21
+
22
+ system_prompt = load_prompt("direct_answer.j2", "system")
23
+ user_prompt = load_prompt("direct_answer.j2", "user", question=state["question"], choices=choices_text)
24
+
25
+ # Escape curly braces to prevent LangChain from parsing them as variables
26
+ system_prompt = system_prompt.replace("{", "{{").replace("}", "}}")
27
+ user_prompt = user_prompt.replace("{", "{{").replace("}", "}}")
28
+
29
+ prompt = ChatPromptTemplate.from_messages([
30
+ ("system", system_prompt),
31
+ ("human", user_prompt),
32
+ ])
33
+
34
+ chain = prompt | llm
35
+ response = chain.invoke({})
36
+
37
+ content = response.content.strip()
38
+ print_log(f" [Direct] Reasoning: {content}...")
39
+
40
+ answer = extract_answer(content, num_choices=len(all_choices) or 4)
41
+ print_log(f" [Direct] Final Answer: {answer}")
42
+ return {"answer": answer, "raw_response": content}
src/nodes/logic.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Logic solver node implementing a Manual Code Execution workflow."""
2
+
3
+ import re
4
+ import string
5
+
6
+ from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
7
+ from langchain_experimental.utilities import PythonREPL
8
+
9
+ from src.data_processing.answer import extract_answer
10
+ from src.data_processing.formatting import format_choices
11
+ from src.state import GraphState
12
+ from src.utils.llm import get_large_model
13
+ from src.utils.logging import print_log
14
+ from src.utils.prompts import load_prompt
15
+
16
+ _python_repl = PythonREPL()
17
+
18
+
19
+ def extract_python_code(text: str) -> str | None:
20
+ """Find and extract Python code from block ``` python ... ```"""
21
+ match = re.search(r"```(?:python)?\s*(.*?)```", text, re.DOTALL | re.IGNORECASE)
22
+ if match:
23
+ return match.group(1).strip()
24
+ return None
25
+
26
+
27
+ def _validate_code_syntax(code: str) -> tuple[bool, str]:
28
+ """Check if code has valid Python syntax. Returns (is_valid, error_message)."""
29
+ try:
30
+ compile(code, "<string>", "exec")
31
+ return True, ""
32
+ except SyntaxError as e:
33
+ return False, str(e)
34
+
35
+
36
+ def _is_placeholder_code(code: str) -> bool:
37
+ """Check if code contains placeholders or is incomplete."""
38
+ if not code or len(code.strip()) < 10:
39
+ return True
40
+ if "..." in code:
41
+ return True
42
+ # Check for {key}-style placeholders (but not f-string or dict literals)
43
+ if re.search(r"\{[a-zA-Z_][a-zA-Z0-9_]*\}", code):
44
+ # Exclude common dict/set patterns and f-strings
45
+ if not re.search(r'["\'][^"\']*\{[a-zA-Z_]', code):
46
+ return True
47
+ return False
48
+
49
+
50
+ def _indent_code(code: str) -> str:
51
+ """Format code to make it easier to read in the terminal."""
52
+ return "\n".join(f" {line}" for line in code.splitlines())
53
+
54
+
55
+ def _fallback_text_reasoning(llm, question: str, choices_text: str) -> dict:
56
+ """Fallback to CoT reasoning when code execution fails."""
57
+ print_log(" [Logic] Falling back to CoT reasoning...")
58
+
59
+ fallback_system = (
60
+ "Nhiệm vụ của bạn là trả lời câu hỏi "
61
+ "được đưa ra bằng khả năng phân tích và suy luận logic. "
62
+ "Hãy phân tích vấn đề và suy luận đề từng bước một. "
63
+ "Cuối cùng, hãy trả lời theo đúng định dạng: 'Đáp án: X' "
64
+ "trong đó X là ký tự đại diện cho lựa chọn đúng (A, B, C, D, ...)."
65
+ )
66
+
67
+ fallback_user = (
68
+ f"Câu hỏi: {question}\n"
69
+ f"{choices_text}"
70
+ )
71
+
72
+ fallback_messages: list[BaseMessage] = [
73
+ SystemMessage(content=fallback_system),
74
+ HumanMessage(content=fallback_user)
75
+ ]
76
+
77
+ fallback_response = llm.invoke(fallback_messages)
78
+ fallback_content = fallback_response.content
79
+ print_log(f" [Logic] Fallback response received.")
80
+
81
+ return {"text": fallback_content}
82
+
83
+
84
+ def _request_final_answer(llm, question: str, choices_text: str, computed_results: str) -> str:
85
+ """Request a strict final answer from the model."""
86
+ system_prompt = (
87
+ "Bạn là trợ lý AI. Dựa vào kết quả tính toán được cung cấp, "
88
+ "hãy đưa ra đáp án cuối cùng. CHỈ trả lời đúng một dòng: Đáp án: X "
89
+ "(trong đó X là A, B, C hoặc D)."
90
+ )
91
+ user_prompt = (
92
+ f"Câu hỏi: {question}\n"
93
+ f"{choices_text}\n"
94
+ f"Kết quả tính toán: {computed_results}\n\n"
95
+ "Trả lời đúng một dòng: Đáp án: X"
96
+ )
97
+
98
+ messages: list[BaseMessage] = [
99
+ SystemMessage(content=system_prompt),
100
+ HumanMessage(content=user_prompt)
101
+ ]
102
+
103
+ response = llm.invoke(messages)
104
+ return response.content
105
+
106
+
107
+ def logic_solver_node(state: GraphState) -> dict:
108
+ """Solve math/logic questions using Python code execution."""
109
+ llm = get_large_model()
110
+ all_choices = state["all_choices"]
111
+ num_choices = len(all_choices)
112
+ choices_text = format_choices(all_choices)
113
+
114
+ system_prompt = load_prompt("logic_solver.j2", "system")
115
+ user_prompt = load_prompt("logic_solver.j2", "user", question=state["question"], choices=choices_text)
116
+
117
+ messages: list[BaseMessage] = [
118
+ SystemMessage(content=system_prompt),
119
+ HumanMessage(content=user_prompt)
120
+ ]
121
+
122
+ step_texts: list[str] = []
123
+ computed_outputs: list[str] = []
124
+
125
+ max_steps = 5
126
+ for step in range(max_steps):
127
+ response = llm.invoke(messages)
128
+ content = response.content
129
+ step_texts.append(content)
130
+ messages.append(response)
131
+
132
+ code_block = extract_python_code(content)
133
+
134
+ if code_block:
135
+ if _is_placeholder_code(code_block):
136
+ print_log(f" [Logic] Step {step+1}: Placeholder code detected. Requesting complete code...")
137
+ regen_msg = (
138
+ "Code không hợp lệ (chứa placeholder hoặc không đầy đủ). "
139
+ "Hãy cung cấp code Python hoàn chỉnh, có thể chạy được, không chứa '...' hay placeholder. "
140
+ "In ra các giá trị tính toán được. "
141
+ "Cuối cùng, kết thúc bằng một dòng duy nhất: Đáp án: X (X là A, B, C hoặc D)."
142
+ )
143
+ messages.append(HumanMessage(content=regen_msg))
144
+ continue
145
+
146
+ print_log(f" [Logic] Step {step+1}: Found Python code. Executing...")
147
+
148
+ # Validate syntax before execution
149
+ is_valid, syntax_error = _validate_code_syntax(code_block)
150
+ if not is_valid:
151
+ print_log(f" [Error] Syntax error detected: {syntax_error}")
152
+ error_msg = f"SyntaxError: {syntax_error}. "
153
+ error_msg += "Lưu ý: KHÔNG sử dụng các từ khóa Python như 'lambda', 'class', 'def' làm tên biến. "
154
+ error_msg += "Hãy đổi tên biến và thử lại."
155
+ messages.append(HumanMessage(content=error_msg))
156
+ continue
157
+
158
+ print_log(f" [Logic] Code:\n{_indent_code(code_block)}")
159
+
160
+ try:
161
+ if "print" not in code_block:
162
+ lines = code_block.splitlines()
163
+ if lines:
164
+ last_line = lines[-1]
165
+ if "=" in last_line:
166
+ var_name = last_line.split("=")[0].strip()
167
+ else:
168
+ var_name = last_line.strip()
169
+ code_block += f"\nprint({var_name})"
170
+
171
+ output = _python_repl.run(code_block)
172
+ output = output.strip() if output else "No output."
173
+ print_log(f" [Logic] Code output: {output}")
174
+ computed_outputs.append(output)
175
+
176
+ # Do NOT extract answer from code output directly
177
+ # Instead, feed output back to model and ask for final answer line
178
+ feedback_msg = (
179
+ f"Kết quả thực thi code: {output}\n\n"
180
+ "Dựa vào kết quả trên, hãy so sánh với các đáp án và đưa ra câu trả lời cuối cùng. "
181
+ "Kết thúc bằng đúng một dòng: Đáp án: X (X là A, B, C hoặc D)."
182
+ )
183
+ messages.append(HumanMessage(content=feedback_msg))
184
+
185
+ except Exception as e:
186
+ error_msg = f"Error running code: {str(e)}"
187
+ print_log(f" [Error] {error_msg}")
188
+ messages.append(HumanMessage(content=f"{error_msg}. Hãy kiểm tra logic và sửa lại code."))
189
+
190
+ continue
191
+
192
+ # Check if current step contains an explicit answer (only at end of response)
193
+ step_answer = extract_answer(content, num_choices=num_choices, require_end=True)
194
+ if step_answer:
195
+ print_log(f" [Logic] Step {step+1}: Found explicit answer: {step_answer}")
196
+ combined_raw = "\n---STEP---\n".join(step_texts)
197
+ return {"answer": step_answer, "raw_response": combined_raw, "route": "math"}
198
+
199
+ # Also check if response contains clear conclusion without "Đáp án:" format
200
+ if any(phrase in content.lower() for phrase in ["kết luận", "vậy đáp án", "do đó", "vì vậy"]):
201
+ # Try to extract any single letter at end of response
202
+ lines = content.strip().split('\n')
203
+ for line in reversed(lines[-3:]): # Check last 3 lines
204
+ line = line.strip()
205
+ if len(line) == 1 and line.upper() in string.ascii_uppercase[:num_choices]:
206
+ print_log(f" [Logic] Step {step+1}: Found implicit answer: {line.upper()}")
207
+ combined_raw = "\n---STEP---\n".join(step_texts)
208
+ return {"answer": line.upper(), "raw_response": combined_raw, "route": "math"}
209
+
210
+ if step < max_steps - 1:
211
+ print_log(" [Warning] No code or answer found. Reminding model...")
212
+ messages.append(HumanMessage(content="Lưu ý: Bạn vẫn chưa đưa ra đáp án cuối cùng. Hãy kết thúc bằng: Đáp án: X"))
213
+
214
+ # Max steps reached - build combined_raw and try to extract answer
215
+ print_log(" [Warning] Max steps reached. Attempting answer extraction from combined text...")
216
+
217
+ # Build combined_raw from all steps
218
+ combined_raw = "\n---STEP---\n".join(step_texts) if step_texts else ""
219
+
220
+ # Try fallback text reasoning with error handling
221
+ try:
222
+ fallback_result = _fallback_text_reasoning(llm, state["question"], choices_text)
223
+ fallback_text = fallback_result["text"]
224
+ if fallback_text:
225
+ combined_raw += "\n---FALLBACK---\n" + fallback_text
226
+ except Exception as e:
227
+ print_log(f" [Error] Fallback reasoning failed: {e}")
228
+ fallback_text = ""
229
+
230
+ # Extract answer from the entire combined text (takes LAST explicit answer)
231
+ final_answer = extract_answer(combined_raw, num_choices=num_choices)
232
+
233
+ if final_answer:
234
+ print_log(f" [Logic] Extracted final answer from combined text: {final_answer}")
235
+ return {"answer": final_answer, "raw_response": combined_raw, "route": "math"}
236
+
237
+ # Still no answer - do one final strict LLM call with error handling
238
+ print_log(" [Logic] No explicit answer found. Requesting strict final answer...")
239
+ computed_str = "; ".join(computed_outputs) if computed_outputs else "Không có kết quả tính toán"
240
+ try:
241
+ strict_response = _request_final_answer(llm, state["question"], choices_text, computed_str)
242
+ combined_raw += "\n---FINAL---\n" + strict_response
243
+
244
+ final_answer = extract_answer(strict_response, num_choices=num_choices)
245
+ if final_answer:
246
+ print_log(f" [Logic] Final strict answer: {final_answer}")
247
+ return {"answer": final_answer, "raw_response": combined_raw, "route": "math"}
248
+ except Exception as e:
249
+ print_log(f" [Error] Final answer request failed: {e}")
250
+
251
+ # Absolute fallback - default to A
252
+ print_log(" [Warning] All extraction attempts failed. Defaulting to A.")
253
+ return {"answer": "A", "raw_response": combined_raw, "route": "math"}
src/nodes/rag.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RAG node for knowledge-based question answering with Retrieve & Rerank."""
2
+
3
+ import re
4
+
5
+ from langchain_core.prompts import ChatPromptTemplate
6
+
7
+ from src.config import settings
8
+ from src.data_processing.answer import extract_answer
9
+ from src.data_processing.formatting import format_choices
10
+ from src.state import GraphState
11
+ from src.utils.ingestion import get_vector_store
12
+ from src.utils.llm import get_small_model
13
+ from src.utils.logging import print_log
14
+ from src.utils.prompts import load_prompt
15
+ from src.nodes.direct import direct_answer_node
16
+
17
+
18
+ def _rerank_documents(query: str, docs: list, top_k: int = 3) -> list:
19
+ """Rerank retrieved documents using the small LLM.
20
+
21
+ Args:
22
+ query: The user question
23
+ docs: List of retrieved documents
24
+ top_k: Number of top documents to return after reranking
25
+
26
+ Returns:
27
+ List of reranked documents (top_k most relevant)
28
+ """
29
+ if len(docs) <= top_k:
30
+ return docs
31
+
32
+ llm = get_small_model()
33
+
34
+ # Build document list for reranking prompt
35
+ doc_list = ""
36
+ for i, doc in enumerate(docs):
37
+ content_preview = doc.page_content[:350].replace("\n", " ")
38
+ doc_list += f"[{i}] {content_preview}...\n\n"
39
+
40
+ rerank_system = (
41
+ "/no_think\n"
42
+ "Bạn là chuyên gia đánh giá độ liên quan của văn bản. "
43
+ "Nhiệm vụ: Chọn ra các đoạn văn bản LIÊN QUAN NHẤT với câu hỏi.\n"
44
+ "Chỉ trả về danh sách các số ID (ví dụ: 0, 3, 5), không giải thích."
45
+ )
46
+
47
+ rerank_user = (
48
+ f"Câu hỏi: {query}\n\n"
49
+ f"Các đoạn văn bản:\n{doc_list}\n"
50
+ f"Hãy chọn {top_k} đoạn văn bản LIÊN QUAN NHẤT với câu hỏi. "
51
+ f"Trả về danh sách ID (số từ 0 đến {len(docs)-1}), cách nhau bởi dấu phẩy."
52
+ )
53
+
54
+ prompt = ChatPromptTemplate.from_messages([
55
+ ("system", rerank_system),
56
+ ("human", rerank_user),
57
+ ])
58
+
59
+ try:
60
+ chain = prompt | llm
61
+ response = chain.invoke({})
62
+ content = response.content.strip()
63
+ print_log(f" [RAG] Reranker response: {content}")
64
+
65
+ # Parse selected IDs from response
66
+ selected_ids = []
67
+ numbers = re.findall(r'\d+', content)
68
+ for num_str in numbers:
69
+ idx = int(num_str)
70
+ if 0 <= idx < len(docs) and idx not in selected_ids:
71
+ selected_ids.append(idx)
72
+ if len(selected_ids) >= top_k:
73
+ break
74
+
75
+ if selected_ids:
76
+ reranked = [docs[i] for i in selected_ids]
77
+ print_log(f" [RAG] Reranked: selected {len(reranked)} docs from {len(docs)}")
78
+ return reranked
79
+
80
+ print_log(" [RAG] Rerank parsing failed, using first top_k docs")
81
+ return docs[:top_k]
82
+
83
+ except Exception as e:
84
+ print_log(f" [RAG] Reranking failed: {e}. Using keyword boosting fallback.")
85
+ return docs[:top_k]
86
+
87
+
88
+
89
+ def knowledge_rag_node(state: GraphState) -> dict:
90
+ """Retrieve relevant context, rerank, and answer knowledge-based questions."""
91
+ vector_store = get_vector_store()
92
+ query = state["question"]
93
+ print_log(f" [RAG] Retrieving context for: '{query}'")
94
+
95
+ docs = vector_store.similarity_search(query, k=settings.top_k_retrieval)
96
+ print_log(f" [RAG] Retrieved {len(docs)} documents")
97
+
98
+ if not docs:
99
+ print_log(" [Warning] No relevant documents found in Knowledge Base.")
100
+ context = ""
101
+ else:
102
+ reranked_docs = _rerank_documents(query, docs, top_k=settings.top_k_rerank)
103
+
104
+ context = "\n\n---\n\n".join([doc.page_content for doc in reranked_docs])
105
+
106
+ if reranked_docs:
107
+ print_log(f" [RAG] Using {len(reranked_docs)} reranked docs. Top: \"{reranked_docs[0].page_content[:80]}...\"")
108
+
109
+ all_choices = state["all_choices"]
110
+ choices_text = format_choices(all_choices)
111
+
112
+ llm = get_small_model()
113
+
114
+ system_prompt = load_prompt("rag.j2", "system", context=context)
115
+ user_prompt = load_prompt("rag.j2", "user", question=state["question"], choices=choices_text)
116
+
117
+ # Escape curly braces to prevent LangChain from parsing them as variables
118
+ system_prompt = system_prompt.replace("{", "{{").replace("}", "}}")
119
+ user_prompt = user_prompt.replace("{", "{{").replace("}", "}}")
120
+
121
+ prompt = ChatPromptTemplate.from_messages([
122
+ ("system", system_prompt),
123
+ ("human", user_prompt),
124
+ ])
125
+
126
+ chain = prompt | llm
127
+ response = chain.invoke({})
128
+ content = response.content.strip()
129
+ print_log(f" [RAG] Reasoning: {content}")
130
+
131
+ answer = extract_answer(content, num_choices=len(all_choices) or 4)
132
+ print_log(f" [RAG] Final Answer: {answer}")
133
+
134
+ # Fallback to direct mode if RAG context was not helpful
135
+ if answer is None:
136
+ print_log(" [RAG] Context not relevant, falling back to direct mode...")
137
+ direct_result = direct_answer_node(state)
138
+ direct_result["route"] = "rag->direct" # Track the fallback
139
+ return direct_result
140
+
141
+ return {"answer": answer, "context": context, "raw_response": content}
src/nodes/router.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Router node for classifying questions and directing to appropriate handlers."""
2
+
3
+ import string
4
+ from typing import Literal
5
+
6
+ from langchain_core.prompts import ChatPromptTemplate
7
+
8
+ from src.data_processing.formatting import format_choices
9
+ from src.state import GraphState
10
+ from src.utils.llm import get_small_model
11
+ from src.utils.logging import print_log
12
+ from src.utils.prompts import load_prompt
13
+
14
+
15
+ def _find_refusal_option(state: GraphState) -> str | None:
16
+ """Find refusal option in choices and return corresponding letter."""
17
+ all_choices = state["all_choices"]
18
+ option_labels = list(string.ascii_uppercase[:len(all_choices)])
19
+
20
+ refusal_patterns = [
21
+ "tôi không thể", "không thể trả lời", "không thể cung cấp", "không thể chia sẻ",
22
+ "từ chối trả lời", "từ chối cung cấp",
23
+ "nằm ngoài phạm vi", "không thuộc phạm vi", "tôi là mô hình ngôn ngữ",
24
+ "hành vi vi phạm", "trái pháp luật", "không hỗ trợ",
25
+ ]
26
+
27
+ for i, choice in enumerate(all_choices):
28
+ txt = choice.lower().strip()
29
+ if any(p in txt for p in refusal_patterns):
30
+ return option_labels[i]
31
+
32
+ return None
33
+
34
+
35
+ def _classify_with_llm(state: GraphState) -> str:
36
+ """Classify question using LLM."""
37
+ choices_text = format_choices(state["all_choices"])
38
+ llm = get_small_model()
39
+
40
+ system_prompt = load_prompt("router.j2", "system")
41
+ user_prompt = load_prompt("router.j2", "user", question=state["question"], choices=choices_text)
42
+
43
+ # Escape curly braces to prevent LangChain from parsing them as variables
44
+ system_prompt = system_prompt.replace("{", "{{").replace("}", "}}")
45
+ user_prompt = user_prompt.replace("{", "{{").replace("}", "}}")
46
+
47
+ prompt = ChatPromptTemplate.from_messages([
48
+ ("system", system_prompt),
49
+ ("human", user_prompt),
50
+ ])
51
+ chain = prompt | llm
52
+ response = chain.invoke({})
53
+ return response.content.strip().lower()
54
+
55
+
56
+ def router_node(state: GraphState) -> dict:
57
+ """Analyze question and determine routing path. Returns answer immediately for toxic content."""
58
+ question = state["question"].lower()
59
+
60
+ # Fast-track: Direct answer for reading comprehension
61
+ direct_keywords = ["đoạn thông tin", "đoạn văn", "bài đọc", "căn cứ vào đoạn", "theo đoạn"]
62
+ if any(k in question for k in direct_keywords) and len(question.split()) > 50:
63
+ print_log(" [Router] Fast-track: Direct Answer (Found Context block)")
64
+ return {"route": "direct"}
65
+
66
+ # Fast-track: Math/Logic for LaTeX or math keywords
67
+ math_signals = [
68
+ "$", "\\frac", "^", "=", "tính giá trị", "biểu thức", "phương trình",
69
+ "hàm số", "đạo hàm", "xác suất", "lãi suất", "vận tốc", "gia tốc",
70
+ "điện trở", "gam", "mol", "nguyên tử khối", "gdp", "lạm phát", "công suất"
71
+ ]
72
+ if any(s in question for s in math_signals):
73
+ print_log(" [Router] Fast-track: Math (Keywords/LaTeX detected)")
74
+ return {"route": "math"}
75
+
76
+ print_log(" [Router] Slow-track: Using LLM to classify...")
77
+ try:
78
+ route = _classify_with_llm(state)
79
+ print_log(f" [Router] LLM Decision: {route}")
80
+
81
+ if "direct" in route:
82
+ route_type = "direct"
83
+ elif "math" in route or "logic" in route:
84
+ route_type = "math"
85
+ elif "toxic" in route:
86
+ refusal_answer = _find_refusal_option(state)
87
+ if refusal_answer:
88
+ print_log(f" [Router] Toxic detected, found refusal option: {refusal_answer}")
89
+ return {"route": "toxic", "answer": refusal_answer}
90
+ print_log(" [Router] Toxic detected, no refusal option found, defaulting to A")
91
+ return {"route": "toxic", "answer": "A"}
92
+ else:
93
+ route_type = "rag"
94
+
95
+ return {"route": route_type}
96
+ except Exception as e:
97
+ print_log(f" [Router] Error: {e}. Fallback to RAG.")
98
+ return {"route": "rag"}
99
+
100
+
101
+ def route_question(state: GraphState) -> Literal["knowledge_rag", "logic_solver", "direct_answer", "__end__"]:
102
+ """Conditional edge function to route to appropriate node based on state route."""
103
+ route = state.get("route", "rag")
104
+ answer = state.get("answer")
105
+
106
+ if route == "toxic":
107
+ return "__end__"
108
+ if route == "direct":
109
+ return "direct_answer"
110
+ if route == "math":
111
+ return "logic_solver"
112
+ return "knowledge_rag"
src/pipeline.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Core pipeline execution logic for the RAG system."""
2
+
3
+ import asyncio
4
+ import csv
5
+ import sys
6
+ import time
7
+ from pathlib import Path
8
+
9
+ from src.config import BATCH_SIZE, DATA_OUTPUT_DIR
10
+ from src.data_processing.answer import normalize_answer
11
+ from src.data_processing.formatting import format_choices_display, question_to_state
12
+ from src.data_processing.models import InferenceLogEntry, PredictionOutput, QuestionInput
13
+ from src.graph import get_graph
14
+ from src.utils.checkpointing import (
15
+ append_log_entry,
16
+ consolidate_log_file,
17
+ generate_csv_from_log,
18
+ is_rate_limit_error,
19
+ )
20
+ from src.utils.common import sort_qids
21
+ from src.utils.ingestion import get_vector_store
22
+ from src.utils.logging import log_done, log_pipeline, log_stats, print_log
23
+
24
+
25
+ def sort_questions_by_qid(questions: list[QuestionInput]) -> list[QuestionInput]:
26
+ """Sort questions by qid using natural sorting."""
27
+ qid_to_question = {q.qid: q for q in questions}
28
+ sorted_qids = sort_qids(list(qid_to_question.keys()))
29
+ return [qid_to_question[qid] for qid in sorted_qids]
30
+
31
+
32
+ async def run_pipeline_async(
33
+ questions: list[QuestionInput],
34
+ batch_size: int = BATCH_SIZE,
35
+ ) -> list[PredictionOutput]:
36
+ """Run pipeline for inference (assumes pre-built Vector DB).
37
+
38
+ Args:
39
+ questions: List of questions to process
40
+ batch_size: Number of concurrent questions to process
41
+
42
+ Returns:
43
+ List of PredictionOutput objects sorted by qid
44
+ """
45
+ log_pipeline("Loading pre-built vector store...")
46
+ get_vector_store()
47
+
48
+ questions = sort_questions_by_qid(questions)
49
+
50
+ graph = get_graph()
51
+ total = len(questions)
52
+ start_time = time.perf_counter()
53
+
54
+ sem = asyncio.Semaphore(batch_size)
55
+ results: dict[str, PredictionOutput] = {}
56
+
57
+ async def process_single_question(q: QuestionInput) -> None:
58
+ async with sem:
59
+ print_log(f"\n[{q.qid}] {q.question}")
60
+ print_log(format_choices_display(q.choices))
61
+ state = question_to_state(q)
62
+ result = await graph.ainvoke(state)
63
+
64
+ answer = result.get("answer", "A")
65
+ route = result.get("route", "unknown")
66
+ num_choices = len(q.choices)
67
+
68
+ normalized_answer = normalize_answer(
69
+ answer=answer,
70
+ num_choices=num_choices,
71
+ question_id=q.qid,
72
+ default="A",
73
+ )
74
+
75
+ log_done(f"{q.qid}: {normalized_answer} (Route: {route})")
76
+ results[q.qid] = PredictionOutput(qid=q.qid, answer=normalized_answer)
77
+
78
+ tasks = [process_single_question(q) for q in questions]
79
+ await asyncio.gather(*tasks)
80
+
81
+ elapsed = time.perf_counter() - start_time
82
+ throughput = total / elapsed if elapsed > 0 else 0
83
+ log_stats(f"Completed {total} questions in {elapsed:.2f}s ({throughput:.2f} req/s)")
84
+
85
+ sorted_qids = sort_qids(list(results.keys()))
86
+ return [results[qid] for qid in sorted_qids]
87
+
88
+
89
+ async def run_pipeline_with_checkpointing(
90
+ questions: list[QuestionInput],
91
+ log_path: Path,
92
+ batch_size: int = BATCH_SIZE,
93
+ ) -> int:
94
+ """Run pipeline with JSONL checkpointing for resume capability.
95
+
96
+ Questions are processed in qid order. Results are appended to log file
97
+ immediately for fault tolerance, then consolidated at the end.
98
+
99
+ Args:
100
+ questions: List of questions to process (already filtered for unprocessed)
101
+ log_path: Path to JSONL log file for checkpointing
102
+ batch_size: Number of concurrent questions to process
103
+
104
+ Returns:
105
+ Count of newly processed questions
106
+ """
107
+ log_pipeline("Loading pre-built vector store...")
108
+ get_vector_store()
109
+
110
+ questions = sort_questions_by_qid(questions)
111
+ log_pipeline(f"Processing {len(questions)} questions in qid order...")
112
+
113
+ graph = get_graph()
114
+ total = len(questions)
115
+ start_time = time.perf_counter()
116
+ processed_count = 0
117
+
118
+ sem = asyncio.Semaphore(batch_size)
119
+ stop_event = asyncio.Event()
120
+
121
+ async def process_single_question(q: QuestionInput) -> None:
122
+ nonlocal processed_count
123
+ if stop_event.is_set():
124
+ return
125
+
126
+ async with sem:
127
+ if stop_event.is_set():
128
+ return
129
+ print_log(f"\n[{q.qid}] {q.question}")
130
+ print_log(format_choices_display(q.choices))
131
+ state = question_to_state(q)
132
+
133
+ try:
134
+ result = await graph.ainvoke(state)
135
+ route = result.get("route", "unknown")
136
+ raw_response = result.get("raw_response", "")
137
+ context = result.get("context", "")
138
+
139
+ answer = normalize_answer(
140
+ answer=result.get("answer"),
141
+ num_choices=len(q.choices),
142
+ question_id=q.qid,
143
+ default="A",
144
+ )
145
+
146
+ log_entry = InferenceLogEntry(
147
+ qid=q.qid,
148
+ question=q.question,
149
+ choices=q.choices,
150
+ final_answer=answer,
151
+ raw_response=raw_response,
152
+ route=route,
153
+ retrieved_context=context,
154
+ )
155
+ await append_log_entry(log_path, log_entry)
156
+
157
+ log_done(f"{q.qid}: {answer} (Route: {route})")
158
+ processed_count += 1
159
+ # await asyncio.sleep(150)
160
+
161
+ except Exception as e:
162
+ if is_rate_limit_error(e):
163
+ print_log(f" [CRITICAL] Rate Limit Detected on {q.qid}: {e}")
164
+ stop_event.set()
165
+ else:
166
+ print_log(f" [Error] Failed to process {q.qid}: {e}")
167
+
168
+ tasks = [asyncio.create_task(process_single_question(q)) for q in questions]
169
+ await asyncio.gather(*tasks)
170
+
171
+ if stop_event.is_set():
172
+ log_pipeline("!!! PIPELINE STOPPED DUE TO RATE LIMIT !!!")
173
+ log_pipeline("Consolidating logs and generating emergency submission...")
174
+ consolidate_log_file(log_path)
175
+
176
+ output_file = DATA_OUTPUT_DIR / "submission_emergency.csv"
177
+ total_entries = generate_csv_from_log(log_path, output_file)
178
+ log_pipeline(f"Saved emergency submission with {total_entries} entries to: {output_file}")
179
+
180
+ sys.exit(0)
181
+
182
+ log_pipeline("Consolidating log file...")
183
+ consolidate_log_file(log_path)
184
+
185
+ elapsed = time.perf_counter() - start_time
186
+ throughput = total / elapsed if elapsed > 0 else 0
187
+ log_stats(f"Processed {processed_count}/{total} questions in {elapsed:.2f}s ({throughput:.2f} req/s)")
188
+
189
+ return processed_count
190
+
191
+
192
+ def save_predictions(
193
+ predictions: list[PredictionOutput],
194
+ output_path: Path,
195
+ ensure_dir: bool = True,
196
+ ) -> None:
197
+ """Save predictions to CSV file, sorted by qid.
198
+
199
+ Args:
200
+ predictions: List of prediction outputs
201
+ output_path: Path to output CSV file
202
+ ensure_dir: If True, create parent directory if it doesn't exist
203
+ """
204
+ if ensure_dir:
205
+ output_path.parent.mkdir(parents=True, exist_ok=True)
206
+
207
+ sorted_qids = sort_qids([p.qid for p in predictions])
208
+ pred_dict = {p.qid: p for p in predictions}
209
+
210
+ with open(output_path, "w", newline="", encoding="utf-8") as f:
211
+ writer = csv.DictWriter(f, fieldnames=["qid", "answer"])
212
+ writer.writeheader()
213
+ for qid in sorted_qids:
214
+ writer.writerow({"qid": qid, "answer": pred_dict[qid].answer})
215
+ log_pipeline(f"Predictions saved to: {output_path}")
src/state.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """State schema definitions for the RAG pipeline graph."""
2
+
3
+ from typing import TypedDict
4
+
5
+
6
+ class GraphState(TypedDict, total=False):
7
+ """State schema for the RAG pipeline graph."""
8
+
9
+ question_id: str
10
+ question: str
11
+ all_choices: list[str]
12
+ route: str
13
+ context: str
14
+ answer: str
15
+ raw_response: str
16
+
src/templates/direct_answer.j2 ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {# Direct Answer Node Prompt Templates #}
2
+ {% block system %}
3
+ /no_think
4
+ Bạn là một chuyên gia trả lời câu hỏi trắc nghiệm. Nhiệm vụ của bạn là phân tích và chọn đáp án đúng nhất cho câu hỏi.
5
+
6
+ NGÔN NGỮ: Toàn bộ suy luận, giải thích PHẢI bằng TIẾNG VIỆT 100%. KHÔNG dùng tiếng Anh.
7
+
8
+ Lưu ý:
9
+ 1. Nếu đề bài có đoạn văn, CHỈ dựa vào đoạn văn đó để suy luận.
10
+ 2. Suy luận từng bước logic.
11
+ - Với câu hỏi về ngày tháng, con số: So sánh chính xác từng ký tự.
12
+ - Nếu câu hỏi yêu cầu tìm từ sai/đúng: Đối chiếu từng phương án với văn bản.
13
+ 3. Trả lời bằng: "Đáp án: X" (X là một trong các lựa chọn A, B, C, D, ...).
14
+ {% endblock %}
15
+
16
+ {% block user %}
17
+ Câu hỏi: {{ question }}
18
+ {{ choices }}
19
+ {% endblock %}
src/templates/logic_solver.j2 ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {# Logic Solver (Code Agent) Prompt Templates #}
2
+ {% block system %}
3
+ /no_think
4
+ Bạn là chuyên gia giải toán và logic. Trả lời NGẮN GỌN, SÚNG TÍCH.
5
+
6
+ NGÔN NGỮ: Toàn bộ suy luận, giải thích PHẢI bằng TIẾNG VIỆT 100%. KHÔNG dùng tiếng Anh.
7
+
8
+ QUY TẮC:
9
+ 1. Suy luận ngắn gọn, đi thẳng vào vấn đề
10
+ 2. Chỉ nêu các bước quan trọng, bỏ qua chi tiết thừa
11
+ 3. Dòng "Đáp án: X" PHẢI là dòng CUỐI CÙNG
12
+ 4. Tối đa 5-7 dòng suy luận
13
+
14
+ CẤU TRÚC:
15
+ 1. Phân tích ngắn gọn
16
+ 2. Suy luận chính
17
+ 3. Kết luận
18
+ 4. Đáp án: X
19
+
20
+ VÍ DỤ TỐT (NGẮN GỌN):
21
+ ```
22
+ Phân tích: Tính 2 + 3 * 4
23
+ Thứ tự: 3 * 4 = 12, sau đó 2 + 12 = 14
24
+ Kết luận: 14 tương ứng đáp án B
25
+
26
+ Đáp án: B
27
+ ```
28
+
29
+ NHẮC LẠI: NGẮN GỌN, SÚNG TÍCH! Chỉ 5-7 dòng! TIẾNG VIỆT 100%!
30
+ {% endblock %}
31
+
32
+ {% block user %}
33
+ {{ question }}
34
+ {{ choices }}
35
+
36
+ Suy luận ngắn gọn:
37
+ {% endblock %}
src/templates/rag.j2 ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {# RAG Node Prompt Templates #}
2
+ {% block system %}
3
+ /no_think
4
+ Bạn là một chuyên gia phân tích thông tin và đọc hiểu văn bản chính xác tuyệt đối.
5
+ Nhiệm vụ: Trả lời câu hỏi trắc nghiệm CHỈ dựa trên thông tin trong phần Văn bản được cung cấp bên dưới.
6
+
7
+ NGÔN NGỮ: Toàn bộ suy luận, giải thích PHẢI bằng TIẾNG VIỆT 100%. KHÔNG dùng tiếng Anh.
8
+
9
+ Văn bản:
10
+ {{ context }}
11
+
12
+ Quy tắc bắt buộc:
13
+ 1. Đọc kỹ văn bản để tìm các từ khóa liên quan đến câu hỏi.
14
+ 2. So sánh từng lựa chọn với thông tin tìm được trong văn bản.
15
+ 3. Suy luận từng bước:
16
+ - Nếu văn bản chứa câu trả lời trực tiếp: Trích dẫn ý đó để xác nhận.
17
+ - Nếu văn bản KHÔNG chứa câu trả lời trực tiếp: Sử dụng phương pháp loại trừ các đáp án sai để chọn đáp án phù hợp và
18
+ đúng nhất.
19
+ 4. Trả lời cuối cùng theo định dạng: "Đáp án: X" (trong đó X là ký tự lựa chọn). Ví dụ: "Đáp án: A"
20
+ {% endblock %}
21
+
22
+ {% block user %}
23
+ Câu hỏi: {{ question }}
24
+ {{ choices }}
25
+ {% endblock %}
src/templates/router.j2 ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {# Router Node Prompt Templates #}
2
+ {% block system %}
3
+ /no_think
4
+ Bạn là hệ thống phân loại câu hỏi tiếng Việt. LUÔN trả lời bằng TIẾNG VIỆT.
5
+
6
+ Nhiệm vụ: Phân loại câu hỏi vào duy nhất 1 trong 4 nhóm: "toxic", "direct", "math", hoặc "rag".
7
+
8
+ Hãy thực hiện theo quy trình kiểm tra thứ tự ưu tiên sau đây (QUAN TRỌNG):
9
+
10
+ Ưu tiên 1: Kiểm tra "toxic" (An toàn là trên hết)
11
+ - Nếu câu hỏi yêu cầu hướng dẫn thực hiện hành vi vi phạm pháp luật (trốn thuế, làm giả giấy tờ, tham nhũng, buôn lậu,
12
+ chế tạo vũ khí...).
13
+ - Nếu câu hỏi mang tính chất phản động, chống phá nhà nước, bôi nhọ lãnh tụ, hoặc vi phạm thuần phong mỹ tục.
14
+ -> Trả về: toxic
15
+
16
+ Ưu tiên 2: Kiểm tra "direct" (Đọc hiểu văn bản có sẵn)
17
+ - Hãy nhìn vào dữ liệu đầu vào. Nếu nó chứa các từ khóa đánh dấu văn bản như: "Đoạn thông tin:", "Văn bản:", "Document",
18
+ "Title:", "Nội dung:", hoặc một đoạn văn dài đi kèm trước câu hỏi.
19
+ - Bất kể nội dung là Lịch sử hay Khoa học, nếu ĐÃ CÓ đoạn văn bản đi kèm để trả lời -> Phải chọn nhóm này.
20
+ -> Trả về: direct
21
+
22
+ Ưu tiên 3: Kiểm tra "math" (Tư duy logic & Tính toán & Lập trình)
23
+ - Các bài tập Toán, Lý, Hóa, Sinh yêu cầu tính toán ra số liệu cụ thể (không phải lý thuyết suông).
24
+ - Các câu hỏi chứa công thức toán học (LaTeX, dấu $, phương trình).
25
+ - Các câu hỏi về Lập trình.
26
+ - Các câu hỏi tư duy logic, chuỗi số, xác suất thống kê.
27
+ -> Trả về: math
28
+
29
+ Ưu tiên 4: Kiểm tra "rag" (Tra cứu kiến thức)
30
+ - Các câu hỏi kiến thức về Lịch sử, Địa lý, Luật pháp, Văn hóa, Xã hội.
31
+ - Các câu hỏi lý thuyết khoa học (không cần tính toán).
32
+ - Câu hỏi mà KHÔNG CÓ đoạn văn bản đi kèm.
33
+ -> Trả về: rag
34
+
35
+ QUAN TRỌNG: Chỉ trả về đúng 1 từ kết quả (toxic/direct/math/rag). Không giải thích thêm.
36
+ {% endblock %}
37
+
38
+ {% block user %}
39
+ Câu hỏi: {{ question }}
40
+ {{ choices }}
41
+
42
+ Kết quả phân loại (chỉ 1 từ):
43
+ {% endblock %}
src/utils/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for the RAG pipeline."""
2
+
3
+ from src.utils.checkpointing import (
4
+ append_log_entry,
5
+ consolidate_log_file,
6
+ generate_csv_from_log,
7
+ is_rate_limit_error,
8
+ load_log_entries,
9
+ load_processed_qids,
10
+ )
11
+ from src.utils.common import normalize_text, remove_diacritics, sort_qids
12
+ from src.utils.ingestion import (
13
+ get_embeddings,
14
+ get_qdrant_client,
15
+ get_vector_store,
16
+ ingest_all_data,
17
+ ingest_files,
18
+ )
19
+ from src.utils.llm import get_large_model, get_small_model
20
+ from src.utils.web_crawler import WebCrawler, crawl_website, save_crawled_data
21
+
22
+ __all__ = [
23
+ # Checkpointing
24
+ "load_processed_qids",
25
+ "load_log_entries",
26
+ "append_log_entry",
27
+ "consolidate_log_file",
28
+ "generate_csv_from_log",
29
+ "is_rate_limit_error",
30
+ "sort_qids",
31
+ # Ingestion
32
+ "get_embeddings",
33
+ "get_qdrant_client",
34
+ "get_vector_store",
35
+ "ingest_all_data",
36
+ "ingest_files",
37
+ # LLM
38
+ "get_small_model",
39
+ "get_large_model",
40
+ # Text utilities
41
+ "normalize_text",
42
+ "remove_diacritics",
43
+ # Web crawler
44
+ "WebCrawler",
45
+ "crawl_website",
46
+ "save_crawled_data",
47
+ ]
src/utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.03 kB). View file
 
src/utils/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (1.03 kB). View file
 
src/utils/__pycache__/checkpointing.cpython-312.pyc ADDED
Binary file (5.38 kB). View file
 
src/utils/__pycache__/checkpointing.cpython-314.pyc ADDED
Binary file (6.5 kB). View file
 
src/utils/__pycache__/common.cpython-312.pyc ADDED
Binary file (3.31 kB). View file