mohantest's picture
new changes
49e52b5 verified
raw
history blame
3.45 kB
import logging
import hashlib
import json
import os
from smolagents import CodeAgent, tool
from huggingface_hub import InferenceClient
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Cache for answers
CACHE_FILE = "answer_cache.json"
if os.path.exists(CACHE_FILE):
with open(CACHE_FILE) as f:
answer_cache = json.load(f)
else:
answer_cache = {}
def save_cache():
with open(CACHE_FILE, "w") as f:
json.dump(answer_cache, f)
# ---------- Tools ----------
@tool
def calculator(expression: str) -> str:
"""
Safely evaluate a mathematical expression.
Args:
expression: A string containing a simple arithmetic expression (e.g., '2 + 2').
Returns:
The result as a string, or an error message if the expression is invalid.
"""
allowed_chars = set("0123456789+-*/(). ")
if not all(c in allowed_chars for c in expression):
return "Error: Expression contains disallowed characters."
try:
result = eval(expression, {"__builtins__": {}}, {})
return str(result)
except Exception as e:
return f"Error: {e}"
@tool
def web_search(query: str) -> str:
"""
Search the web for up-to-date information.
Args:
query: The search query string.
Returns:
A string containing up to three search result snippets with titles and URLs,
or an error message if the search fails.
"""
try:
from duckduckgo_search import DDGS
with DDGS() as ddgs:
results = list(ddgs.text(query, max_results=3))
if not results:
return "No results found."
snippets = []
for r in results:
snippets.append(f"Title: {r['title']}\nBody: {r['body']}\nURL: {r['href']}")
return "\n\n".join(snippets)
except ImportError:
return "Web search tool not available: install duckduckgo-search"
except Exception as e:
return f"Search error: {e}"
# ---------- Custom model ----------
class CustomHFModel:
def __init__(self, model_id="HuggingFaceH4/zephyr-7b-beta"):
self.client = InferenceClient(model=model_id, token=os.getenv("HF_TOKEN"))
self.model_id = model_id
def __call__(self, messages, **kwargs):
response = self.client.chat_completion(
messages=messages,
max_tokens=500,
temperature=0.7,
**kwargs
)
return response.choices[0].message.content
# ---------- Assemble agent ----------
tools = [calculator]
try:
import duckduckgo_search
tools.append(web_search)
logger.info("Web search tool enabled.")
except ImportError:
logger.warning("duckduckgo-search not installed, web_search disabled.")
model = CustomHFModel()
agent = CodeAgent(tools=tools, model=model)
# ---------- Main entry point (called by app.py) ----------
def solve(question: str) -> str:
"""This function must be named 'solve' because app.py imports it."""
q_hash = hashlib.md5(question.encode()).hexdigest()
if q_hash in answer_cache:
logger.info(f"Cache hit for question: {question[:50]}...")
return answer_cache[q_hash]
logger.info(f"Processing question: {question[:50]}...")
try:
answer = agent.run(question)
except Exception as e:
logger.error(f"Agent error: {e}")
answer = f"Error: {e}"
answer_cache[q_hash] = answer
save_cache()
return answer