jackal79's picture
Update app.py
4ab67f5 verified
import os, json, time, hmac, hashlib, re, requests
import gradio as gr
# Optional shim for websockets.asyncio.client
try:
websockets # type: ignore
except Exception:
try:
import websockets.asyncio.client # noqa: F401
except Exception:
import websockets.legacy.client as websockets_asyncio_client # type: ignore
websockets = type("_W", (), {"asyncio": type("_A", (), {"client": websockets_asyncio_client})})
PAT_API_URL = os.environ.get("PAT_API_URL", "").rstrip("/")
PAT_API_KEY = os.environ.get("PAT_API_KEY", "")
HMAC_SECRET = os.environ.get("HMAC_SECRET", "")
# Qwen options: either call a custom endpoint (QWEN_API_URL) or HF Inference (QWEN_MODEL_ID+HF_TOKEN)
QWEN_API_URL = os.environ.get("QWEN_API_URL", "").rstrip("/")
QWEN_MODEL_ID = os.environ.get("QWEN_MODEL_ID", "")
HF_TOKEN = os.environ.get("HF_TOKEN", "")
_EXPR_RE = re.compile(r"^[0-9()+\-*/^= ]{1,200}$")
PROMPT_SYSTEM = (
"You are Modulus' NL parser. Return ONLY compact JSON with an 'intent' and fields.\n"
"Supported intents: simplify, expand, eval, factor, check_equal, recover, solve, logic, code, ode, pde.\n"
"Schemas:\n"
"- simplify|expand|eval: {intent, expr, subs?} where expr uses + - * / ^, integers, single-letter vars; subs is {var:int}.\n"
"- factor: {intent, poly} where poly is textual polynomial in one var.\n"
"- check_equal: {intent, left, right}.\n"
"- recover: {intent, samples:[{x: string, y: string}], kind:'rational'|'poly', hint?}.\n"
"- solve: {intent, expr} for numeric arithmetic.\n"
"- logic: {intent, expr, compare?}.\n"
"- code: {intent, expr_a, expr_b, vars?, bound?}.\n"
"- ode: {intent, A, x0, dt, steps}.\n"
"- pde: {intent, kind, N}.\n"
"Do not add commentary. If unsure, return {intent:'simplify', expr:'...'} with your best guess."
)
def _sign_headers(body: bytes) -> dict:
headers = {"x-api-key": PAT_API_KEY, "content-type": "application/json"}
if HMAC_SECRET:
ts = str(int(time.time()))
mac = hmac.new(HMAC_SECRET.encode("utf-8"), digestmod=hashlib.sha256)
mac.update(ts.encode("utf-8")); mac.update(b"\n"); mac.update(body)
headers.update({"x-timestamp": ts, "x-signature": mac.hexdigest()})
return headers
def _call_qwen(prompt: str) -> str:
if QWEN_API_URL:
try:
r = requests.post(
f"{QWEN_API_URL}/v1/parse",
headers={"content-type": "application/json"},
data=json.dumps({"prompt": prompt, "system": PROMPT_SYSTEM, "format": "json"}).encode("utf-8"),
timeout=20,
)
r.raise_for_status()
return json.dumps(r.json())
except Exception:
return ""
if QWEN_MODEL_ID and HF_TOKEN:
try:
r = requests.post(
f"https://api-inference.huggingface.co/models/{QWEN_MODEL_ID}",
headers={"Authorization": f"Bearer {HF_TOKEN}", "content-type": "application/json"},
data=json.dumps({
"inputs": f"System: {PROMPT_SYSTEM}\nUser: {prompt}\nAssistant:",
"parameters": {"max_new_tokens": 64, "temperature": 0.0, "return_full_text": False},
}).encode("utf-8"),
timeout=30,
)
r.raise_for_status()
out = r.json()
text = ""
if isinstance(out, list) and out and isinstance(out[0], dict):
text = out[0].get("generated_text", "") or out[0].get("summary_text", "")
if not text:
text = json.dumps(out)
return text
except Exception:
return ""
return ""
def solve_chat(message: str, history: list[list[str]]):
if not PAT_API_URL or not PAT_API_KEY:
return "API not configured"
def _insert_implicit_mul(expr: str) -> str:
e = expr
e = re.sub(r'(\d|\))\(', r'\1*(', e)
e = re.sub(r'(\d|\))([a-zA-Z])', r'\1*\2', e)
e = re.sub(r'([a-zA-Z]|\))(\d)', r'\1*\2', e)
e = re.sub(r'([a-zA-Z]|\))([a-zA-Z])', r'\1*\2', e)
return e
def _norm_expr(msg: str) -> str:
s = msg.strip()
try:
m = re.search(r'"([^"]+)"', s)
if m:
return _insert_implicit_mul(m.group(1))
m = re.match(r'^(?:simplify|expand)\s*:??\s+(.+)$', s, re.I)
if m:
return _insert_implicit_mul(m.group(1))
except Exception:
pass
return _insert_implicit_mul(s)
def _http_json_post(url: str, body: dict) -> dict:
data = json.dumps(body).encode("utf-8")
r = requests.post(url, headers=_sign_headers(data), data=data, timeout=30)
try:
r.raise_for_status()
except requests.RequestException as e:
try:
return {"error": r.json()}
except Exception:
return {"error": str(e)}
return r.json()
def _sanitize_symbolic(s: str) -> str:
s2 = s
s2 = s2.replace('−','-').replace('–','-').replace('—','-')
s2 = s2.replace('×','*').replace('·','*')
s2 = s2.replace('⁻','-').replace('^','^')
s2 = s2.replace('“','"').replace('”','"').replace("’","'")
s2 = re.sub(r"[^0-9a-zA-Z()+\-*/^ ]+", " ", s2)
return _insert_implicit_mul(s2.strip())
def _poly_from_str(poly: str) -> list[str] | None:
s = _insert_implicit_mul(poly.replace(' ', ''))
pos = 0; n = len(s)
def peek(): return s[pos] if pos < n else ''
def consume(ch=None):
nonlocal pos
if ch and peek() != ch: return False
pos += 1; return True
def poly_const(c: int): return {0: c}
def poly_var(): return {1: 1}
def poly_add(a: dict, b: dict):
r = dict(a)
for k,v in b.items(): r[k] = r.get(k,0) + v
return {k:v for k,v in r.items() if v != 0}
def poly_mul(a: dict, b: dict):
r: dict[int,int] = {}
for da,va in a.items():
for db,vb in b.items():
r[da+db] = r.get(da+db,0) + va*vb
return {k:v for k,v in r.items() if v != 0}
def poly_pow(a: dict, m: int):
res = {0:1}; p = dict(a); mm = int(m)
if mm < 0: return None
while mm > 0:
if mm & 1: res = poly_mul(res, p)
p = poly_mul(p, p); mm >>= 1
return res
def parse_number():
nonlocal pos
start = pos
if peek() in '+-': pos += 1
while pos < n and s[pos].isdigit(): pos += 1
if start == pos or (s[start] in '+-' and pos == start+1): return None
return int(s[start:pos])
def parse_factor():
nonlocal pos
if peek().isdigit() or peek() in '+-':
c = parse_number(); return poly_const(int(c)) if c is not None else None
if peek() == 'x':
consume('x'); base = poly_var()
if peek() == '^':
consume('^'); exp = parse_number()
if exp is None: return None
return poly_pow(base, int(exp))
return base
if peek() == '(':
consume('('); a = parse_expr()
if not consume(')'): return None
if peek() == '^':
consume('^'); exp = parse_number()
if exp is None: return None
return poly_pow(a, int(exp))
return a
return None
def parse_term():
v = parse_factor()
while True:
if peek() == '*':
consume('*'); w = parse_factor()
if w is None: return None
v = poly_mul(v, w)
else:
break
return v
def parse_expr():
v = parse_term()
while True:
if peek() == '+':
consume('+'); w = parse_term()
if w is None: return None
v = poly_add(v, w)
elif peek() == '-':
consume('-'); w = parse_term()
if w is None: return None
v = poly_add(v, poly_mul(w, {0:-1}))
else:
break
return v
poly_map = parse_expr()
if poly_map is None: return None
deg_max = max(poly_map.keys()) if poly_map else 0
coeffs = [str(poly_map.get(d, 0)) for d in range(0, deg_max+1)]
return coeffs
def _fmt_factor(resp: dict) -> str:
try:
factors = resp.get("factors") or []
parts = []
for item in factors:
coeffs = item.get("factor") or []
m = int(item.get("multiplicity", 1))
if len(coeffs) == 2 and coeffs[1] == "1":
c0 = coeffs[0]
sgn = '+' if not str(c0).startswith('-') else ''
base = f"(x{sgn}{c0})" if c0 != "0" else "x"
else:
terms = []
for i,c in enumerate(coeffs):
cstr = str(c)
if cstr == "0": continue
if i == 0: terms.append(cstr)
elif i == 1: terms.append(f"{cstr}*x")
else: terms.append(f"{cstr}*x^{i}")
base = "(" + " + ".join(terms) + ")"
parts.append(base if m == 1 else f"{base}^{m}")
return " * ".join(parts) if parts else "1"
except Exception:
return json.dumps(resp)
def _fmt_partial(resp: dict) -> str:
try:
if not resp.get("decomposed"): return "Could not decompose."
poly = resp.get("polynomial_part") or []
def poly_to_str(coeffs: list[str]) -> str:
terms = []
for i,c in enumerate(coeffs):
cstr = str(c)
if cstr == "0": continue
if i == 0: terms.append(cstr)
elif i == 1: terms.append(f"{cstr}*x")
else: terms.append(f"{cstr}*x^{i}")
return " + ".join(terms) if terms else "0"
parts = []
if poly:
pstr = poly_to_str(poly)
if pstr and pstr != "0": parts.append(pstr)
for t in resp.get("terms") or []:
A = t.get("A"); root = t.get("root")
if str(A) == "0": continue
parts.append(f"{A}/(x-({root}))")
return " + ".join(parts) if parts else "0"
except Exception:
return json.dumps(resp)
def _fmt_recover(resp: dict) -> str:
try:
expr = resp.get("expr") or {}
if expr.get("type") == "poly":
return _fmt_partial({"decomposed": True, "polynomial_part": expr.get("coeffs", []), "terms": []})
if expr.get("type") == "rational":
num = _fmt_partial({"decomposed": True, "polynomial_part": expr.get("num", []), "terms": []})
den = _fmt_partial({"decomposed": True, "polynomial_part": expr.get("den", []), "terms": []})
return f"({num})/({den})"
return json.dumps(resp)
except Exception:
return json.dumps(resp)
def _dispatch(plan: dict) -> str:
intent = (plan.get('intent') or '').lower()
if intent in ('simplify','expand','eval'):
expr = _norm_expr(plan.get('expr',''))
mode = 'simplify' if intent in ('simplify','expand') else 'eval'
body = {"mode": mode, "expr": expr}
if mode == 'eval' and isinstance(plan.get('subs'), dict):
body['subs'] = {k: int(str(v)) for k,v in plan['subs'].items() if re.fullmatch(r"[a-z]", k)}
resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", body)
return resp.get("expr") or resp.get("value") or json.dumps(resp)
if intent == 'factor':
poly = _insert_implicit_mul(plan.get('poly',''))
coeffs = _try_poly_coeffs(poly)
if coeffs:
resp = _http_json_post(f"{PAT_API_URL}/v1/factor", {"coeffs": coeffs})
return _fmt_factor(resp) if isinstance(resp, dict) else str(resp)
resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode":"simplify","expr": poly})
return resp.get("expr") or resp.get("value") or json.dumps(resp)
if intent == 'check_equal':
left = _insert_implicit_mul(plan.get('left',''))
right = _insert_implicit_mul(plan.get('right',''))
resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode":"check_equal","expr":"","left":left,"right":right,"certify_points":["-2","-1","0","1","2","3"]})
if isinstance(resp, dict) and resp.get('error') or (isinstance(resp, dict) and 'detail' in resp and resp.get('detail')):
diff_expr = f"({left})-({right})"
alt = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode":"simplify","expr": diff_expr})
if isinstance(alt, dict):
val = str(alt.get('expr') or alt.get('value') or '')
if val.strip() in ('0','0/1'):
return "equal: True"
ok = bool(resp.get("equal"))
return f"equal: {ok}"
if intent == 'recover':
kind = (plan.get('kind') or 'rational').lower()
samples = plan.get('samples') or []
if not samples:
return "No samples provided"
body = {"mode": 'poly' if kind=='poly' else 'rational', "var": "x", "samples": samples, "certify": True}
resp = _http_json_post(f"{PAT_API_URL}/v1/recover", body)
return _fmt_recover(resp) if isinstance(resp, dict) else str(resp)
if intent == 'logic':
resp = _http_json_post(f"{PAT_API_URL}/v1/logic_simplify", {"expr": plan.get('expr',''), "compare": plan.get('compare')})
return json.dumps(resp)
if intent == 'code':
resp = _http_json_post(f"{PAT_API_URL}/v1/code_equiv", {
"expr_a": plan.get('expr_a',''),
"expr_b": plan.get('expr_b',''),
"vars": plan.get('vars') or ['x'],
"bound": int(plan.get('bound') or 5),
})
return json.dumps(resp)
if intent == 'ode':
resp = _http_json_post(f"{PAT_API_URL}/v1/ode_solve", {"A": plan.get('A'), "x0": plan.get('x0'), "dt": plan.get('dt'), "steps": int(plan.get('steps') or 1)})
return json.dumps(resp)
if intent == 'pde':
resp = _http_json_post(f"{PAT_API_URL}/v1/pde_solve", {"kind": plan.get('kind') or 'poisson_1d', "N": int(plan.get('N') or 8)})
return json.dumps(resp)
if intent == 'solve':
expr = (plan.get('expr') or '').strip()
if not _EXPR_RE.fullmatch(expr.rstrip('=')):
return "Could not parse the question into a numeric expression."
body = {"question": expr.rstrip('='), "options": {"certificates": True}}
resp = _http_json_post(f"{PAT_API_URL}/v1/solve", body)
return str(resp.get('answer')) if isinstance(resp, dict) else str(resp)
return "Unknown intent"
def _try_poly_coeffs(expr: str) -> list[str] | None:
s = expr.replace(' ', '')
if any(ch in s for ch in '()/'):
return None
var = None; pos = 0; n = len(s)
deg_to_coeff: dict[int, int] = {}
while pos < n:
sign = 1
if s[pos] in '+-':
sign = -1 if s[pos] == '-' else 1
pos += 1
start = pos
while pos < n and s[pos].isdigit():
pos += 1
coeff = int(s[start:pos] or '1')
if pos < n and s[pos] == '*':
pos += 1
exp = 0
if pos < n and s[pos].isalpha():
v = s[pos]
if var is None: var = v
elif var != v: return None
pos += 1; exp = 1
if pos < n and s[pos] == '^':
pos += 1; start = pos
while pos < n and s[pos].isdigit(): pos += 1
if start == pos: return None
exp = int(s[start:pos])
deg_to_coeff[exp] = deg_to_coeff.get(exp, 0) + sign * coeff
if pos < n and s[pos] not in '+-': return None
if not deg_to_coeff: return None
max_deg = max(deg_to_coeff.keys())
return [str(deg_to_coeff.get(d, 0)) for d in range(0, max_deg + 1)]
# Command routing (string-first UX)
msg = message.strip()
# Simplify/Expand: explicit commands bypass NL parsing
m_se = re.match(r'^(simplify|expand)\s*:??\s+(.+)$', msg, re.I)
if m_se:
expr = _insert_implicit_mul(m_se.group(2))
resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode": "simplify", "expr": expr})
if "error" in resp:
return json.dumps(resp["error"]) if isinstance(resp["error"], dict) else str(resp["error"])
return resp.get("expr") or resp.get("value") or json.dumps(resp)
# Equality: supports "check_equal: A = B" or "A = B", or single side to simplify
if re.match(r'^(check_equal|equal|equiv|prove)\s*:?', msg, re.I) or re.search(r'(==|=|≟|\?=|≡)', msg):
payload = msg
mcmd = re.match(r'^(check_equal|equal|equiv|prove)\s*:??\s*(.+)$', msg, re.I)
if mcmd:
payload = mcmd.group(2)
parts = re.split(r'\s*(?:==|=|≟|\?=|≡)\s*', payload, maxsplit=1)
if len(parts) == 2:
left = _sanitize_symbolic(parts[0])
right = _sanitize_symbolic(parts[1])
resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode": "check_equal", "expr": "", "left": left, "right": right, "certify_points": ["-2","-1","0","1","2","3"]})
if "error" in resp:
return json.dumps(resp["error"]) if isinstance(resp["error"], dict) else str(resp["error"])
return f"equal: {bool(resp.get('equal'))}"
expr = _sanitize_symbolic(payload)
resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode": "simplify", "expr": expr})
if "error" in resp:
return json.dumps(resp["error"]) if isinstance(resp["error"], dict) else str(resp["error"])
return resp.get("expr") or resp.get("value") or json.dumps(resp)
# Factor
m = re.match(r'^(factor)\s*:??\s+(.+)$', msg, re.I)
if m:
expr = _sanitize_symbolic(m.group(2))
coeffs = _try_poly_coeffs(expr)
if coeffs:
resp = _http_json_post(f"{PAT_API_URL}/v1/factor", {"coeffs": coeffs})
return _fmt_factor(resp) if isinstance(resp, dict) else str(resp)
resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode": "simplify", "expr": expr})
if "error" in resp:
return json.dumps(resp["error"]) if isinstance(resp["error"], dict) else str(resp["error"])
return resp.get("expr") or resp.get("value") or json.dumps(resp)
# Recover
m = re.match(r'^recover\s*:??\s+(.+)$', msg, re.I)
if m:
payload = m.group(1)
samples = []
for mm in re.finditer(r'[-\d]+\s*(?:->|→)\s*[-\d]+(?:/\d+)?', payload):
pair = mm.group(0)
lhs, rhs = re.split(r'\s*(?:->|→)\s*', pair)
samples.append({"x": lhs.strip(), "y": rhs.strip()})
if samples:
resp = _http_json_post(f"{PAT_API_URL}/v1/recover", {"mode": "rational", "var": "x", "samples": samples, "certify": True})
return _fmt_recover(resp) if isinstance(resp, dict) else str(resp)
return "Could not parse samples for recover. Use like: recover: (1->3/2), (2->4/3)"
# Variables present → prefer direct symbolic simplify unless sentence-like words appear
if any(c.isalpha() for c in message):
if re.search(r"[A-Za-z]{2,}", message):
pass
else:
expr = _norm_expr(message)
resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode": "simplify", "expr": expr})
if "error" in resp:
return json.dumps(resp["error"]) if isinstance(resp["error"], dict) else str(resp["error"])
expr_out = resp.get("expr") or resp.get("value") or resp
return expr_out if isinstance(expr_out, str) else json.dumps(expr_out)
# 1) Ask Qwen for a structured plan
plan_text = _call_qwen(message)
try:
plan = json.loads(plan_text) if plan_text else {}
except Exception:
plan = {}
# 1a) Heuristic fallback for partial fractions when no Qwen plan is available
if not plan:
m_pf = re.search(r"partial\s+fractions?:\s*(.+)$", message, re.I)
if not m_pf:
m_pf = re.search(r"break\s+this\s+rational.*?:\s*(.+)$", message, re.I)
if m_pf:
frac = m_pf.group(1).strip()
m2 = re.match(r"^\((.+)\)\s*/\s*\((.+)\)\s*$", frac)
if not m2:
m2 = re.match(r"^([^/]+)\s*/\s*([^/]+)$", frac)
if m2:
P_str, Q_str = m2.group(1), m2.group(2)
P_coeffs = _poly_from_str(P_str)
Q_coeffs = _poly_from_str(Q_str)
if P_coeffs and Q_coeffs:
resp = _http_json_post(f"{PAT_API_URL}/v1/partial_fractions", {"P": P_coeffs, "Q": Q_coeffs})
return _fmt_partial(resp) if isinstance(resp, dict) else str(resp)
# Fallbacks: numeric-only quick path
if not plan:
expr = message.strip()
if _EXPR_RE.fullmatch(expr.rstrip('=')):
plan = {"intent": "solve", "expr": expr}
if not plan:
return "Could not parse. Try a shorter prompt with explicit math."
return _dispatch(plan)
chat = gr.ChatInterface(
fn=solve_chat,
title="Modulus Chat",
description="Qwen parses → Modulus computes (certified).",
)
if __name__ == "__main__":
chat.launch()