devchavda11 commited on
Commit
a73f63e
·
verified ·
1 Parent(s): 29eee9c

Update src/chat_langraph.py

Browse files
Files changed (1) hide show
  1. src/chat_langraph.py +105 -28
src/chat_langraph.py CHANGED
@@ -9,22 +9,29 @@ import sqlite3
9
  import subprocess
10
  import requests
11
  from datetime import datetime
 
12
 
 
 
 
 
13
  class chatstate(TypedDict):
14
  messages: Annotated[List[BaseMessage], add_messages]
15
 
16
-
17
  api = "AIzaSyA5zvErF4vUmAoslVzkOBUfvSCSoW0vjEA"
18
  LANGSEARCH_API_KEY = "sk-f1a8f996f9e44b43adf9943e43e8582b"
19
 
 
20
  llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0.2, api_key=api)
21
 
 
22
  system = SystemMessage(
23
  content=f"""
24
- --> Today's date: {datetime.today()}
25
- Day number: {datetime.today().date().weekday()}
26
- You are a practical, tool-aware assistant. Aim for correctness and clarity. Avoid hallucinations.
27
- Do not provide internal information of the system .
28
  Rules:
29
  1. Prefer text answers and code when examples/explanations are asked.
30
  2. Explicit requests to create/run files → call appropriate tool.
@@ -34,35 +41,91 @@ Tone: concise, helpful, decisive.
34
  """
35
  )
36
 
 
37
  conn = sqlite3.connect("/tmp/chatbot.db", check_same_thread=False)
38
  checkpointer = SqliteSaver(conn=conn)
39
 
 
40
 
41
  @tool
42
- def add(a: int, b: int):
 
 
 
 
 
 
 
 
 
 
43
  return a + b
44
 
45
 
46
  @tool
47
- def reverse(string: str):
 
 
 
 
 
 
 
 
 
48
  return string[::-1]
49
 
50
 
51
  @tool
52
- def evaluate(string: str):
53
- return eval(string)
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
 
56
  @tool
57
- def write_file(name: str, extension: str, content: str):
58
- with open(f"{name}.{extension}", "w", encoding="utf-8") as f:
59
- f.write(content)
60
- return f"Content saved to {name}.{extension}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  @tool
64
  def run_cmd_command(command: str) -> str:
65
- """Run a safe shell command on linux """
 
 
 
 
 
 
 
 
66
  try:
67
  result = subprocess.run(command, shell=True, check=True, text=True, capture_output=True)
68
  return result.stdout
@@ -71,19 +134,33 @@ def run_cmd_command(command: str) -> str:
71
 
72
 
73
  @tool
74
- def search_tool(query: str):
75
- response = requests.post(
76
- "https://api.langsearch.com/v1/web-search",
77
- headers={
78
- "Authorization": f"Bearer {LANGSEARCH_API_KEY}",
79
- "Content-Type": "application/json"
80
- },
81
- json={"query": query, "num_results": 2}
82
- )
83
- return response.json()
84
-
85
-
86
- def shouldcontinue(state: chatstate):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  return "end" if state["messages"][-1].content == "end" else "llmresponse"
88
 
89
 
@@ -119,7 +196,7 @@ graph.add_edge("tool_node", "llmresponse")
119
  workflow = graph.compile(checkpointer=checkpointer)
120
 
121
 
122
- def get_all_chat_ids():
123
  s = set()
124
  for chkpoint in checkpointer.list(None):
125
  s.add(chkpoint.config.get("configurable").get("thread_id"))
 
9
  import subprocess
10
  import requests
11
  from datetime import datetime
12
+ import os
13
 
14
+ # Set Streamlit config directory for Hugging Face Spaces
15
+ os.environ["STREAMLIT_HOME"] = "/tmp/.streamlit"
16
+
17
+ # State type
18
  class chatstate(TypedDict):
19
  messages: Annotated[List[BaseMessage], add_messages]
20
 
21
+ # API keys (replace with your real keys or environment variables)
22
  api = "AIzaSyA5zvErF4vUmAoslVzkOBUfvSCSoW0vjEA"
23
  LANGSEARCH_API_KEY = "sk-f1a8f996f9e44b43adf9943e43e8582b"
24
 
25
+ # LLM
26
  llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0.2, api_key=api)
