quachtiensinh27 commited on
Commit
dd47faf
·
1 Parent(s): ca2ba49

feat: implement core agent architecture including LLM integration, Redis-backed memory, tool definitions, and comprehensive test suite.

Browse files
Files changed (8) hide show
  1. agent.py +26 -4
  2. config.py +5 -0
  3. llm.py +31 -13
  4. redis_client.py +17 -2
  5. tools/base.py +23 -4
  6. tools/memory.py +2 -2
  7. tools/scheduler.py +27 -5
  8. tools/summarizer.py +2 -0
agent.py CHANGED
@@ -14,6 +14,7 @@ if project_root not in sys.path:
14
  import logging
15
  from typing import Any
16
  from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage
 
17
  from src.llm import llm
18
  from src.config import DEFAULT_MODEL, LOG_LEVEL
19
  from src.tools import get_tool_schemas, execute_tool
@@ -21,9 +22,26 @@ from src.tools import get_tool_schemas, execute_tool
21
  logging.basicConfig(level=LOG_LEVEL, format="%(asctime)s [%(levelname)s] %(message)s")
22
  logger = logging.getLogger(__name__)
23
 
24
- SYSTEM_PROMPT = """You are an intelligent AI assistant.
25
- You can use the provided tools to complete tasks.
26
- Think step by step and use tools when necessary."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  def create_agent():
@@ -35,8 +53,12 @@ def run_agent_loop(client: Any, user_input: str, max_turns: int = 10) -> str:
35
  """
36
  Run the agent loop: send message -> receive response -> call tool -> repeat.
37
  """
 
 
 
 
38
  messages = [
39
- SystemMessage(content=SYSTEM_PROMPT),
40
  HumanMessage(content=user_input)
41
  ]
42
 
 
14
  import logging
15
  from typing import Any
16
  from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage
17
+ from datetime import datetime
18
  from src.llm import llm
19
  from src.config import DEFAULT_MODEL, LOG_LEVEL
20
  from src.tools import get_tool_schemas, execute_tool
 
22
  logging.basicConfig(level=LOG_LEVEL, format="%(asctime)s [%(levelname)s] %(message)s")
23
  logger = logging.getLogger(__name__)
24
 
25
+ SYSTEM_PROMPT = """You are an intelligent AI Assistant designed to be a "Second Brain" for the user.
26
+ Your primary goal is to help the user manage their life while respecting UNIQUE constraints.
27
+
28
+ CALENDAR-FIRST PRIORITY RULES:
29
+ 1. **Event > Habit**: A specific calendar event (e.g., "Family Anniversary", "OOO", "Anniversary", "Off-grid") on a specific date ALWAYS overrules a general habit (e.g., "Monday Deep Work", "17h Swimming").
30
+ 2. **No Exceptions for OOO**: If the user is OOO or "Off-grid" on a day, you MUST NOT suggest any work or meetings for that day. A "Sếp đòi phương án sáng mai" request must be resolved by proposing action **BEFORE** the OOO starts (e.g., tonight) or **DELEGATING** completely to an internal staff (e.g., Anh Hoàng).
31
+ 3. **Validate Every Date**: Always check the specific date/day for a request (e.g., "Sếp đòi sáng mai" when today is Wednesday means Thursday). Cross-reference this date with your schedule results.
32
+
33
+ SEARCH & REASONING RULES:
34
+ 1. **BROAD SEARCH**: Call `get_memories(limit=100)` with NO query for audits.
35
+ 2. **SINGLE-WORD KEYWORDS**: Only use single-word keywords for search (e.g., "azure", "ghét").
36
+ 3. **RANGE SEARCH**: Call `get_schedule(date_str="next 2 weeks")`.
37
+ 4. **Parallel Context**: Call all 3 context tools TOGETHER in your first turn.
38
+ 5. **Strict Taboos**: Strictly reject any tech/vendor the user has an aversion to (Azure) and any work practices they hate (Outsourcing).
39
+
40
+ THINK STEP-BY-STEP:
41
+ 1. Call search tools in parallel.
42
+ 2. Cross-reference all chat proposals against specific calendar events (First Priority).
43
+ 3. Apply habits and taboos (Second Priority).
44
+ 4. Synthesize the final plan."""
45
 
46
 
47
  def create_agent():
 
53
  """
