mohantest commited on
Commit
0aa9224
·
verified ·
1 Parent(s): a2f1bed

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +122 -0
agent.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import hashlib
3
+ import json
4
+ import logging
5
+ from smolagents import CodeAgent, tool
6
+ from huggingface_hub import InferenceClient
7
+
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # Cache for answers (persists between runs)
12
+ CACHE_FILE = "answer_cache.json"
13
+ if os.path.exists(CACHE_FILE):
14
+ with open(CACHE_FILE) as f:
15
+ answer_cache = json.load(f)
16
+ else:
17
+ answer_cache = {}
18
+
19
+ def save_cache():
20
+ with open(CACHE_FILE, "w") as f:
21
+ json.dump(answer_cache, f)
22
+
23
+ # ---------- Tools ----------
24
+ @tool
25
+ def calculator(expression: str) -> str:
26
+ """
27
+ Safely evaluate a mathematical expression.
28
+
29
+ Args:
30
+ expression: A string containing a simple arithmetic expression (e.g., '2 + 2').
31
+
32
+ Returns:
33
+ The result as a string, or an error message if the expression is invalid.
34
+ """
35
+ allowed_chars = set("0123456789+-*/(). ")
36
+ if not all(c in allowed_chars for c in expression):
37
+ return "Error: Expression contains disallowed characters."
38
+ try:
39
+ # Restricted eval – only math allowed
40
+ result = eval(expression, {"__builtins__": {}}, {})
41
+ return str(result)
42
+ except Exception as e:
43
+ return f"Error: {e}"
44
+
45
+ @tool
46
+ def web_search(query: str) -> str:
47
+ """
48
+ Search the web for up-to-date information.
49
+
50
+ Args:
51
+ query: The search query string.
52
+
53
+ Returns:
54
+ A string containing up to three search result snippets with titles and URLs,
55
+ or an error message if the search fails.
56
+ """
57
+ try:
58
+ from duckduckgo_search import DDGS
59
+ with DDGS() as ddgs:
60
+ results = list(ddgs.text(query, max_results=3))
61
+ if not results:
62
+ return "No results found."
63
+ snippets = []
64
+ for r in results:
65
+ snippets.append(f"Title: {r['title']}\nBody: {r['body']}\nURL: {r['href']}")
66
+ return "\n\n".join(snippets)
67
+ except ImportError:
68
+ return "Web search tool not available: install duckduckgo-search"
69
+ except Exception as e:
70
+ return f"Search error: {e}"
71
+
72
+ # ---------- Custom model that wraps HF InferenceClient ----------
73
+ class CustomHFModel:
74
+ def __init__(self, model_id="HuggingFaceH4/zephyr-7b-beta"):
75
+ self.client = InferenceClient(model=model_id, token=os.getenv("HF_TOKEN"))
76
+ self.model_id = model_id
77
+
78
+ def __call__(self, messages, **kwargs):
79
+ """
80
+ Expected by smolagents: takes a list of messages
81
+ (e.g., [{"role": "user", "content": "..."}])
82
+ and returns the assistant's reply as a string.
83
+ """
84
+ response = self.client.chat_completion(
85
+ messages=messages,
86
+ max_tokens=500,
87
+ temperature=0.7,
88
+ **kwargs
89
+ )
90
+ return response.choices[0].message.content
91
+
92
+ # ---------- Assemble the agent (once, at import) ----------
93
+ tools = [calculator]
94
+ try:
95
+ import duckduckgo_search
96
+ tools.append(web_search)
97
+ logger.info("Web search tool enabled.")
98
+ except ImportError:
99
+ logger.warning("duckduckgo-search not installed, web_search disabled.")
100
+
101
+ model = CustomHFModel() # you can change the model_id if desired
102
+ agent = CodeAgent(tools=tools, model=model)
103
+
104
+ # ---------- The class expected by app.py ----------
105
+ class CustomAgent:
106
+ def __call__(self, question: str) -> str:
107
+ """This method is called for each question."""
108
+ q_hash = hashlib.md5(question.encode()).hexdigest()
109
+ if q_hash in answer_cache:
110
+ logger.info(f"Cache hit for question: {question[:50]}...")
111
+ return answer_cache[q_hash]
112
+
113
+ logger.info(f"Processing question: {question[:50]}...")
114
+ try:
115
+ answer = agent.run(question)
116
+ except Exception as e:
117
+ logger.error(f"Agent error: {e}")
118
+ answer = f"Error: {e}"
119
+
120
+ answer_cache[q_hash] = answer
121
+ save_cache()
122
+ return answer