File size: 3,453 Bytes
49e52b5
 
 
10e9b7d
49e52b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31243f4
49e52b5
3c4371f
49e52b5
 
e80aab9
49e52b5
 
 
 
 
 
31243f4
49e52b5
 
31243f4
49e52b5
 
 
 
 
 
 
 
 
 
 
 
 
 
eccf8e4
49e52b5
 
 
 
 
 
 
 
 
 
 
7d65c66
49e52b5
 
 
 
 
 
 
 
 
 
 
 
 
 
e80aab9
49e52b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d65c66
49e52b5
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import logging
import hashlib
import json
import os
from smolagents import CodeAgent, tool
from huggingface_hub import InferenceClient

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Cache for answers
CACHE_FILE = "answer_cache.json"
if os.path.exists(CACHE_FILE):
    with open(CACHE_FILE) as f:
        answer_cache = json.load(f)
else:
    answer_cache = {}

def save_cache():
    with open(CACHE_FILE, "w") as f:
        json.dump(answer_cache, f)

# ---------- Tools ----------
@tool
def calculator(expression: str) -> str:
    """
    Safely evaluate a mathematical expression.

    Args:
        expression: A string containing a simple arithmetic expression (e.g., '2 + 2').

    Returns:
        The result as a string, or an error message if the expression is invalid.
    """
    allowed_chars = set("0123456789+-*/(). ")
    if not all(c in allowed_chars for c in expression):
        return "Error: Expression contains disallowed characters."
    try:
        result = eval(expression, {"__builtins__": {}}, {})
        return str(result)
    except Exception as e:
        return f"Error: {e}"

@tool
def web_search(query: str) -> str:
    """
    Search the web for up-to-date information.

    Args:
        query: The search query string.

    Returns:
        A string containing up to three search result snippets with titles and URLs,
        or an error message if the search fails.
    """
    try:
        from duckduckgo_search import DDGS
        with DDGS() as ddgs:
            results = list(ddgs.text(query, max_results=3))
        if not results:
            return "No results found."
        snippets = []
        for r in results:
            snippets.append(f"Title: {r['title']}\nBody: {r['body']}\nURL: {r['href']}")
        return "\n\n".join(snippets)
    except ImportError:
        return "Web search tool not available: install duckduckgo-search"
    except Exception as e:
        return f"Search error: {e}"

# ---------- Custom model ----------
class CustomHFModel:
    def __init__(self, model_id="HuggingFaceH4/zephyr-7b-beta"):
        self.client = InferenceClient(model=model_id, token=os.getenv("HF_TOKEN"))
        self.model_id = model_id

    def __call__(self, messages, **kwargs):
        response = self.client.chat_completion(
            messages=messages,
            max_tokens=500,
            temperature=0.7,
            **kwargs
        )
        return response.choices[0].message.content

# ---------- Assemble agent ----------
tools = [calculator]
try:
    import duckduckgo_search
    tools.append(web_search)
    logger.info("Web search tool enabled.")
except ImportError:
    logger.warning("duckduckgo-search not installed, web_search disabled.")

model = CustomHFModel()
agent = CodeAgent(tools=tools, model=model)

# ---------- Main entry point (called by app.py) ----------
def solve(question: str) -> str:
    """This function must be named 'solve' because app.py imports it."""
    q_hash = hashlib.md5(question.encode()).hexdigest()
    if q_hash in answer_cache:
        logger.info(f"Cache hit for question: {question[:50]}...")
        return answer_cache[q_hash]

    logger.info(f"Processing question: {question[:50]}...")
    try:
        answer = agent.run(question)
    except Exception as e:
        logger.error(f"Agent error: {e}")
        answer = f"Error: {e}"

    answer_cache[q_hash] = answer
    save_cache()
    return answer