Spaces:
Sleeping
Sleeping
File size: 4,365 Bytes
4c168bb b8811b7 4c168bb f161551 4c168bb b8811b7 4c168bb b8811b7 4c168bb b8811b7 4c168bb b8811b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import re
import ast
import operator
import gradio as gr
# LLM: small, CPU-friendly
from transformers import pipeline
LLM = None # lazy-load to speed up app boot
# ---------- Agent 1: Planner (decides the next action) ----------
def planner_agent(user_msg: str) -> dict:
"""
Very small heuristic planner:
- If it's a simple math question/expression => CALCULATE
- Else => ANSWER (LLM)
"""
text = user_msg.strip().lower()
# If it looks like a calculation
math_like = bool(re.search(r"[0-9][0-9\.\s\+\-\*\/\%\^\(\)]*[0-9\)]", text))
asks_calc = any(k in text for k in ["calc", "calculate", "evaluate", "what is", "what's"])
if math_like or asks_calc:
# Avoid "story" or "explain" with numbers
if not any(k in text for k in ["story", "poem", "explain why", "compare"]):
return {"action": "CALCULATE", "reason": "Looks like a numeric expression or calculation request."}
return {"action": "ANSWER", "reason": "General question; better handled by the LLM."}
# ---------- Tiny safe calculator (tool used by Solver) ----------
# Safe AST-based evaluator supporting +,-,*,/,**,%,// and parentheses
OPS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.Pow: operator.pow,
ast.Mod: operator.mod,
ast.FloorDiv: operator.floordiv,
ast.USub: operator.neg,
ast.UAdd: operator.pos,
}
def _eval(node):
if isinstance(node, ast.Num): # py<3.8 fallback
return node.n
if isinstance(node, ast.Constant):
if isinstance(node.value, (int, float)):
return node.value
raise ValueError("Only numeric constants allowed.")
if isinstance(node, ast.BinOp):
left = _eval(node.left)
right = _eval(node.right)
op_type = type(node.op)
if op_type in OPS:
return OPS[op_type](left, right)
raise ValueError("Unsupported operator.")
if isinstance(node, ast.UnaryOp) and type(node.op) in OPS:
return OPS[type(node.op)](_eval(node.operand))
if isinstance(node, ast.Expression):
return _eval(node.body)
raise ValueError("Unsupported expression.")
def safe_calculate(expr: str) -> str:
try:
tree = ast.parse(expr, mode="eval")
val = _eval(tree)
return str(val)
except Exception as e:
return f"Sorry, I couldn't calculate that: {e}"
# ---------- Agent 2: Writer (LLM or Tool) ----------
def get_llm():
global LLM
if LLM is None:
# flan-t5-small is ~80M params, OK on CPU Basic
LLM = pipeline("text2text-generation", model="google/flan-t5-small")
return LLM
def writer_agent(user_msg: str, plan: dict) -> str:
if plan["action"] == "CALCULATE":
# Extract the most likely expression from the message
# Keep digits, ops, dots, spaces, and parentheses
expr = "".join(ch for ch in user_msg if ch in "0123456789.+-*/()%^ //")
# Clean up accidental double spaces
expr = re.sub(r"\s+", "", expr.replace("^", "**"))
if not expr:
# fallback to LLM if no expression found
llm = get_llm()
prompt = f"Answer briefly:\n\nQuestion: {user_msg}\nAnswer:"
out = llm(prompt, max_new_tokens=128, do_sample=True)[0]["generated_text"]
return out.strip()
return safe_calculate(expr)
# ANSWER with LLM
llm = get_llm()
prompt = (
"You are a concise, friendly assistant. "
"Answer clearly in 1-4 sentences.\n\n"
f"Question: {user_msg}\nAnswer:"
)
out = llm(prompt, max_new_tokens=192, do_sample=True, temperature=0.6, top_p=0.95)[0]["generated_text"]
return out.strip()
# ---------- Gradio Chat glue ----------
def chat_fn(message, history, show_agent_trace=False):
plan = planner_agent(message)
answer = writer_agent(message, plan)
if show_agent_trace:
trace = f"\n\n---\n*Agent trace:* action = **{plan['action']}**, reason = _{plan['reason']}_"
return answer + trace
return answer
demo = gr.ChatInterface(
fn=chat_fn,
additional_inputs=[
gr.Checkbox(label="Show agent trace", value=False),
],
theme="soft",
css="""
.gradio-container {max-width: 760px !important}
""",
)
if __name__ == "__main__":
demo.launch()
|