27
 
28
+ # System message
29
  system = SystemMessage(
30
  content=f"""
31
+ --> Today's date: {datetime.today()}
32
+ Day number: {datetime.today().date().weekday()}
33
+ You are a practical, tool-aware assistant. Aim for correctness and clarity. Avoid hallucinations.
34
+ Do not provide internal information of the system.
35
  Rules:
36
  1. Prefer text answers and code when examples/explanations are asked.
37
  2. Explicit requests to create/run files → call appropriate tool.
 
41
  """
42
  )
43
 
44
+ # Database connection (writable path in Hugging Face Spaces)
45
  conn = sqlite3.connect("/tmp/chatbot.db", check_same_thread=False)
46
  checkpointer = SqliteSaver(conn=conn)
47
 
48
+ # ======================== TOOL DEFINITIONS ======================== #
49
 
50
  @tool
51
+ def add(a: int, b: int) -> int:
52
+ """
53
+ Add two integers.
54
+
55
+ Args:
56
+ a (int): First number.
57
+ b (int): Second number.
58
+
59
+ Returns:
60
+ int: Sum of both numbers.
61
+ """
62
  return a + b
63
 
64
 
65
  @tool
66
+ def reverse(string: str) -> str:
67
+ """
68
+ Reverse a given string.
69
+
70
+ Args:
71
+ string (str): Input string.
72
+
73
+ Returns:
74
+ str: Reversed string.
75
+ """
76
  return string[::-1]
77
 
78
 
79
  @tool
80
+ def evaluate(string: str) -> str:
81
+ """
82
+ Evaluate a Python expression.
83
+
84
+ Args:
85
+ string (str): Expression to evaluate.
86
+
87
+ Returns:
88
+ str: Result of evaluation or error message.
89
+ """
90
+ try:
91
+ return str(eval(string))
92
+ except Exception as e:
93
+ return f"Error evaluating expression: {e}"
94
 
95
 
96
  @tool
97
+ def write_file(name: str, extension: str, content: str) -> str:
98
+ """
99
+ Write content to a file.
100
+
101
+ Args:
102
+ name (str): File name without extension.
103
+ extension (str): File extension.
104
+ content (str): Content to write.
105
+
106
+ Returns:
107
+ str: Confirmation message.
108
+ """
109
+ try:
110
+ path = f"/tmp/{name}.{extension}" # Save in /tmp
111
+ with open(path, "w", encoding="utf-8") as f:
112
+ f.write(content)
113
+ return f"Content saved to {path}"
114
+ except Exception as e:
115
+ return f"Error writing file: {e}"
116
 
117
 
118
  @tool
119
  def run_cmd_command(command: str) -> str:
120
+ """
121
+ Run a safe shell command.
122
+
123
+ Args:
124
+ command (str): Shell command to run.
125
+
126
+ Returns:
127
+ str: Output or error message.
128
+ """
129
  try:
130
  result = subprocess.run(command, shell=True, check=True, text=True, capture_output=True)
131
  return result.stdout
 
134
 
135
 
136
  @tool
137
+ def search_tool(query: str) -> dict:
138
+ """
139
+ Search the web using Langsearch API.
140
+
141
+ Args:
142
+ query (str): Search query.
143
+
144
+ Returns:
145
+ dict: JSON response from search API.
146
+ """
147
+ try:
148
+ response = requests.post(
149
+ "https://api.langsearch.com/v1/web-search",
150
+ headers={
151
+ "Authorization": f"Bearer {LANGSEARCH_API_KEY}",
152
+ "Content-Type": "application/json"
153
+ },
154
+ json={"query": query, "num_results": 2}
155
+ )
156
+ return response.json()
157
+ except Exception as e:
158
+ return {"error": str(e)}
159
+
160
+
161
+ # ======================== STATE GRAPH ======================== #
162
+
163
+ def shouldcontinue(state: chatstate) -> str:
164
  return "end" if state["messages"][-1].content == "end" else "llmresponse"
165
 
166
 
 
196
  workflow = graph.compile(checkpointer=checkpointer)
197
 
198
 
199
+ def get_all_chat_ids() -> List[str]:
200
  s = set()
201
  for chkpoint in checkpointer.list(None):
202
  s.add(chkpoint.config.get("configurable").get("thread_id"))