Davit6174 commited on
Commit
8b94f0b
·
verified ·
1 Parent(s): f591129

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -28
app.py CHANGED
@@ -10,12 +10,19 @@ from tools import tools
10
  from langchain_core.messages import HumanMessage
11
  from langgraph.prebuilt import ToolNode, create_react_agent
12
  from langgraph.graph import StateGraph, END
 
 
 
13
 
14
 
15
  # (Keep Constants as is)
16
  # --- Constants ---
17
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
18
 
 
 
 
 
19
  # --- Basic Agent Definition ---
20
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
21
  class BasicAgent:
@@ -29,44 +36,49 @@ class BasicAgent:
29
 
30
  class LangGraphAgent:
31
  def __init__(self):
32
- print("Initializing LangGraphAgent...")
33
  model_id = "HuggingFaceH4/zephyr-7b-beta"
 
34
  self.tokenizer = AutoTokenizer.from_pretrained(model_id)
35
- self.model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
 
 
 
36
 
37
- pipe = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer, max_new_tokens=512)
38
- self.llm = HuggingFacePipeline(pipeline=pipe)
 
 
 
 
39
 
40
- self.graph = self._build_graph()
 
41
 
42
- def _build_graph(self):
43
- agent_node = create_react_agent(model=self.llm, tools=tools)
44
- tool_node = ToolNode(tools)
45
 
46
- def run_agent_node(state):
47
- return {"messages": agent_node.invoke(state)["messages"]}
 
 
 
48
 
49
- def run_tool_node(state):
50
- return tool_node.invoke({"messages": state["messages"]})
51
 
52
- builder = StateGraph(input_schema={"messages": list})
53
- builder.add_node("agent", run_agent_node)
54
- builder.add_node("tools", run_tool_node)
55
- builder.set_entry_point("agent")
56
- builder.add_edge("agent", "tools")
57
- builder.add_edge("tools", END)
58
- return builder.compile()
 
 
59
 
60
  def __call__(self, question: str) -> str:
61
- print(f"LangGraphAgent processing: {question[:50]}...")
62
- try:
63
- messages = [{"role": "user", "content": question}]
64
- output = self.graph.invoke({"messages": messages})
65
- print("LangGraphAgent result:", output)
66
- return output["messages"][-1]["content"]
67
- except Exception as e:
68
- print(f"LangGraphAgent error: {e}")
69
- return "⚠️ Error during LangGraph agent processing."
70
 
71
  def run_and_submit_all( profile: gr.OAuthProfile | None):
72
  """
 
10
  from langchain_core.messages import HumanMessage
11
  from langgraph.prebuilt import ToolNode, create_react_agent
12
  from langgraph.graph import StateGraph, END
13
+ from langchain.agents import tool
14
+ from langchain_core.runnables import Runnable
15
+ from langchain_core.tools import Tool
16
 
17
 
18
  # (Keep Constants as is)
19
  # --- Constants ---
20
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
21
 
22
+ @tool
23
+ def dummy_tool(query: str) -> str:
24
+ return f"You asked me to look up: {query}"
25
+
26
  # --- Basic Agent Definition ---
27
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
28
  class BasicAgent:
 
36
 
37
  class LangGraphAgent:
38
  def __init__(self):
 
39
  model_id = "HuggingFaceH4/zephyr-7b-beta"
40
+
41
  self.tokenizer = AutoTokenizer.from_pretrained(model_id)
42
+ self.model = AutoModelForCausalLM.from_pretrained(
43
+ model_id,
44
+ torch_dtype="auto",
45
+ )
46
 
47
+ self.pipe = pipeline(
48
+ "text-generation",
49
+ model=self.model,
50
+ tokenizer=self.tokenizer,
51
+ return_full_text=False
52
+ )
53
 
54
+ self.tools = [dummy_tool] # Add more tools later if needed
55
+ self.tool_node = ToolNode(tools=self.tools)
56
 
57
+ # LangGraph states are dicts with a "messages" key
58
+ builder = StateGraph()
 
59
 
60
+ builder.add_node("invoke_model", self.invoke_model)
61
+ builder.add_node("tools", self.tool_node)
62
+ builder.set_entry_point("invoke_model")
63
+ builder.add_edge("invoke_model", "tools")
64
+ builder.add_edge("tools", END)
65
 
66
+ self.app = builder.compile()
 
67
 
68
+ def invoke_model(self, state: dict) -> dict:
69
+ messages = state["messages"]
70
+ if isinstance(messages, str):
71
+ messages = [{"role": "user", "content": messages}]
72
+ prompt = self.tokenizer.apply_chat_template(
73
+ messages, tokenize=False, add_generation_prompt=True
74
+ )
75
+ response = self.pipe(prompt, max_new_tokens=256, temperature=0.7)[0]["generated_text"]
76
+ return {"messages": messages + [{"role": "assistant", "content": response.strip()}]}
77
 
78
  def __call__(self, question: str) -> str:
79
+ result = self.app.invoke({"messages": [{"role": "user", "content": question}]})
80
+ messages = result["messages"]
81
+ return messages[-1]["content"] if messages else " No response generated."
 
 
 
 
 
 
82
 
83
  def run_and_submit_all( profile: gr.OAuthProfile | None):
84
  """