import os, json, time, hmac, hashlib, re, requests import gradio as gr # Optional shim for websockets.asyncio.client try: websockets # type: ignore except Exception: try: import websockets.asyncio.client # noqa: F401 except Exception: import websockets.legacy.client as websockets_asyncio_client # type: ignore websockets = type("_W", (), {"asyncio": type("_A", (), {"client": websockets_asyncio_client})}) PAT_API_URL = os.environ.get("PAT_API_URL", "").rstrip("/") PAT_API_KEY = os.environ.get("PAT_API_KEY", "") HMAC_SECRET = os.environ.get("HMAC_SECRET", "") # Qwen options: either call a custom endpoint (QWEN_API_URL) or HF Inference (QWEN_MODEL_ID+HF_TOKEN) QWEN_API_URL = os.environ.get("QWEN_API_URL", "").rstrip("/") QWEN_MODEL_ID = os.environ.get("QWEN_MODEL_ID", "") HF_TOKEN = os.environ.get("HF_TOKEN", "") _EXPR_RE = re.compile(r"^[0-9()+\-*/^= ]{1,200}$") PROMPT_SYSTEM = ( "You are Modulus' NL parser. Return ONLY compact JSON with an 'intent' and fields.\n" "Supported intents: simplify, expand, eval, factor, check_equal, recover, solve, logic, code, ode, pde.\n" "Schemas:\n" "- simplify|expand|eval: {intent, expr, subs?} where expr uses + - * / ^, integers, single-letter vars; subs is {var:int}.\n" "- factor: {intent, poly} where poly is textual polynomial in one var.\n" "- check_equal: {intent, left, right}.\n" "- recover: {intent, samples:[{x: string, y: string}], kind:'rational'|'poly', hint?}.\n" "- solve: {intent, expr} for numeric arithmetic.\n" "- logic: {intent, expr, compare?}.\n" "- code: {intent, expr_a, expr_b, vars?, bound?}.\n" "- ode: {intent, A, x0, dt, steps}.\n" "- pde: {intent, kind, N}.\n" "Do not add commentary. If unsure, return {intent:'simplify', expr:'...'} with your best guess." ) def _sign_headers(body: bytes) -> dict: headers = {"x-api-key": PAT_API_KEY, "content-type": "application/json"} if HMAC_SECRET: ts = str(int(time.time())) mac = hmac.new(HMAC_SECRET.encode("utf-8"), digestmod=hashlib.sha256) mac.update(ts.encode("utf-8")); mac.update(b"\n"); mac.update(body) headers.update({"x-timestamp": ts, "x-signature": mac.hexdigest()}) return headers def _call_qwen(prompt: str) -> str: if QWEN_API_URL: try: r = requests.post( f"{QWEN_API_URL}/v1/parse", headers={"content-type": "application/json"}, data=json.dumps({"prompt": prompt, "system": PROMPT_SYSTEM, "format": "json"}).encode("utf-8"), timeout=20, ) r.raise_for_status() return json.dumps(r.json()) except Exception: return "" if QWEN_MODEL_ID and HF_TOKEN: try: r = requests.post( f"https://api-inference.huggingface.co/models/{QWEN_MODEL_ID}", headers={"Authorization": f"Bearer {HF_TOKEN}", "content-type": "application/json"}, data=json.dumps({ "inputs": f"System: {PROMPT_SYSTEM}\nUser: {prompt}\nAssistant:", "parameters": {"max_new_tokens": 64, "temperature": 0.0, "return_full_text": False}, }).encode("utf-8"), timeout=30, ) r.raise_for_status() out = r.json() text = "" if isinstance(out, list) and out and isinstance(out[0], dict): text = out[0].get("generated_text", "") or out[0].get("summary_text", "") if not text: text = json.dumps(out) return text except Exception: return "" return "" def solve_chat(message: str, history: list[list[str]]): if not PAT_API_URL or not PAT_API_KEY: return "API not configured" def _insert_implicit_mul(expr: str) -> str: e = expr e = re.sub(r'(\d|\))\(', r'\1*(', e) e = re.sub(r'(\d|\))([a-zA-Z])', r'\1*\2', e) e = re.sub(r'([a-zA-Z]|\))(\d)', r'\1*\2', e) e = re.sub(r'([a-zA-Z]|\))([a-zA-Z])', r'\1*\2', e) return e def _norm_expr(msg: str) -> str: s = msg.strip() try: m = re.search(r'"([^"]+)"', s) if m: return _insert_implicit_mul(m.group(1)) m = re.match(r'^(?:simplify|expand)\s*:??\s+(.+)$', s, re.I) if m: return _insert_implicit_mul(m.group(1)) except Exception: pass return _insert_implicit_mul(s) def _http_json_post(url: str, body: dict) -> dict: data = json.dumps(body).encode("utf-8") r = requests.post(url, headers=_sign_headers(data), data=data, timeout=30) try: r.raise_for_status() except requests.RequestException as e: try: return {"error": r.json()} except Exception: return {"error": str(e)} return r.json() def _sanitize_symbolic(s: str) -> str: s2 = s s2 = s2.replace('−','-').replace('–','-').replace('—','-') s2 = s2.replace('×','*').replace('·','*') s2 = s2.replace('⁻','-').replace('^','^') s2 = s2.replace('“','"').replace('”','"').replace("’","'") s2 = re.sub(r"[^0-9a-zA-Z()+\-*/^ ]+", " ", s2) return _insert_implicit_mul(s2.strip()) def _poly_from_str(poly: str) -> list[str] | None: s = _insert_implicit_mul(poly.replace(' ', '')) pos = 0; n = len(s) def peek(): return s[pos] if pos < n else '' def consume(ch=None): nonlocal pos if ch and peek() != ch: return False pos += 1; return True def poly_const(c: int): return {0: c} def poly_var(): return {1: 1} def poly_add(a: dict, b: dict): r = dict(a) for k,v in b.items(): r[k] = r.get(k,0) + v return {k:v for k,v in r.items() if v != 0} def poly_mul(a: dict, b: dict): r: dict[int,int] = {} for da,va in a.items(): for db,vb in b.items(): r[da+db] = r.get(da+db,0) + va*vb return {k:v for k,v in r.items() if v != 0} def poly_pow(a: dict, m: int): res = {0:1}; p = dict(a); mm = int(m) if mm < 0: return None while mm > 0: if mm & 1: res = poly_mul(res, p) p = poly_mul(p, p); mm >>= 1 return res def parse_number(): nonlocal pos start = pos if peek() in '+-': pos += 1 while pos < n and s[pos].isdigit(): pos += 1 if start == pos or (s[start] in '+-' and pos == start+1): return None return int(s[start:pos]) def parse_factor(): nonlocal pos if peek().isdigit() or peek() in '+-': c = parse_number(); return poly_const(int(c)) if c is not None else None if peek() == 'x': consume('x'); base = poly_var() if peek() == '^': consume('^'); exp = parse_number() if exp is None: return None return poly_pow(base, int(exp)) return base if peek() == '(': consume('('); a = parse_expr() if not consume(')'): return None if peek() == '^': consume('^'); exp = parse_number() if exp is None: return None return poly_pow(a, int(exp)) return a return None def parse_term(): v = parse_factor() while True: if peek() == '*': consume('*'); w = parse_factor() if w is None: return None v = poly_mul(v, w) else: break return v def parse_expr(): v = parse_term() while True: if peek() == '+': consume('+'); w = parse_term() if w is None: return None v = poly_add(v, w) elif peek() == '-': consume('-'); w = parse_term() if w is None: return None v = poly_add(v, poly_mul(w, {0:-1})) else: break return v poly_map = parse_expr() if poly_map is None: return None deg_max = max(poly_map.keys()) if poly_map else 0 coeffs = [str(poly_map.get(d, 0)) for d in range(0, deg_max+1)] return coeffs def _fmt_factor(resp: dict) -> str: try: factors = resp.get("factors") or [] parts = [] for item in factors: coeffs = item.get("factor") or [] m = int(item.get("multiplicity", 1)) if len(coeffs) == 2 and coeffs[1] == "1": c0 = coeffs[0] sgn = '+' if not str(c0).startswith('-') else '' base = f"(x{sgn}{c0})" if c0 != "0" else "x" else: terms = [] for i,c in enumerate(coeffs): cstr = str(c) if cstr == "0": continue if i == 0: terms.append(cstr) elif i == 1: terms.append(f"{cstr}*x") else: terms.append(f"{cstr}*x^{i}") base = "(" + " + ".join(terms) + ")" parts.append(base if m == 1 else f"{base}^{m}") return " * ".join(parts) if parts else "1" except Exception: return json.dumps(resp) def _fmt_partial(resp: dict) -> str: try: if not resp.get("decomposed"): return "Could not decompose." poly = resp.get("polynomial_part") or [] def poly_to_str(coeffs: list[str]) -> str: terms = [] for i,c in enumerate(coeffs): cstr = str(c) if cstr == "0": continue if i == 0: terms.append(cstr) elif i == 1: terms.append(f"{cstr}*x") else: terms.append(f"{cstr}*x^{i}") return " + ".join(terms) if terms else "0" parts = [] if poly: pstr = poly_to_str(poly) if pstr and pstr != "0": parts.append(pstr) for t in resp.get("terms") or []: A = t.get("A"); root = t.get("root") if str(A) == "0": continue parts.append(f"{A}/(x-({root}))") return " + ".join(parts) if parts else "0" except Exception: return json.dumps(resp) def _fmt_recover(resp: dict) -> str: try: expr = resp.get("expr") or {} if expr.get("type") == "poly": return _fmt_partial({"decomposed": True, "polynomial_part": expr.get("coeffs", []), "terms": []}) if expr.get("type") == "rational": num = _fmt_partial({"decomposed": True, "polynomial_part": expr.get("num", []), "terms": []}) den = _fmt_partial({"decomposed": True, "polynomial_part": expr.get("den", []), "terms": []}) return f"({num})/({den})" return json.dumps(resp) except Exception: return json.dumps(resp) def _dispatch(plan: dict) -> str: intent = (plan.get('intent') or '').lower() if intent in ('simplify','expand','eval'): expr = _norm_expr(plan.get('expr','')) mode = 'simplify' if intent in ('simplify','expand') else 'eval' body = {"mode": mode, "expr": expr} if mode == 'eval' and isinstance(plan.get('subs'), dict): body['subs'] = {k: int(str(v)) for k,v in plan['subs'].items() if re.fullmatch(r"[a-z]", k)} resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", body) return resp.get("expr") or resp.get("value") or json.dumps(resp) if intent == 'factor': poly = _insert_implicit_mul(plan.get('poly','')) coeffs = _try_poly_coeffs(poly) if coeffs: resp = _http_json_post(f"{PAT_API_URL}/v1/factor", {"coeffs": coeffs}) return _fmt_factor(resp) if isinstance(resp, dict) else str(resp) resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode":"simplify","expr": poly}) return resp.get("expr") or resp.get("value") or json.dumps(resp) if intent == 'check_equal': left = _insert_implicit_mul(plan.get('left','')) right = _insert_implicit_mul(plan.get('right','')) resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode":"check_equal","expr":"","left":left,"right":right,"certify_points":["-2","-1","0","1","2","3"]}) if isinstance(resp, dict) and resp.get('error') or (isinstance(resp, dict) and 'detail' in resp and resp.get('detail')): diff_expr = f"({left})-({right})" alt = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode":"simplify","expr": diff_expr}) if isinstance(alt, dict): val = str(alt.get('expr') or alt.get('value') or '') if val.strip() in ('0','0/1'): return "equal: True" ok = bool(resp.get("equal")) return f"equal: {ok}" if intent == 'recover': kind = (plan.get('kind') or 'rational').lower() samples = plan.get('samples') or [] if not samples: return "No samples provided" body = {"mode": 'poly' if kind=='poly' else 'rational', "var": "x", "samples": samples, "certify": True} resp = _http_json_post(f"{PAT_API_URL}/v1/recover", body) return _fmt_recover(resp) if isinstance(resp, dict) else str(resp) if intent == 'logic': resp = _http_json_post(f"{PAT_API_URL}/v1/logic_simplify", {"expr": plan.get('expr',''), "compare": plan.get('compare')}) return json.dumps(resp) if intent == 'code': resp = _http_json_post(f"{PAT_API_URL}/v1/code_equiv", { "expr_a": plan.get('expr_a',''), "expr_b": plan.get('expr_b',''), "vars": plan.get('vars') or ['x'], "bound": int(plan.get('bound') or 5), }) return json.dumps(resp) if intent == 'ode': resp = _http_json_post(f"{PAT_API_URL}/v1/ode_solve", {"A": plan.get('A'), "x0": plan.get('x0'), "dt": plan.get('dt'), "steps": int(plan.get('steps') or 1)}) return json.dumps(resp) if intent == 'pde': resp = _http_json_post(f"{PAT_API_URL}/v1/pde_solve", {"kind": plan.get('kind') or 'poisson_1d', "N": int(plan.get('N') or 8)}) return json.dumps(resp) if intent == 'solve': expr = (plan.get('expr') or '').strip() if not _EXPR_RE.fullmatch(expr.rstrip('=')): return "Could not parse the question into a numeric expression." body = {"question": expr.rstrip('='), "options": {"certificates": True}} resp = _http_json_post(f"{PAT_API_URL}/v1/solve", body) return str(resp.get('answer')) if isinstance(resp, dict) else str(resp) return "Unknown intent" def _try_poly_coeffs(expr: str) -> list[str] | None: s = expr.replace(' ', '') if any(ch in s for ch in '()/'): return None var = None; pos = 0; n = len(s) deg_to_coeff: dict[int, int] = {} while pos < n: sign = 1 if s[pos] in '+-': sign = -1 if s[pos] == '-' else 1 pos += 1 start = pos while pos < n and s[pos].isdigit(): pos += 1 coeff = int(s[start:pos] or '1') if pos < n and s[pos] == '*': pos += 1 exp = 0 if pos < n and s[pos].isalpha(): v = s[pos] if var is None: var = v elif var != v: return None pos += 1; exp = 1 if pos < n and s[pos] == '^': pos += 1; start = pos while pos < n and s[pos].isdigit(): pos += 1 if start == pos: return None exp = int(s[start:pos]) deg_to_coeff[exp] = deg_to_coeff.get(exp, 0) + sign * coeff if pos < n and s[pos] not in '+-': return None if not deg_to_coeff: return None max_deg = max(deg_to_coeff.keys()) return [str(deg_to_coeff.get(d, 0)) for d in range(0, max_deg + 1)] # Command routing (string-first UX) msg = message.strip() # Simplify/Expand: explicit commands bypass NL parsing m_se = re.match(r'^(simplify|expand)\s*:??\s+(.+)$', msg, re.I) if m_se: expr = _insert_implicit_mul(m_se.group(2)) resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode": "simplify", "expr": expr}) if "error" in resp: return json.dumps(resp["error"]) if isinstance(resp["error"], dict) else str(resp["error"]) return resp.get("expr") or resp.get("value") or json.dumps(resp) # Equality: supports "check_equal: A = B" or "A = B", or single side to simplify if re.match(r'^(check_equal|equal|equiv|prove)\s*:?', msg, re.I) or re.search(r'(==|=|≟|\?=|≡)', msg): payload = msg mcmd = re.match(r'^(check_equal|equal|equiv|prove)\s*:??\s*(.+)$', msg, re.I) if mcmd: payload = mcmd.group(2) parts = re.split(r'\s*(?:==|=|≟|\?=|≡)\s*', payload, maxsplit=1) if len(parts) == 2: left = _sanitize_symbolic(parts[0]) right = _sanitize_symbolic(parts[1]) resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode": "check_equal", "expr": "", "left": left, "right": right, "certify_points": ["-2","-1","0","1","2","3"]}) if "error" in resp: return json.dumps(resp["error"]) if isinstance(resp["error"], dict) else str(resp["error"]) return f"equal: {bool(resp.get('equal'))}" expr = _sanitize_symbolic(payload) resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode": "simplify", "expr": expr}) if "error" in resp: return json.dumps(resp["error"]) if isinstance(resp["error"], dict) else str(resp["error"]) return resp.get("expr") or resp.get("value") or json.dumps(resp) # Factor m = re.match(r'^(factor)\s*:??\s+(.+)$', msg, re.I) if m: expr = _sanitize_symbolic(m.group(2)) coeffs = _try_poly_coeffs(expr) if coeffs: resp = _http_json_post(f"{PAT_API_URL}/v1/factor", {"coeffs": coeffs}) return _fmt_factor(resp) if isinstance(resp, dict) else str(resp) resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode": "simplify", "expr": expr}) if "error" in resp: return json.dumps(resp["error"]) if isinstance(resp["error"], dict) else str(resp["error"]) return resp.get("expr") or resp.get("value") or json.dumps(resp) # Recover m = re.match(r'^recover\s*:??\s+(.+)$', msg, re.I) if m: payload = m.group(1) samples = [] for mm in re.finditer(r'[-\d]+\s*(?:->|→)\s*[-\d]+(?:/\d+)?', payload): pair = mm.group(0) lhs, rhs = re.split(r'\s*(?:->|→)\s*', pair) samples.append({"x": lhs.strip(), "y": rhs.strip()}) if samples: resp = _http_json_post(f"{PAT_API_URL}/v1/recover", {"mode": "rational", "var": "x", "samples": samples, "certify": True}) return _fmt_recover(resp) if isinstance(resp, dict) else str(resp) return "Could not parse samples for recover. Use like: recover: (1->3/2), (2->4/3)" # Variables present → prefer direct symbolic simplify unless sentence-like words appear if any(c.isalpha() for c in message): if re.search(r"[A-Za-z]{2,}", message): pass else: expr = _norm_expr(message) resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode": "simplify", "expr": expr}) if "error" in resp: return json.dumps(resp["error"]) if isinstance(resp["error"], dict) else str(resp["error"]) expr_out = resp.get("expr") or resp.get("value") or resp return expr_out if isinstance(expr_out, str) else json.dumps(expr_out) # 1) Ask Qwen for a structured plan plan_text = _call_qwen(message) try: plan = json.loads(plan_text) if plan_text else {} except Exception: plan = {} # 1a) Heuristic fallback for partial fractions when no Qwen plan is available if not plan: m_pf = re.search(r"partial\s+fractions?:\s*(.+)$", message, re.I) if not m_pf: m_pf = re.search(r"break\s+this\s+rational.*?:\s*(.+)$", message, re.I) if m_pf: frac = m_pf.group(1).strip() m2 = re.match(r"^\((.+)\)\s*/\s*\((.+)\)\s*$", frac) if not m2: m2 = re.match(r"^([^/]+)\s*/\s*([^/]+)$", frac) if m2: P_str, Q_str = m2.group(1), m2.group(2) P_coeffs = _poly_from_str(P_str) Q_coeffs = _poly_from_str(Q_str) if P_coeffs and Q_coeffs: resp = _http_json_post(f"{PAT_API_URL}/v1/partial_fractions", {"P": P_coeffs, "Q": Q_coeffs}) return _fmt_partial(resp) if isinstance(resp, dict) else str(resp) # Fallbacks: numeric-only quick path if not plan: expr = message.strip() if _EXPR_RE.fullmatch(expr.rstrip('=')): plan = {"intent": "solve", "expr": expr} if not plan: return "Could not parse. Try a shorter prompt with explicit math." return _dispatch(plan) chat = gr.ChatInterface( fn=solve_chat, title="Modulus Chat", description="Qwen parses → Modulus computes (certified).", ) if __name__ == "__main__": chat.launch()