GermanySutherland commited on
Commit
4c168bb
·
verified ·
1 Parent(s): 07c3beb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -40
app.py CHANGED
@@ -1,50 +1,132 @@
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
- from typing import List, Tuple
4
- import os
5
-
6
- # Initialize client with token (add HF_TOKEN in your Space secrets)
7
- client = InferenceClient(
8
- model="HuggingFaceH4/zephyr-7b-beta",
9
- token=os.getenv("HF_TOKEN")
10
- )
11
 
12
- def respond(
13
- message: str,
14
- history: List[Tuple[str, str]],
15
- system_message: str,
16
- max_tokens: int,
17
- temperature: float,
18
- top_p: float,
19
- ):
20
- # Build conversation history
21
- messages = [{"role": "system", "content": system_message}]
22
- for user_msg, bot_reply in history:
23
- if user_msg:
24
- messages.append({"role": "user", "content": user_msg})
25
- if bot_reply:
26
- messages.append({"role": "assistant", "content": bot_reply})
27
- messages.append({"role": "user", "content": message})
28
-
29
- # Call Hugging Face Inference API
30
- msg = client.chat_completion(
31
- model="HuggingFaceH4/zephyr-7b-beta",
32
- messages=messages,
33
- max_tokens=max_tokens,
34
- temperature=temperature,
35
- top_p=top_p,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  )
37
- return msg.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- # Gradio Chat UI
40
  demo = gr.ChatInterface(
41
- respond,
42
  additional_inputs=[
43
- gr.Textbox(value="You are a friendly chatbot.", label="System message"),
44
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
45
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
46
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
47
  ],
 
 
 
 
48
  )
49
 
50
  if __name__ == "__main__":
 
1
+ import re
2
+ import ast
3
+ import operator
4
  import gradio as gr
 
 
 
 
 
 
 
 
 
5
 
6
+ # LLM: small, CPU-friendly
7
+ from transformers import pipeline
8
+
9
+ LLM = None # lazy-load to speed up app boot
10
+
11
+
12
+ # ---------- Agent 1: Planner (decides the next action) ----------
13
+ def planner_agent(user_msg: str) -> dict:
14
+ """
15
+ Very small heuristic planner:
16
+ - If it's a simple math question/expression => CALCULATE
17
+ - Else => ANSWER (LLM)
18
+ """
19
+ text = user_msg.strip().lower()
20
+
21
+ # If it looks like a calculation
22
+ math_like = bool(re.search(r"[0-9][0-9\.\s\+\-\*\/\%\^\(\)]*[0-9\)]", text))
23
+ asks_calc = any(k in text for k in ["calc", "calculate", "evaluate", "what is", "what's"])
24
+
25
+ if math_like or asks_calc:
26
+ # Avoid "story" or "explain" with numbers
27
+ if not any(k in text for k in ["story", "poem", "explain why", "compare"]):
28
+ return {"action": "CALCULATE", "reason": "Looks like a numeric expression or calculation request."}
29
+
30
+ return {"action": "ANSWER", "reason": "General question; better handled by the LLM."}
31
+
32
+
33
+ # ---------- Tiny safe calculator (tool used by Solver) ----------
34
+ # Safe AST-based evaluator supporting +,-,*,/,**,%,// and parentheses
35
+ OPS = {
36
+ ast.Add: operator.add,
37
+ ast.Sub: operator.sub,
38
+ ast.Mult: operator.mul,
39
+ ast.Div: operator.truediv,
40
+ ast.Pow: operator.pow,
41
+ ast.Mod: operator.mod,
42
+ ast.FloorDiv: operator.floordiv,
43
+ ast.USub: operator.neg,
44
+ ast.UAdd: operator.pos,
45
+ }
46
+
47
+ def _eval(node):
48
+ if isinstance(node, ast.Num): # py<3.8 fallback
49
+ return node.n
50
+ if isinstance(node, ast.Constant):
51
+ if isinstance(node.value, (int, float)):
52
+ return node.value
53
+ raise ValueError("Only numeric constants allowed.")
54
+ if isinstance(node, ast.BinOp):
55
+ left = _eval(node.left)
56
+ right = _eval(node.right)
57
+ op_type = type(node.op)
58
+ if op_type in OPS:
59
+ return OPS[op_type](left, right)
60
+ raise ValueError("Unsupported operator.")
61
+ if isinstance(node, ast.UnaryOp) and type(node.op) in OPS:
62
+ return OPS[type(node.op)](_eval(node.operand))
63
+ if isinstance(node, ast.Expression):
64
+ return _eval(node.body)
65
+ raise ValueError("Unsupported expression.")
66
+
67
+ def safe_calculate(expr: str) -> str:
68
+ try:
69
+ tree = ast.parse(expr, mode="eval")
70
+ val = _eval(tree)
71
+ return str(val)
72
+ except Exception as e:
73
+ return f"Sorry, I couldn't calculate that: {e}"
74
+
75
+
76
+ # ---------- Agent 2: Writer (LLM or Tool) ----------
77
+ def get_llm():
78
+ global LLM
79
+ if LLM is None:
80
+ # flan-t5-small is ~80M params, OK on CPU Basic
81
+ LLM = pipeline("text2text-generation", model="google/flan-t5-small")
82
+ return LLM
83
+
84
+ def writer_agent(user_msg: str, plan: dict) -> str:
85
+ if plan["action"] == "CALCULATE":
86
+ # Extract the most likely expression from the message
87
+ # Keep digits, ops, dots, spaces, and parentheses
88
+ expr = "".join(ch for ch in user_msg if ch in "0123456789.+-*/()%^ //")
89
+ # Clean up accidental double spaces
90
+ expr = re.sub(r"\s+", "", expr.replace("^", "**"))
91
+ if not expr:
92
+ # fallback to LLM if no expression found
93
+ llm = get_llm()
94
+ prompt = f"Answer briefly:\n\nQuestion: {user_msg}\nAnswer:"
95
+ out = llm(prompt, max_new_tokens=128, do_sample=True)[0]["generated_text"]
96
+ return out.strip()
97
+ return safe_calculate(expr)
98
+
99
+ # ANSWER with LLM
100
+ llm = get_llm()
101
+ prompt = (
102
+ "You are a concise, friendly assistant. "
103
+ "Answer clearly in 1-4 sentences.\n\n"
104
+ f"Question: {user_msg}\nAnswer:"
105
  )
106
+ out = llm(prompt, max_new_tokens=192, do_sample=True, temperature=0.6, top_p=0.95)[0]["generated_text"]
107
+ return out.strip()
108
+
109
+
110
+ # ---------- Gradio Chat glue ----------
111
+ def chat_fn(message, history, show_agent_trace=False):
112
+ plan = planner_agent(message)
113
+ answer = writer_agent(message, plan)
114
+
115
+ if show_agent_trace:
116
+ trace = f"\n\n---\n*Agent trace:* action = **{plan['action']}**, reason = _{plan['reason']}_"
117
+ return answer + trace
118
+ return answer
119
+
120
 
 
121
  demo = gr.ChatInterface(
122
+ fn=chat_fn,
123
  additional_inputs=[
124
+ gr.Checkbox(label="Show agent trace", value=False),
 
 
 
125
  ],
126
+ theme="soft",
127
+ css="""
128
+ .gradio-container {max-width: 760px !important}
129
+ """,
130
  )
131
 
132
  if __name__ == "__main__":