54
  Run the agent loop: send message -> receive response -> call tool -> repeat.
55
  """
56
+ # Dynamic Date Injection
57
+ today = datetime.now()
58
+ time_context = f"\n[CURRENT TIME CONTEXT]\nToday is {today.strftime('%A, %B %d, %Y')}.\n"
59
+
60
  messages = [
61
+ SystemMessage(content=SYSTEM_PROMPT + time_context),
62
  HumanMessage(content=user_input)
63
  ]
64
 
config.py CHANGED
@@ -18,6 +18,11 @@ QWEN_API_KEY = os.getenv("QWEN_API_KEY", "")
18
  QWEN_BASE_URL = os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
19
  QWEN_MODEL = os.getenv("QWEN_MODEL", "qwen-plus")
20
 
 
 
 
 
 
21
  # Local LLM config
22
  USE_LOCAL_LLM = os.getenv("USE_LOCAL_LLM", "false").lower() == "true"
23
  LOCAL_MODEL_ID = os.getenv("LOCAL_MODEL_ID", "Qwen/Qwen2.5-0.5B-Instruct")
 
18
  QWEN_BASE_URL = os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
19
  QWEN_MODEL = os.getenv("QWEN_MODEL", "qwen-plus")
20
 
21
+ # OpenRouter
22
+ OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY", "")
23
+ OPENROUTER_BASE_URL = os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1")
24
+ OPENROUTER_MODEL = os.getenv("OPENROUTER_MODEL", "google/gemma-4-26b-a4b-it")
25
+
26
  # Local LLM config
27
  USE_LOCAL_LLM = os.getenv("USE_LOCAL_LLM", "false").lower() == "true"
28
  LOCAL_MODEL_ID = os.getenv("LOCAL_MODEL_ID", "Qwen/Qwen2.5-0.5B-Instruct")
llm.py CHANGED
@@ -1,17 +1,35 @@
1
  from langchain_google_genai import ChatGoogleGenerativeAI
2
- from src.config import GEMINI_API_KEY, DEFAULT_MODEL
 
3
 
4
- llm = ChatGoogleGenerativeAI(
5
- model=DEFAULT_MODEL,
6
- temperature=0,
7
- top_p=1,
8
- top_k=1,
9
- max_tokens=None,
10
- timeout=None,
11
- max_retries=2,
12
- google_api_key=GEMINI_API_KEY
13
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  if __name__ == "__main__":
16
- response = llm.invoke("Hello World là gì?").content
17
- print(response)
 
 
 
 
 
1
  from langchain_google_genai import ChatGoogleGenerativeAI
2
+ from langchain_openai import ChatOpenAI
3
+ from src.config import GEMINI_API_KEY, DEFAULT_MODEL, OPENROUTER_API_KEY, OPENROUTER_BASE_URL, OPENROUTER_MODEL
4
 
5
+ def get_agent_llm():
6
+ """Returns the LLM instance based on availability: OpenRouter > Gemini."""
7
+ if OPENROUTER_API_KEY and not OPENROUTER_API_KEY.startswith("your-"):
8
+ return ChatOpenAI(
9
+ model=OPENROUTER_MODEL,
10
+ api_key=OPENROUTER_API_KEY,
11
+ base_url=OPENROUTER_BASE_URL,
12
+ temperature=0,
13
+ max_tokens=4096,
14
+ default_headers={
15
+ "HTTP-Referer": "https://github.com/a20-ai-thuc-chien", # Optional for OpenRouter
16
+ "X-Title": "A20 AI Assistant",
17
+ }
18
+ )
19
+
20
+ # Fallback to Gemini
21
+ return ChatGoogleGenerativeAI(
22
+ model=DEFAULT_MODEL,
23
+ temperature=0,
24
+ google_api_key=GEMINI_API_KEY
25
+ )
26
+
27
+ llm = get_agent_llm()
28
 
29
  if __name__ == "__main__":
30
+ # Test
31
+ try:
32
+ response = llm.invoke("Hello, who are you?").content
33
+ print(f"LLM Response: {response}")
34
+ except Exception as e:
35
+ print(f"Error testing LLM: {e}")
redis_client.py CHANGED
@@ -12,6 +12,7 @@ from typing import Optional, Any
12
 
13
  import redis
14
  import time
 
15
 
16
  # Thêm path để load config nếu cần
17
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
@@ -123,7 +124,6 @@ class RedisClient:
123
  return True
124
 
125
  # ISO timestamp -> unix timestamp logic
126
- from datetime import datetime
127
  ts = int(datetime.now().timestamp() * 1000)
128
  if "time" in event_data:
129
  try:
@@ -144,7 +144,22 @@ class RedisClient:
144
  try:
145
  if self._use_local:
146
  db = self._load_local()
147
- return list(db["events"].values())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  index_key = self._key("evt", "index")
150
  event_ids = self._client.zrangebyscore(index_key, start_ts, end_ts)
 
12
 
13
  import redis
14
  import time
15
+ from datetime import datetime
16
 
17
  # Thêm path để load config nếu cần
18
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
 
124
  return True
125
 
126
  # ISO timestamp -> unix timestamp logic
 
127
  ts = int(datetime.now().timestamp() * 1000)
128
  if "time" in event_data:
129
  try:
 
144
  try:
145
  if self._use_local:
146
  db = self._load_local()
147
+ events = list(db["events"].values())
148
+ filtered = []
149
+ for ev in events:
150
+ try:
151
+ # Parse time to check against range
152
+ dt = datetime.fromisoformat(ev["time"])
153
+ ts = int(dt.timestamp() * 1000)
154
+ if start_ts <= ts <= end_ts:
155
+ filtered.append(ev)
156
+ else:
157
+ logger.info(f"Event {ev.get('name')} ({ev.get('time')}) excluded: ts {ts} outside {start_ts}-{end_ts}")
158
+ except (ValueError, KeyError, TypeError):
159
+ # Fallback: if no time, only include if full range
160
+ if start_ts == 0 and end_ts >= 3000000000000:
161
+ filtered.append(ev)
162
+ return filtered
163
 
164
  index_key = self._key("evt", "index")
165
  event_ids = self._client.zrangebyscore(index_key, start_ts, end_ts)
tools/base.py CHANGED
@@ -11,9 +11,17 @@ from langchain_openai import ChatOpenAI
11
  from langchain_huggingface import HuggingFacePipeline
12
  from langchain_google_genai import ChatGoogleGenerativeAI
13
  try:
14
- from ..config import QWEN_API_KEY, QWEN_BASE_URL, QWEN_MODEL, LOG_LEVEL, USE_LOCAL_LLM, LOCAL_MODEL_ID, GEMINI_API_KEY
 
 
 
 
15
  except (ImportError, ValueError):
16
- from config import QWEN_API_KEY, QWEN_BASE_URL, QWEN_MODEL, LOG_LEVEL, USE_LOCAL_LLM, LOCAL_MODEL_ID, GEMINI_API_KEY
 
 
 
 
17
 
18
  logger = logging.getLogger(__name__)
19
 
@@ -38,11 +46,22 @@ def get_llm():
38
  """
