Davit6174 commited on
Commit
82faff4
·
verified ·
1 Parent(s): d56b17d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -17
app.py CHANGED
@@ -5,11 +5,11 @@ import inspect
5
  import pandas as pd
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
  import torch
8
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
  from langchain_community.llms import HuggingFacePipeline
10
  from tools import tools
11
  from langchain_core.messages import HumanMessage
12
- from langgraph.prebuilt import ToolExecutor, chat_agent_executor
 
13
 
14
 
15
  # (Keep Constants as is)
@@ -27,28 +27,47 @@ class BasicAgent:
27
  print(f"Agent returning fixed answer: {fixed_answer}")
28
  return fixed_answer
29
 
30
- class ZephyrAgent:
31
  def __init__(self):
32
- print("Initializing local Zephyr model with tools...")
 
 
 
33
 
34
- tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
35
- model = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/zephyr-7b-beta", torch_dtype=torch.float16, device_map="auto")
36
 
37
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512)
38
- llm = HuggingFacePipeline(pipeline=pipe)
39
 
40
- # Wrap tools
41
- tool_executor = ToolExecutor(tools)
42
- self.agent_executor = chat_agent_executor.create_chat_agent_executor(llm=llm, tools=tools)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  def __call__(self, question: str) -> str:
 
45
  try:
46
- message = HumanMessage(content=question)
47
- response = self.agent_executor.invoke({"messages": [message]})
48
- return response.content
49
  except Exception as e:
50
- print(f"Tool-augmented ZephyrAgent error: {e}")
51
- return "⚠️ Agent failed to answer due to tool or model error."
52
 
53
  def run_and_submit_all( profile: gr.OAuthProfile | None):
54
  """
@@ -211,7 +230,7 @@ with gr.Blocks() as demo:
211
 
212
  def test_agent_response(question: str) -> str:
213
  # agent = BasicAgent()
214
- agent = ZephyrAgent()
215
  return agent(question)
216
 
217
  test_button.click(fn=test_agent_response, inputs=question_input, outputs=answer_output)
 
5
  import pandas as pd
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
  import torch
 
8
  from langchain_community.llms import HuggingFacePipeline
9
  from tools import tools
10
  from langchain_core.messages import HumanMessage
11
+ from langgraph.prebuilt import ToolNode, create_react_agent
12
+ from langgraph.graph import StateGraph, END
13
 
14
 
15
  # (Keep Constants as is)
 
27
  print(f"Agent returning fixed answer: {fixed_answer}")
28
  return fixed_answer
29
 
30
+ class LangGraphAgent:
31
  def __init__(self):
32
+ print("Initializing LangGraphAgent...")
33
+ model_id = "HuggingFaceH4/zephyr-7b-beta"
34
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
35
+ self.model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
36
 
37
+ pipe = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer, max_new_tokens=512)
38
+ self.llm = HuggingFacePipeline(pipeline=pipe)
39
 
40
+ self.graph = self._build_graph()
 
41
 
42
+ def _build_graph(self):
43
+ agent_node = create_react_agent(model=self.llm, tools=tools)
44
+ tool_node = ToolNode(tools)
45
+
46
+ def run_agent_node(state):
47
+ question = state["question"]
48
+ messages = [{"role": "user", "content": question}]
49
+ result = agent_node.invoke({"messages": messages})
50
+ return {"messages": result["messages"]}
51
+
52
+ def run_tool_node(state):
53
+ return tool_node.invoke({"messages": state["messages"]})
54
+
55
+ builder = StateGraph(input_schema={"question": str})
56
+ builder.add_node("agent", run_agent_node)
57
+ builder.add_node("tools", run_tool_node)
58
+ builder.set_entry_point("agent")
59
+ builder.add_edge("agent", "tools")
60
+ builder.add_edge("tools", END)
61
+ return builder.compile()
62
 
63
  def __call__(self, question: str) -> str:
64
+ print(f"LangGraphAgent processing: {question[:50]}...")
65
  try:
66
+ output = self.graph.invoke({"question": question})
67
+ return output["messages"][-1].content
 
68
  except Exception as e:
69
+ print(f"LangGraphAgent error: {e}")
70
+ return "⚠️ Error during LangGraph agent processing."
71
 
72
  def run_and_submit_all( profile: gr.OAuthProfile | None):
73
  """
 
230
 
231
  def test_agent_response(question: str) -> str:
232
  # agent = BasicAgent()
233
+ agent = LangGraphAgent()
234
  return agent(question)
235
 
236
  test_button.click(fn=test_agent_response, inputs=question_input, outputs=answer_output)