Davit6174 commited on
Commit
25b1c2d
·
verified ·
1 Parent(s): 75ec2d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -38
app.py CHANGED
@@ -31,55 +31,49 @@ class BasicAgent:
31
  print(f"Agent returning fixed answer: {fixed_answer}")
32
  return fixed_answer
33
 
34
- class ZephyrPipelineModel:
35
  def __init__(self):
36
- hf_token = os.getenv("HF_TOKEN")
37
- if not hf_token:
38
- raise ValueError("HF_TOKEN environment variable not set.")
39
-
40
- self.client = InferenceClient(
41
- model="HuggingFaceH4/zephyr-7b-beta",
42
- token=hf_token,
43
- )
 
 
 
 
 
 
 
 
44
 
45
- def __call__(self, prompt: str) -> str:
46
- messages = [{"role": "user", "content": prompt}]
47
  try:
48
- completion = self.client.chat.completions.create(
49
- model="HuggingFaceH4/zephyr-7b-beta",
50
- messages=messages,
51
- max_tokens=512,
52
- temperature=0.7,
53
- )
54
- return completion.choices[0].message.content
55
  except Exception as e:
56
- return f"❌ Inference failed: {str(e)}"
 
 
57
 
58
  class LangGraphAgent:
59
  def __init__(self):
60
- # Add token from environment
61
- hf_token = os.environ.get("HF_TOKEN")
62
- if not hf_token:
63
- raise ValueError("HF_TOKEN is not set.")
64
-
65
- # ✅ Restore this structure with token
66
- self.model = ChatHuggingFace.from_model_id(
67
- model_id="HuggingFaceH4/zephyr-7b-beta",
68
- task="text-generation",
69
- model_kwargs={
70
- "temperature": 0.7,
71
- "max_new_tokens": 512
72
- },
73
- huggingfacehub_api_token=hf_token,
74
- )
75
 
76
- # ✅ Simple LangGraph setup
77
  builder = StateGraph()
78
 
79
  def call_model(state):
80
  messages = state.get("messages", [])
81
- response = self.model.invoke(messages)
82
- return {"messages": messages + [response]}
 
 
 
 
83
 
84
  builder.add_node("chat", call_model)
85
  builder.set_entry_point("chat")
@@ -92,7 +86,8 @@ class LangGraphAgent:
92
  "messages": [HumanMessage(content=question)]
93
  })
94
 
95
- for msg in reversed(result.get("messages", [])):
 
96
  if isinstance(msg, AIMessage):
97
  return msg.content
98
 
 
31
  print(f"Agent returning fixed answer: {fixed_answer}")
32
  return fixed_answer
33
 
34
+ class ZephyrAPI:
35
  def __init__(self):
36
+ self.api_url = "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
37
+ self.headers = {
38
+ "Authorization": f"Bearer {os.getenv('HF_TOKEN')}"
39
+ }
40
+ print("ZephyrAPI initialized using Inference API.")
41
+
42
+ def __call__(self, question: str) -> str:
43
+ prompt = f"<|system|>\nYou are a helpful assistant.\n<|user|>\n{question}\n<|assistant|>\n"
44
+ payload = {
45
+ "inputs": prompt,
46
+ "parameters": {
47
+ "max_new_tokens": 256,
48
+ "temperature": 0.7,
49
+ "top_p": 0.9,
50
+ }
51
+ }
52
 
 
 
53
  try:
54
+ response = requests.post(self.api_url, headers=self.headers, json=payload, timeout=60)
55
+ response.raise_for_status()
56
+ result = response.json()
57
+ return result[0]["generated_text"].split("<|assistant|>")[-1].strip()
 
 
 
58
  except Exception as e:
59
+ print(f"Error: {e}")
60
+ return "⚠️ Model could not respond. Check API access or token."
61
+
62
 
63
  class LangGraphAgent:
64
  def __init__(self):
65
+ self.model = ZephyrAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
 
67
  builder = StateGraph()
68
 
69
  def call_model(state):
70
  messages = state.get("messages", [])
71
+ user_msg = next((m for m in messages if isinstance(m, HumanMessage)), None)
72
+ if not user_msg:
73
+ return {"messages": messages + [AIMessage(content="❌ No user input found.")]}
74
+
75
+ response = self.model(user_msg.content)
76
+ return {"messages": messages + [AIMessage(content=response)]}
77
 
78
  builder.add_node("chat", call_model)
79
  builder.set_entry_point("chat")
 
86
  "messages": [HumanMessage(content=question)]
87
  })
88
 
89
+ messages = result.get("messages", [])
90
+ for msg in reversed(messages):
91
  if isinstance(msg, AIMessage):
92
  return msg.content
93