39
  Initialize and return the LLM based on configuration (Gemini > Local > Cloud Qwen).
40
  """
41
- # 1. Prioritize Gemini if API key is present
 
 
 
 
 
 
 
 
 
 
 
42
  if GEMINI_API_KEY and not GEMINI_API_KEY.startswith("your-"):
43
  logger.info("Initializing Google Gemini LLM...")
44
  return ChatGoogleGenerativeAI(
45
- model="gemini-2.0-flash", # or from config
46
  google_api_key=GEMINI_API_KEY,
47
  temperature=0.1,
48
  )
 
11
  from langchain_huggingface import HuggingFacePipeline
12
  from langchain_google_genai import ChatGoogleGenerativeAI
13
  try:
14
+ from ..config import (
15
+ QWEN_API_KEY, QWEN_BASE_URL, QWEN_MODEL,
16
+ LOG_LEVEL, USE_LOCAL_LLM, LOCAL_MODEL_ID,
17
+ GEMINI_API_KEY, OPENROUTER_API_KEY, OPENROUTER_BASE_URL, OPENROUTER_MODEL
18
+ )
19
  except (ImportError, ValueError):
20
+ from config import (
21
+ QWEN_API_KEY, QWEN_BASE_URL, QWEN_MODEL,
22
+ LOG_LEVEL, USE_LOCAL_LLM, LOCAL_MODEL_ID,
23
+ GEMINI_API_KEY, OPENROUTER_API_KEY, OPENROUTER_BASE_URL, OPENROUTER_MODEL
24
+ )
25
 
26
  logger = logging.getLogger(__name__)
27
 
 
46
  """
