File size: 4,365 Bytes
4c168bb
 
 
b8811b7
 
4c168bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f161551
4c168bb
 
 
 
 
 
 
 
 
 
 
 
 
 
b8811b7
 
4c168bb
b8811b7
4c168bb
b8811b7
4c168bb
 
 
 
b8811b7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import re
import ast
import operator
import gradio as gr

# LLM: small, CPU-friendly
from transformers import pipeline

LLM = None  # lazy-load to speed up app boot


# ---------- Agent 1: Planner (decides the next action) ----------
def planner_agent(user_msg: str) -> dict:
    """
    Very small heuristic planner:
    - If it's a simple math question/expression => CALCULATE
    - Else => ANSWER (LLM)
    """
    text = user_msg.strip().lower()

    # If it looks like a calculation
    math_like = bool(re.search(r"[0-9][0-9\.\s\+\-\*\/\%\^\(\)]*[0-9\)]", text))
    asks_calc = any(k in text for k in ["calc", "calculate", "evaluate", "what is", "what's"])

    if math_like or asks_calc:
        # Avoid "story" or "explain" with numbers
        if not any(k in text for k in ["story", "poem", "explain why", "compare"]):
            return {"action": "CALCULATE", "reason": "Looks like a numeric expression or calculation request."}

    return {"action": "ANSWER", "reason": "General question; better handled by the LLM."}


# ---------- Tiny safe calculator (tool used by Solver) ----------
# Safe AST-based evaluator supporting +,-,*,/,**,%,// and parentheses
OPS = {
    ast.Add: operator.add,
    ast.Sub: operator.sub,
    ast.Mult: operator.mul,
    ast.Div: operator.truediv,
    ast.Pow: operator.pow,
    ast.Mod: operator.mod,
    ast.FloorDiv: operator.floordiv,
    ast.USub: operator.neg,
    ast.UAdd: operator.pos,
}

def _eval(node):
    if isinstance(node, ast.Num):  # py<3.8 fallback
        return node.n
    if isinstance(node, ast.Constant):
        if isinstance(node.value, (int, float)):
            return node.value
        raise ValueError("Only numeric constants allowed.")
    if isinstance(node, ast.BinOp):
        left = _eval(node.left)
        right = _eval(node.right)
        op_type = type(node.op)
        if op_type in OPS:
            return OPS[op_type](left, right)
        raise ValueError("Unsupported operator.")
    if isinstance(node, ast.UnaryOp) and type(node.op) in OPS:
        return OPS[type(node.op)](_eval(node.operand))
    if isinstance(node, ast.Expression):
        return _eval(node.body)
    raise ValueError("Unsupported expression.")

def safe_calculate(expr: str) -> str:
    try:
        tree = ast.parse(expr, mode="eval")
        val = _eval(tree)
        return str(val)
    except Exception as e:
        return f"Sorry, I couldn't calculate that: {e}"


# ---------- Agent 2: Writer (LLM or Tool) ----------
def get_llm():
    global LLM
    if LLM is None:
        # flan-t5-small is ~80M params, OK on CPU Basic
        LLM = pipeline("text2text-generation", model="google/flan-t5-small")
    return LLM

def writer_agent(user_msg: str, plan: dict) -> str:
    if plan["action"] == "CALCULATE":
        # Extract the most likely expression from the message
        # Keep digits, ops, dots, spaces, and parentheses
        expr = "".join(ch for ch in user_msg if ch in "0123456789.+-*/()%^ //")
        # Clean up accidental double spaces
        expr = re.sub(r"\s+", "", expr.replace("^", "**"))
        if not expr:
            # fallback to LLM if no expression found
            llm = get_llm()
            prompt = f"Answer briefly:\n\nQuestion: {user_msg}\nAnswer:"
            out = llm(prompt, max_new_tokens=128, do_sample=True)[0]["generated_text"]
            return out.strip()
        return safe_calculate(expr)

    # ANSWER with LLM
    llm = get_llm()
    prompt = (
        "You are a concise, friendly assistant. "
        "Answer clearly in 1-4 sentences.\n\n"
        f"Question: {user_msg}\nAnswer:"
    )
    out = llm(prompt, max_new_tokens=192, do_sample=True, temperature=0.6, top_p=0.95)[0]["generated_text"]
    return out.strip()


# ---------- Gradio Chat glue ----------
def chat_fn(message, history, show_agent_trace=False):
    plan = planner_agent(message)
    answer = writer_agent(message, plan)

    if show_agent_trace:
        trace = f"\n\n---\n*Agent trace:* action = **{plan['action']}**, reason = _{plan['reason']}_"
        return answer + trace
    return answer


demo = gr.ChatInterface(
    fn=chat_fn,
    additional_inputs=[
        gr.Checkbox(label="Show agent trace", value=False),
    ],
    theme="soft",
    css="""
    .gradio-container {max-width: 760px !important}
    """,
)

if __name__ == "__main__":
    demo.launch()