Ashok75 commited on
Commit
8453b43
·
verified ·
1 Parent(s): 9c856ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -74
app.py CHANGED
@@ -1,100 +1,118 @@
1
  import torch
2
- import json
3
  import re
4
- import datetime
5
  from flask import Flask, request, Response, render_template
6
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
- from threading import Thread
 
 
8
 
9
  app = Flask(__name__)
10
 
11
- # 1. TOOL DEFINITIONS
12
- def get_current_datetime(query: str = ""):
13
- """Returns the current date and time."""
14
- return f"Observation: The current date and time is {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}."
15
 
 
16
  def simple_calculator(expression: str):
17
- """An easy-to-construct tool for basic math (add, sub, mult, div)."""
18
  try:
19
- # Source 351: Calculators are essential tools for deterministic results.
20
- # Note: In production, use a safer math parser instead of eval.
21
- result = eval(expression, {"__builtins__": None}, {})
22
- return f"Observation: The calculation result is {result}."
23
  except Exception as e:
24
- return f"Observation: Error in calculation: {str(e)}."
25
 
26
- # Tool Registry
27
- tools = {
28
- "get_current_datetime": get_current_datetime,
29
- "simple_calculator": simple_calculator
30
- }
31
 
32
- # Load Model
 
 
33
  model_id = "AshokGakr/model-tiny"
34
  tokenizer = AutoTokenizer.from_pretrained(model_id)
35
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
36
 
37
- SYSTEM_PROMPT = """
38
- ROLE: You are a ReAct Agent. You solve tasks using this loop:
39
- Thought: (Reasoning about what to do)
40
- Action: (Tool name: 'get_current_datetime' or 'simple_calculator')
41
- Action Input: (Parameter for the tool)
42
- Observation: (Result from the tool - provided to you)
43
- ... (Repeat Thought/Action/Observation if needed)
44
- Final Answer: (The final response to the user)
45
-
46
- AVAILABLE TOOLS:
47
- - get_current_datetime: Use this for any questions about the current date or time. No input needed.
48
- - simple_calculator: Use this for any math calculations. Input should be a math expression (e.g., '10 + 5').
49
- """
50
 
51
- @app.route('/')
52
- def index():
53
- return render_template('index.html')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- @app.route('/chat', methods=['POST'])
56
- def chat():
57
- data = request.json
58
- user_query = data.get("message", "")
 
 
 
59
 
60
- def generate_agent_response():
61
- # Source 13: Episodic memory maintains the conversation trajectory.
62
- history = [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_query}]
63
-
64
- for i in range(5): # Limit iterations to prevent infinite loops [5]
65
- input_ids = tokenizer.apply_chat_template(history, add_generation_prompt=True, return_tensors="pt").to(model.device)
66
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
67
-
68
- thread = Thread(target=model.generate, kwargs={"input_ids": input_ids, "streamer": streamer, "max_new_tokens": 256})
69
- thread.start()
70
 
71
- full_turn_output = ""
72
- for new_text in streamer:
73
- full_turn_output += new_text
74
- yield new_text # Stream thoughts to the UI [6]
 
 
 
75
 
76
- # Check for Action [7]
77
- action_match = re.search(r"Action:\s*(\w+)", full_turn_output)
78
- input_match = re.search(r"Action Input:\s*(.*)", full_turn_output)
79
 
80
- if action_match and input_match:
81
- tool_name = action_match.group(1).strip()
82
- tool_input = input_match.group(1).strip()
83
-
84
- if tool_name in tools:
85
- obs = tools[tool_name](tool_input)
86
- yield f"\n{obs}\n"
87
- # Feed observation back into history [8, 9]
88
- history.append({"role": "assistant", "content": full_turn_output})
89
- history.append({"role": "user", "content": obs})
90
- else:
91
- break
92
- elif "Final Answer:" in full_turn_output:
93
- break
94
- else:
95
- break
96
 
97
- return Response(generate_agent_response(), mimetype='text/plain')
98
 
99
  if __name__ == '__main__':
100
  app.run(host='0.0.0.0', port=7860)
 
1
  import torch
 
2
  import re
3
+ from typing import Annotated, TypedDict, Union
4
  from flask import Flask, request, Response, render_template
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ from langgraph.graph import StateGraph, END
7
+ from langchain_core.tools import tool
8
+ from pydantic import BaseModel, Field
9
 
10
  app = Flask(__name__)
11
 
12
+ # 1. DEFINE STRUCTURED TOOLS WITH PYDANTIC
13
+ class CalcInput(BaseModel):
14
+ expression: str = Field(description="The math expression to evaluate, e.g., '2 + 2'")
 