47
  Initialize and return the LLM based on configuration (Gemini > Local > Cloud Qwen).
48
  """
49
+ # 1. Prioritize OpenRouter
50
+ if OPENROUTER_API_KEY and not OPENROUTER_API_KEY.startswith("your-"):
51
+ logger.info(f"Initializing OpenRouter LLM ({OPENROUTER_MODEL})...")
52
+ return ChatOpenAI(
53
+ model=OPENROUTER_MODEL,
54
+ api_key=OPENROUTER_API_KEY,
55
+ base_url=OPENROUTER_BASE_URL,
56
+ temperature=0.1,
57
+ max_tokens=4096,
58
+ )
59
+
60
+ # 2. Fallback to Gemini
61
  if GEMINI_API_KEY and not GEMINI_API_KEY.startswith("your-"):
62
  logger.info("Initializing Google Gemini LLM...")
63
  return ChatGoogleGenerativeAI(
64
+ model="gemini-2.0-flash",
65
  google_api_key=GEMINI_API_KEY,
66
  temperature=0.1,
67
  )
tools/memory.py CHANGED
@@ -50,11 +50,11 @@ def tool_save_memory(content: str, category: str = "general") -> dict:
50
  name="get_memories",
51
  description="Tìm kiếm và truy xuất các thông tin đã ghi nhớ trước đây dựa trên từ khóa.",
52
  parameters=[
53
- {"name": "query", "type": "string", "description": "Từ khóa tìm kiếm thông tin trong bộ nhớ.", "required": False},
54
  {"name": "limit", "type": "integer", "description": "Số lượng kết quả tối đa.", "required": False}
55
  ]
56
  )
57
- def tool_get_memories(query: str = None, limit: int = 10) -> dict:
58
  """
59
  Retrieves memories from Redis.
60
  """
 
50
  name="get_memories",
51
  description="Tìm kiếm và truy xuất các thông tin đã ghi nhớ trước đây dựa trên từ khóa.",
52
  parameters=[
53
+ {"name": "query", "type": "string", "description": "TỪ KHÓA DUY NHẤT (VD: 'azure', 'ghét'). Để trống nếu muốn lấy toàn bộ 100 ghi nhớ mới nhất.", "required": False},
54
  {"name": "limit", "type": "integer", "description": "Số lượng kết quả tối đa.", "required": False}
55
  ]
56
  )
57
+ def tool_get_memories(query: str = None, limit: int = 100) -> dict:
58
  """
59
  Retrieves memories from Redis.
