Davit6174 commited on
Commit
43440dd
·
verified ·
1 Parent(s): b720dc6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -36
app.py CHANGED
@@ -7,12 +7,13 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
  import torch
8
  from langchain_community.llms import HuggingFacePipeline
9
  from tools import tools
10
- from langchain_core.messages import HumanMessage
11
  from langgraph.prebuilt import ToolNode, create_react_agent
12
  from langgraph.graph import StateGraph, END
13
  from langchain.agents import tool
14
  from langchain_core.runnables import Runnable
15
  from langchain_core.tools import Tool
 
16
 
17
 
18
  # (Keep Constants as is)
@@ -32,49 +33,41 @@ class BasicAgent:
32
 
33
  class LangGraphAgent:
34
  def __init__(self):
35
- model_id = "HuggingFaceH4/zephyr-7b-beta"
36
-
37
- self.tokenizer = AutoTokenizer.from_pretrained(model_id)
38
- self.model = AutoModelForCausalLM.from_pretrained(
39
- model_id,
40
- torch_dtype="auto",
41
- )
42
-
43
- self.pipe = pipeline(
44
- "text-generation",
45
- model=self.model,
46
- tokenizer=self.tokenizer,
47
- return_full_text=False
48
  )
49
 
50
- self.tools = [tools] # Add more tools later if needed
51
- self.tool_node = ToolNode(tools=self.tools)
52
-
53
- # LangGraph states are dicts with a "messages" key
54
  builder = StateGraph()
55
 
56
- builder.add_node("invoke_model", self.invoke_model)
57
- builder.add_node("tools", self.tool_node)
58
- builder.set_entry_point("invoke_model")
59
- builder.add_edge("invoke_model", "tools")
60
- builder.add_edge("tools", END)
61
 
62
- self.app = builder.compile()
 
 
63
 
64
- def invoke_model(self, state: dict) -> dict:
65
- messages = state["messages"]
66
- if isinstance(messages, str):
67
- messages = [{"role": "user", "content": messages}]
68
- prompt = self.tokenizer.apply_chat_template(
69
- messages, tokenize=False, add_generation_prompt=True
70
- )
71
- response = self.pipe(prompt, max_new_tokens=256, temperature=0.7)[0]["generated_text"]
72
- return {"messages": messages + [{"role": "assistant", "content": response.strip()}]}
73
 
74
  def __call__(self, question: str) -> str:
75
- result = self.app.invoke({"messages": [{"role": "user", "content": question}]})
76
- messages = result["messages"]
77
- return messages[-1]["content"] if messages else "❌ No response generated."
 
 
 
 
 
 
 
 
 
78
 
79
  def run_and_submit_all( profile: gr.OAuthProfile | None):
80
  """
 
7
  import torch
8
  from langchain_community.llms import HuggingFacePipeline
9
  from tools import tools
10
+ from langchain_core.messages import HumanMessage, AIMessage
11
  from langgraph.prebuilt import ToolNode, create_react_agent
12
  from langgraph.graph import StateGraph, END
13
  from langchain.agents import tool
14
  from langchain_core.runnables import Runnable
15
  from langchain_core.tools import Tool
16
+ from langchain_community.chat_models import ChatHuggingFace
17
 
18
 
19
  # (Keep Constants as is)
 
33
 
34
  class LangGraphAgent:
35
  def __init__(self):
36
+ # Load Zephyr with correct config
37
+ self.model = ChatHuggingFace.from_model_id(
38
+ model_id="HuggingFaceH4/zephyr-7b-beta",
39
+ task="text-generation",
40
+ model_kwargs={"temperature": 0.7, "max_new_tokens": 512}
 
 
 
 
 
 
 
 
41
  )
42
 
43
+ # Define a simple graph with just one node
 
 
 
44
  builder = StateGraph()
45
 
46
+ def call_model(state):
47
+ messages = state.get("messages", [])
48
+ response = self.model.invoke(messages)
49
+ return {"messages": messages + [response]}
 
50
 
51
+ builder.add_node("chat", call_model)
52
+ builder.set_entry_point("chat")
53
+ builder.add_edge("chat", END)
54
 
55
+ # Compile the graph
56
+ self.graph = builder.compile()
 
 
 
 
 
 
 
57
 
58
  def __call__(self, question: str) -> str:
59
+ # Wrap input in HumanMessage format
60
+ result = self.graph.invoke({
61
+ "messages": [HumanMessage(content=question)]
62
+ })
63
+
64
+ # Extract final response
65
+ messages = result.get("messages", [])
66
+ for msg in reversed(messages):
67
+ if isinstance(msg, AIMessage):
68
+ return msg.content
69
+
70
+ return "❌ No response generated."
71
 
72
  def run_and_submit_all( profile: gr.OAuthProfile | None):
73
  """