Spaces:
Sleeping
Sleeping
| import re | |
| import ast | |
| import operator | |
| import gradio as gr | |
| # LLM: small, CPU-friendly | |
| from transformers import pipeline | |
| LLM = None # lazy-load to speed up app boot | |
| # ---------- Agent 1: Planner (decides the next action) ---------- | |
| def planner_agent(user_msg: str) -> dict: | |
| """ | |
| Very small heuristic planner: | |
| - If it's a simple math question/expression => CALCULATE | |
| - Else => ANSWER (LLM) | |
| """ | |
| text = user_msg.strip().lower() | |
| # If it looks like a calculation | |
| math_like = bool(re.search(r"[0-9][0-9\.\s\+\-\*\/\%\^\(\)]*[0-9\)]", text)) | |
| asks_calc = any(k in text for k in ["calc", "calculate", "evaluate", "what is", "what's"]) | |
| if math_like or asks_calc: | |
| # Avoid "story" or "explain" with numbers | |
| if not any(k in text for k in ["story", "poem", "explain why", "compare"]): | |
| return {"action": "CALCULATE", "reason": "Looks like a numeric expression or calculation request."} | |
| return {"action": "ANSWER", "reason": "General question; better handled by the LLM."} | |
| # ---------- Tiny safe calculator (tool used by Solver) ---------- | |
| # Safe AST-based evaluator supporting +,-,*,/,**,%,// and parentheses | |
| OPS = { | |
| ast.Add: operator.add, | |
| ast.Sub: operator.sub, | |
| ast.Mult: operator.mul, | |
| ast.Div: operator.truediv, | |
| ast.Pow: operator.pow, | |
| ast.Mod: operator.mod, | |
| ast.FloorDiv: operator.floordiv, | |
| ast.USub: operator.neg, | |
| ast.UAdd: operator.pos, | |
| } | |
| def _eval(node): | |
| if isinstance(node, ast.Num): # py<3.8 fallback | |
| return node.n | |
| if isinstance(node, ast.Constant): | |
| if isinstance(node.value, (int, float)): | |
| return node.value | |
| raise ValueError("Only numeric constants allowed.") | |
| if isinstance(node, ast.BinOp): | |
| left = _eval(node.left) | |
| right = _eval(node.right) | |
| op_type = type(node.op) | |
| if op_type in OPS: | |
| return OPS[op_type](left, right) | |
| raise ValueError("Unsupported operator.") | |
| if isinstance(node, ast.UnaryOp) and type(node.op) in OPS: | |
| return OPS[type(node.op)](_eval(node.operand)) | |
| if isinstance(node, ast.Expression): | |
| return _eval(node.body) | |
| raise ValueError("Unsupported expression.") | |
| def safe_calculate(expr: str) -> str: | |
| try: | |
| tree = ast.parse(expr, mode="eval") | |
| val = _eval(tree) | |
| return str(val) | |
| except Exception as e: | |
| return f"Sorry, I couldn't calculate that: {e}" | |
| # ---------- Agent 2: Writer (LLM or Tool) ---------- | |
| def get_llm(): | |
| global LLM | |
| if LLM is None: | |
| # flan-t5-small is ~80M params, OK on CPU Basic | |
| LLM = pipeline("text2text-generation", model="google/flan-t5-small") | |
| return LLM | |
| def writer_agent(user_msg: str, plan: dict) -> str: | |
| if plan["action"] == "CALCULATE": | |
| # Extract the most likely expression from the message | |
| # Keep digits, ops, dots, spaces, and parentheses | |
| expr = "".join(ch for ch in user_msg if ch in "0123456789.+-*/()%^ //") | |
| # Clean up accidental double spaces | |
| expr = re.sub(r"\s+", "", expr.replace("^", "**")) | |
| if not expr: | |
| # fallback to LLM if no expression found | |
| llm = get_llm() | |
| prompt = f"Answer briefly:\n\nQuestion: {user_msg}\nAnswer:" | |
| out = llm(prompt, max_new_tokens=128, do_sample=True)[0]["generated_text"] | |
| return out.strip() | |
| return safe_calculate(expr) | |
| # ANSWER with LLM | |
| llm = get_llm() | |
| prompt = ( | |
| "You are a concise, friendly assistant. " | |
| "Answer clearly in 1-4 sentences.\n\n" | |
| f"Question: {user_msg}\nAnswer:" | |
| ) | |
| out = llm(prompt, max_new_tokens=192, do_sample=True, temperature=0.6, top_p=0.95)[0]["generated_text"] | |
| return out.strip() | |
| # ---------- Gradio Chat glue ---------- | |
| def chat_fn(message, history, show_agent_trace=False): | |
| plan = planner_agent(message) | |
| answer = writer_agent(message, plan) | |
| if show_agent_trace: | |
| trace = f"\n\n---\n*Agent trace:* action = **{plan['action']}**, reason = _{plan['reason']}_" | |
| return answer + trace | |
| return answer | |
| demo = gr.ChatInterface( | |
| fn=chat_fn, | |
| additional_inputs=[ | |
| gr.Checkbox(label="Show agent trace", value=False), | |
| ], | |
| theme="soft", | |
| css=""" | |
| .gradio-container {max-width: 760px !important} | |
| """, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |