File size: 3,453 Bytes
49e52b5 10e9b7d 49e52b5 31243f4 49e52b5 3c4371f 49e52b5 e80aab9 49e52b5 31243f4 49e52b5 31243f4 49e52b5 eccf8e4 49e52b5 7d65c66 49e52b5 e80aab9 49e52b5 7d65c66 49e52b5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | import logging
import hashlib
import json
import os
from smolagents import CodeAgent, tool
from huggingface_hub import InferenceClient
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Cache for answers
CACHE_FILE = "answer_cache.json"
if os.path.exists(CACHE_FILE):
with open(CACHE_FILE) as f:
answer_cache = json.load(f)
else:
answer_cache = {}
def save_cache():
with open(CACHE_FILE, "w") as f:
json.dump(answer_cache, f)
# ---------- Tools ----------
@tool
def calculator(expression: str) -> str:
"""
Safely evaluate a mathematical expression.
Args:
expression: A string containing a simple arithmetic expression (e.g., '2 + 2').
Returns:
The result as a string, or an error message if the expression is invalid.
"""
allowed_chars = set("0123456789+-*/(). ")
if not all(c in allowed_chars for c in expression):
return "Error: Expression contains disallowed characters."
try:
result = eval(expression, {"__builtins__": {}}, {})
return str(result)
except Exception as e:
return f"Error: {e}"
@tool
def web_search(query: str) -> str:
"""
Search the web for up-to-date information.
Args:
query: The search query string.
Returns:
A string containing up to three search result snippets with titles and URLs,
or an error message if the search fails.
"""
try:
from duckduckgo_search import DDGS
with DDGS() as ddgs:
results = list(ddgs.text(query, max_results=3))
if not results:
return "No results found."
snippets = []
for r in results:
snippets.append(f"Title: {r['title']}\nBody: {r['body']}\nURL: {r['href']}")
return "\n\n".join(snippets)
except ImportError:
return "Web search tool not available: install duckduckgo-search"
except Exception as e:
return f"Search error: {e}"
# ---------- Custom model ----------
class CustomHFModel:
def __init__(self, model_id="HuggingFaceH4/zephyr-7b-beta"):
self.client = InferenceClient(model=model_id, token=os.getenv("HF_TOKEN"))
self.model_id = model_id
def __call__(self, messages, **kwargs):
response = self.client.chat_completion(
messages=messages,
max_tokens=500,
temperature=0.7,
**kwargs
)
return response.choices[0].message.content
# ---------- Assemble agent ----------
tools = [calculator]
try:
import duckduckgo_search
tools.append(web_search)
logger.info("Web search tool enabled.")
except ImportError:
logger.warning("duckduckgo-search not installed, web_search disabled.")
model = CustomHFModel()
agent = CodeAgent(tools=tools, model=model)
# ---------- Main entry point (called by app.py) ----------
def solve(question: str) -> str:
"""This function must be named 'solve' because app.py imports it."""
q_hash = hashlib.md5(question.encode()).hexdigest()
if q_hash in answer_cache:
logger.info(f"Cache hit for question: {question[:50]}...")
return answer_cache[q_hash]
logger.info(f"Processing question: {question[:50]}...")
try:
answer = agent.run(question)
except Exception as e:
logger.error(f"Agent error: {e}")
answer = f"Error: {e}"
answer_cache[q_hash] = answer
save_cache()
return answer |