Chatbot / app.py
GermanySutherland's picture
Update app.py
4c168bb verified
import re
import ast
import operator
import gradio as gr
# LLM: small, CPU-friendly
from transformers import pipeline
LLM = None # lazy-load to speed up app boot
# ---------- Agent 1: Planner (decides the next action) ----------
def planner_agent(user_msg: str) -> dict:
"""
Very small heuristic planner:
- If it's a simple math question/expression => CALCULATE
- Else => ANSWER (LLM)
"""
text = user_msg.strip().lower()
# If it looks like a calculation
math_like = bool(re.search(r"[0-9][0-9\.\s\+\-\*\/\%\^\(\)]*[0-9\)]", text))
asks_calc = any(k in text for k in ["calc", "calculate", "evaluate", "what is", "what's"])
if math_like or asks_calc:
# Avoid "story" or "explain" with numbers
if not any(k in text for k in ["story", "poem", "explain why", "compare"]):
return {"action": "CALCULATE", "reason": "Looks like a numeric expression or calculation request."}
return {"action": "ANSWER", "reason": "General question; better handled by the LLM."}
# ---------- Tiny safe calculator (tool used by Solver) ----------
# Safe AST-based evaluator supporting +,-,*,/,**,%,// and parentheses
OPS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.Pow: operator.pow,
ast.Mod: operator.mod,
ast.FloorDiv: operator.floordiv,
ast.USub: operator.neg,
ast.UAdd: operator.pos,
}
def _eval(node):
if isinstance(node, ast.Num): # py<3.8 fallback
return node.n
if isinstance(node, ast.Constant):
if isinstance(node.value, (int, float)):
return node.value
raise ValueError("Only numeric constants allowed.")
if isinstance(node, ast.BinOp):
left = _eval(node.left)
right = _eval(node.right)
op_type = type(node.op)
if op_type in OPS:
return OPS[op_type](left, right)
raise ValueError("Unsupported operator.")
if isinstance(node, ast.UnaryOp) and type(node.op) in OPS:
return OPS[type(node.op)](_eval(node.operand))
if isinstance(node, ast.Expression):
return _eval(node.body)
raise ValueError("Unsupported expression.")
def safe_calculate(expr: str) -> str:
try:
tree = ast.parse(expr, mode="eval")
val = _eval(tree)
return str(val)
except Exception as e:
return f"Sorry, I couldn't calculate that: {e}"
# ---------- Agent 2: Writer (LLM or Tool) ----------
def get_llm():
global LLM
if LLM is None:
# flan-t5-small is ~80M params, OK on CPU Basic
LLM = pipeline("text2text-generation", model="google/flan-t5-small")
return LLM
def writer_agent(user_msg: str, plan: dict) -> str:
if plan["action"] == "CALCULATE":
# Extract the most likely expression from the message
# Keep digits, ops, dots, spaces, and parentheses
expr = "".join(ch for ch in user_msg if ch in "0123456789.+-*/()%^ //")
# Clean up accidental double spaces
expr = re.sub(r"\s+", "", expr.replace("^", "**"))
if not expr:
# fallback to LLM if no expression found
llm = get_llm()
prompt = f"Answer briefly:\n\nQuestion: {user_msg}\nAnswer:"
out = llm(prompt, max_new_tokens=128, do_sample=True)[0]["generated_text"]
return out.strip()
return safe_calculate(expr)
# ANSWER with LLM
llm = get_llm()
prompt = (
"You are a concise, friendly assistant. "
"Answer clearly in 1-4 sentences.\n\n"
f"Question: {user_msg}\nAnswer:"
)
out = llm(prompt, max_new_tokens=192, do_sample=True, temperature=0.6, top_p=0.95)[0]["generated_text"]
return out.strip()
# ---------- Gradio Chat glue ----------
def chat_fn(message, history, show_agent_trace=False):
plan = planner_agent(message)
answer = writer_agent(message, plan)
if show_agent_trace:
trace = f"\n\n---\n*Agent trace:* action = **{plan['action']}**, reason = _{plan['reason']}_"
return answer + trace
return answer
demo = gr.ChatInterface(
fn=chat_fn,
additional_inputs=[
gr.Checkbox(label="Show agent trace", value=False),
],
theme="soft",
css="""
.gradio-container {max-width: 760px !important}
""",
)
if __name__ == "__main__":
demo.launch()