Davit6174 commited on
Commit
31be34e
·
verified ·
1 Parent(s): 91a6778

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -8
app.py CHANGED
@@ -33,23 +33,32 @@ class BasicAgent:
33
  print(f"Agent returning fixed answer: {fixed_answer}")
34
  return fixed_answer
35
 
 
 
 
 
 
 
 
 
 
 
36
  class LangGraphAgent:
37
  def __init__(self):
38
  # Load Zephyr with correct config
39
- self.model = ChatHuggingFace.from_model_id(
40
- model_id="HuggingFaceH4/zephyr-7b-beta",
41
- task="text-generation",
42
- model_kwargs={"temperature": 0.7, "max_new_tokens": 512},
43
- huggingfacehub_api_token=hf_token
44
- )
45
 
46
  # Define a simple graph with just one node
47
  builder = StateGraph()
48
 
49
  def call_model(state):
50
  messages = state.get("messages", [])
51
- response = self.model.invoke(messages)
52
- return {"messages": messages + [response]}
 
 
 
 
53
 
54
  builder.add_node("chat", call_model)
55
  builder.set_entry_point("chat")
 
33
  print(f"Agent returning fixed answer: {fixed_answer}")
34
  return fixed_answer
35
 
36
+ class ZephyrPipelineModel:
37
+ def __init__(self):
38
+ self.pipe = pipeline("text-generation", model="HuggingFaceH4/zephyr-7b-beta")
39
+
40
+ def __call__(self, prompt: str) -> str:
41
+ # Zephyr wants a list of messages
42
+ messages = [{"role": "user", "content": prompt}]
43
+ result = self.pipe(messages, max_new_tokens=512)
44
+ return result[0]["generated_text"]
45
+
46
  class LangGraphAgent:
47
  def __init__(self):
48
  # Load Zephyr with correct config
49
+ self.model = ZephyrPipelineModel()
 
 
 
 
 
50
 
51
  # Define a simple graph with just one node
52
  builder = StateGraph()
53
 
54
  def call_model(state):
55
  messages = state.get("messages", [])
56
+ user_msg = next((m for m in messages if isinstance(m, HumanMessage)), None)
57
+ if not user_msg:
58
+ return {"messages": messages + [AIMessage(content="❌ No user input found.")]}
59
+
60
+ response = self.model(user_msg.content)
61
+ return {"messages": messages + [AIMessage(content=response)]}
62
 
63
  builder.add_node("chat", call_model)
64
  builder.set_entry_point("chat")