15
 
16
+ @tool("simple_calculator", args_schema=CalcInput)
17
  def simple_calculator(expression: str):
18
+ """Useful for basic math calculations."""
19
  try:
20
+ # Source 351: Tools provide deterministic results for agents.
21
+ return str(eval(expression, {"__builtins__": None}, {}))
 
 
22
  except Exception as e:
23
+ return f"Error: {str(e)}"
24
 
25
+ @tool("get_time")
26
+ def get_time():
27
+ """Returns the current system time."""
28
+ from datetime import datetime
29
+ return datetime.now().strftime("%H:%M:%S")
30
 
31
+ tools = {"simple_calculator": simple_calculator, "get_time": get_time}
32
+
33
+ # 2. LOAD REASONING ENGINE
34
  model_id = "AshokGakr/model-tiny"
35
  tokenizer = AutoTokenizer.from_pretrained(model_id)
36
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
37
 
38
+ # 3. DEFINE THE AGENT STATE
39
+ class AgentState(TypedDict):
40
+ # Source 828: StateGraph acts as the system's real-time workflow tracker.
41
+ messages: list[dict]
42
+ next_step: str
 
 
 
 
 
 
 
 
43
 
44
+ # 4. AGENT LOGIC NODES
45
+ def call_model(state: AgentState):
46
+ # 1. Apply Chat Template to format the history
47
+ # This prepares the context for the reasoning engine [5].
48
+ inputs = tokenizer.apply_chat_template(
49
+ state['messages'],
50
+ add_generation_prompt=True,
51
+ return_tensors="pt"
52
+ ).to(model.device)
53
+
54
+ # 2. FIX: Unpack the inputs using ** to pass tensors correctly
55
+ # This prevents the KeyError: 'shape' by giving generate the specific tensors it needs.
56
+ output_ids = model.generate(
57
+ **inputs, # <--- CRITICAL FIX: Unpack the dictionary
58
+ max_new_tokens=256,
59
+ do_sample=True,
60
+ temperature=0.7
61
+ )
62
+
63
+ # 3. Decode only the newly generated tokens (skipping the original prompt)
64
+ # inputs['input_ids'].shape[-1] provides the length of the input tokens.
65
+ new_tokens = output_ids[inputs['input_ids'].shape[-1]:]
66
+ response = tokenizer.decode(new_tokens, skip_special_tokens=True)
67
+
68
+ # Identify if a tool needs to be called [1, 6]
69
+ action_match = re.search(r"Action:\s*(\w+)", response)
70
+
71
+ return {
72
+ "messages": state['messages'] + [{"role": "assistant", "content": response}],
73
+ "next_step": action_match.group(1) if action_match else "end"
74
+ }
75
 
76
+ def execute_tool(state: AgentState):
77
+ tool_name = state['next_step']
78
+ last_message = state['messages'][-1]['content']
79
+
80
+ # Parse input (simplified for this model-tiny example)
81
+ input_match = re.search(r"Action Input:\s*(.*)", last_message)
82
+ arg = input_match.group(1).strip() if input_match else ""
83
 
84
+ observation = tools[tool_name].run(arg)
85
+ return {"messages": state['messages'] + [{"role": "user", "content": f"Observation: {observation}"}]}
86
+
87
+ # 5. CONSTRUCT THE GRAPH
88
+ # Source 96: Nodes represent actions; edges define the control flow.
89
+ workflow = StateGraph(AgentState)
90
+ workflow.add_node("agent", call_model)
91
+ workflow.add_node("tools", execute_tool)
 
 
92
 
93
+ workflow.set_entry_point("agent")
94
+ workflow.add_conditional_edges(
95
+ "agent",
96
+ lambda x: "tools" if x["next_step"] in tools else "end",
97
+ {"tools": "tools", "end": END}
98
+ )
99
+ workflow.add_edge("tools", "agent") # Create the ReAct Loop cycle [10]
100
 
101
+ agent_app = workflow.compile()
 
 
102
 
103
+ @app.route('/chat', methods=['POST'])
104
+ def chat():
105
+ user_input = request.json.get("message")
106
+ # Execute the graph [5, 11]
107
+ inputs = {"messages": [{"role": "user", "content": user_input}]}
108
+
109
+ def run():
110
+ for output in agent_app.stream(inputs):
111
+ for key, value in output.items():
112
+ # Stream the latest message content to the UI [12]
113
+ yield value['messages'][-1]['content'] + "\n"
 
 
 
 
 
114
 
115
+ return Response(run(), mimetype='text/plain')
116
 
117
  if __name__ == '__main__':
118
  app.run(host='0.0.0.0', port=7860)