mohantest commited on
Commit
57111d2
·
verified ·
1 Parent(s): 787d663

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +21 -11
agent.py CHANGED
@@ -3,12 +3,13 @@ import hashlib
3
  import json
4
  import logging
5
  from smolagents import CodeAgent, tool
 
6
  from huggingface_hub import InferenceClient
7
 
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
- # Cache for answers (persists between runs)
12
  CACHE_FILE = "answer_cache.json"
13
  if os.path.exists(CACHE_FILE):
14
  with open(CACHE_FILE) as f:
@@ -36,7 +37,6 @@ def calculator(expression: str) -> str:
36
  if not all(c in allowed_chars for c in expression):
37
  return "Error: Expression contains disallowed characters."
38
  try:
39
- # Restricted eval – only math allowed
40
  result = eval(expression, {"__builtins__": {}}, {})
41
  return str(result)
42
  except Exception as e:
@@ -69,27 +69,37 @@ def web_search(query: str) -> str:
69
  except Exception as e:
70
  return f"Search error: {e}"
71
 
72
- # ---------- Custom model that wraps HF InferenceClient ----------
73
- class CustomHFModel:
74
- def __init__(self, model_id="HuggingFaceH4/zephyr-7b-beta"):
 
75
  self.client = InferenceClient(model=model_id, token=os.getenv("HF_TOKEN"))
76
  self.model_id = model_id
77
 
78
- def __call__(self, messages, **kwargs):
 
 
 
 
 
 
 
 
79
  """
80
- Expected by smolagents: takes a list of messages
81
  (e.g., [{"role": "user", "content": "..."}])
82
  and returns the assistant's reply as a string.
83
  """
84
  response = self.client.chat_completion(
85
  messages=messages,
86
- max_tokens=500,
87
- temperature=0.7,
 
88
  **kwargs
89
  )
90
  return response.choices[0].message.content
91
 
92
- # ---------- Assemble the agent (once, at import) ----------
93
  tools = [calculator]
94
  try:
95
  import duckduckgo_search
@@ -104,7 +114,7 @@ agent = CodeAgent(tools=tools, model=model)
104
  # ---------- The class expected by app.py ----------
105
  class CustomAgent:
106
  def __call__(self, question: str) -> str:
107
- """This method is called for each question."""
108
  q_hash = hashlib.md5(question.encode()).hexdigest()
109
  if q_hash in answer_cache:
110
  logger.info(f"Cache hit for question: {question[:50]}...")
 
3
  import json
4
  import logging
5
  from smolagents import CodeAgent, tool
6
+ from smolagents.models import Model # <-- base class for models
7
  from huggingface_hub import InferenceClient
8
 
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
+ # Cache for answers
13
  CACHE_FILE = "answer_cache.json"
14
  if os.path.exists(CACHE_FILE):
15
  with open(CACHE_FILE) as f:
 
37
  if not all(c in allowed_chars for c in expression):
38
  return "Error: Expression contains disallowed characters."
39
  try:
 
40
  result = eval(expression, {"__builtins__": {}}, {})
41
  return str(result)
42
  except Exception as e:
 
69
  except Exception as e:
70
  return f"Search error: {e}"
71
 
72
+ # ---------- Custom model that inherits from smolagents.Model ----------
73
+ class CustomHFModel(Model):
74
+ def __init__(self, model_id="HuggingFaceH4/zephyr-7b-beta", **kwargs):
75
+ super().__init__(**kwargs)
76
  self.client = InferenceClient(model=model_id, token=os.getenv("HF_TOKEN"))
77
  self.model_id = model_id
78
 
79
+ def generate(
80
+ self,
81
+ messages,
82
+ stop_sequences=None,
83
+ grammar=None,
84
+ max_tokens=500,
85
+ temperature=0.7,
86
+ **kwargs
87
+ ):
88
  """
89
+ Required by smolagents: takes a list of messages
90
  (e.g., [{"role": "user", "content": "..."}])
91
  and returns the assistant's reply as a string.
92
  """
93
  response = self.client.chat_completion(
94
  messages=messages,
95
+ max_tokens=max_tokens,
96
+ temperature=temperature,
97
+ stop=stop_sequences,
98
  **kwargs
99
  )
100
  return response.choices[0].message.content
101
 
102
+ # ---------- Assemble the agent ----------
103
  tools = [calculator]
104
  try:
105
  import duckduckgo_search
 
114
  # ---------- The class expected by app.py ----------
115
  class CustomAgent:
116
  def __call__(self, question: str) -> str:
117
+ """Called for each question."""
118
  q_hash = hashlib.md5(question.encode()).hexdigest()
119
  if q_hash in answer_cache:
120
  logger.info(f"Cache hit for question: {question[:50]}...")