60
  """
tools/scheduler.py CHANGED
@@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
15
  {
16
  "name": "query",
17
  "type": "string",
18
- "description": "Từ khóa tìm kiếm trong tên hoặc tả sự kiện. dụ: 'họp nhóm', 'báo cáo'",
19
  "required": False
20
  },
21
  {
@@ -49,10 +49,17 @@ def tool_get_schedule(query: str = "", date_str: str = "", room_id: str = None)
49
  start_ts = int(day_start.timestamp() * 1000)
50
  end_ts = int(day_end.timestamp() * 1000)
51
  else:
52
- return {
53
- "status": "error",
54
- "message": f"Không thể hiểu được khoảng thời gian: '{date_str}'."
55
- }
 
 
 
 
 
 
 
56
 
57
  # Retrieve from Redis
58
  events = redis_client.list_events(start_ts, end_ts)
@@ -68,6 +75,21 @@ def tool_get_schedule(query: str = "", date_str: str = "", room_id: str = None)
68
  if match:
69
  results.append(event)
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  # 2. Fetch Chat Context (Hybrid Memory)
72
  chat_context = []
73
  if room_id:
 
15
  {
16
  "name": "query",
17
  "type": "string",
18
+ "description": "TỪ KHÓA DUY NHẤT (VD: 'họp'). Để trống nếu muốn xem toàn bộ lịch trình.",
19
  "required": False
20
  },
21
  {
 
49
  start_ts = int(day_start.timestamp() * 1000)
50
  end_ts = int(day_end.timestamp() * 1000)
51
  else:
52
+ # Fallback for range-like strings (e.g., "next 2 weeks", "tuần tới")
53
+ range_keywords = ["tuần", "tháng", "next", "week", "month", "khoảng", "tới"]
54
+ if any(k in date_str.lower() for k in range_keywords):
55
+ logger.info(f"Date string '{date_str}' looks like a range. Returning all future events (30 days).")
56
+ start_ts = int(datetime.now().timestamp() * 1000)
57
+ end_ts = start_ts + (30 * 24 * 60 * 60 * 1000) # 30 days
58
+ else:
59
+ return {
60
+ "status": "error",
61
+ "message": f"Không thể hiểu được khoảng thời gian: '{date_str}'."
62
+ }
63
 
64
  # Retrieve from Redis
65
  events = redis_client.list_events(start_ts, end_ts)
 
75
  if match:
76
  results.append(event)
77
 
78
+ # Robustness Fallback:
79
+ # If results are empty and there was a query but no date_str,
80
+ # check if the query was meant to be a date (e.g., "tối nay").
81
+ # We only fallback if the query contains common time-related keywords to avoid false positives (e.g., "ăn").
82
+ time_keywords = ["nay", "mai", "mốt", "hôm", "tối", "sáng", "chiều", "trưa", "ngày", "lịch", "tuần", "tháng"]
83
+ is_time_query = any(k in query.lower() for k in time_keywords) or any(char.isdigit() for char in query)
84
+
85
+ if not results and query and not date_str and is_time_query:
86
+ fallback_date = dateparser.parse(query, settings={'PREFER_DATES_FROM': 'future'})
87
+ if fallback_date:
88
+ # Check if parsing was actually meaningful (not just a random number or word parsed as current year)
89
+ logger.info(f"Query '{query}' looks like a date. Retrying search with date filtering.")
90
+ # Recursive call with query moved to date_str
91
+ return tool_get_schedule(query="", date_str=query, room_id=room_id)
92
+
93
  # 2. Fetch Chat Context (Hybrid Memory)
94
  chat_context = []
95
  if room_id:
tools/summarizer.py CHANGED
@@ -16,6 +16,8 @@ try:
16
  except (ImportError, ValueError):
17
  from redis_client import redis_client
18
 
 
 
19
  # --- Pydantic Schemas ---
20
  class ThreadSummary(BaseModel):
21
  """Schema cho tóm tắt của một thread."""
 
16
  except (ImportError, ValueError):
17
  from redis_client import redis_client
18
 
19
+ logger = logging.getLogger(__name__)
20
+
21
  # --- Pydantic Schemas ---
22
  class ThreadSummary(BaseModel):
23
  """Schema cho tóm tắt của một thread."""