Spaces:
Sleeping
Sleeping
| import os, json, time, hmac, hashlib, re, requests | |
| import gradio as gr | |
| # Optional shim for websockets.asyncio.client | |
| try: | |
| websockets # type: ignore | |
| except Exception: | |
| try: | |
| import websockets.asyncio.client # noqa: F401 | |
| except Exception: | |
| import websockets.legacy.client as websockets_asyncio_client # type: ignore | |
| websockets = type("_W", (), {"asyncio": type("_A", (), {"client": websockets_asyncio_client})}) | |
| PAT_API_URL = os.environ.get("PAT_API_URL", "").rstrip("/") | |
| PAT_API_KEY = os.environ.get("PAT_API_KEY", "") | |
| HMAC_SECRET = os.environ.get("HMAC_SECRET", "") | |
| # Qwen options: either call a custom endpoint (QWEN_API_URL) or HF Inference (QWEN_MODEL_ID+HF_TOKEN) | |
| QWEN_API_URL = os.environ.get("QWEN_API_URL", "").rstrip("/") | |
| QWEN_MODEL_ID = os.environ.get("QWEN_MODEL_ID", "") | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| _EXPR_RE = re.compile(r"^[0-9()+\-*/^= ]{1,200}$") | |
| PROMPT_SYSTEM = ( | |
| "You are Modulus' NL parser. Return ONLY compact JSON with an 'intent' and fields.\n" | |
| "Supported intents: simplify, expand, eval, factor, check_equal, recover, solve, logic, code, ode, pde.\n" | |
| "Schemas:\n" | |
| "- simplify|expand|eval: {intent, expr, subs?} where expr uses + - * / ^, integers, single-letter vars; subs is {var:int}.\n" | |
| "- factor: {intent, poly} where poly is textual polynomial in one var.\n" | |
| "- check_equal: {intent, left, right}.\n" | |
| "- recover: {intent, samples:[{x: string, y: string}], kind:'rational'|'poly', hint?}.\n" | |
| "- solve: {intent, expr} for numeric arithmetic.\n" | |
| "- logic: {intent, expr, compare?}.\n" | |
| "- code: {intent, expr_a, expr_b, vars?, bound?}.\n" | |
| "- ode: {intent, A, x0, dt, steps}.\n" | |
| "- pde: {intent, kind, N}.\n" | |
| "Do not add commentary. If unsure, return {intent:'simplify', expr:'...'} with your best guess." | |
| ) | |
| def _sign_headers(body: bytes) -> dict: | |
| headers = {"x-api-key": PAT_API_KEY, "content-type": "application/json"} | |
| if HMAC_SECRET: | |
| ts = str(int(time.time())) | |
| mac = hmac.new(HMAC_SECRET.encode("utf-8"), digestmod=hashlib.sha256) | |
| mac.update(ts.encode("utf-8")); mac.update(b"\n"); mac.update(body) | |
| headers.update({"x-timestamp": ts, "x-signature": mac.hexdigest()}) | |
| return headers | |
| def _call_qwen(prompt: str) -> str: | |
| if QWEN_API_URL: | |
| try: | |
| r = requests.post( | |
| f"{QWEN_API_URL}/v1/parse", | |
| headers={"content-type": "application/json"}, | |
| data=json.dumps({"prompt": prompt, "system": PROMPT_SYSTEM, "format": "json"}).encode("utf-8"), | |
| timeout=20, | |
| ) | |
| r.raise_for_status() | |
| return json.dumps(r.json()) | |
| except Exception: | |
| return "" | |
| if QWEN_MODEL_ID and HF_TOKEN: | |
| try: | |
| r = requests.post( | |
| f"https://api-inference.huggingface.co/models/{QWEN_MODEL_ID}", | |
| headers={"Authorization": f"Bearer {HF_TOKEN}", "content-type": "application/json"}, | |
| data=json.dumps({ | |
| "inputs": f"System: {PROMPT_SYSTEM}\nUser: {prompt}\nAssistant:", | |
| "parameters": {"max_new_tokens": 64, "temperature": 0.0, "return_full_text": False}, | |
| }).encode("utf-8"), | |
| timeout=30, | |
| ) | |
| r.raise_for_status() | |
| out = r.json() | |
| text = "" | |
| if isinstance(out, list) and out and isinstance(out[0], dict): | |
| text = out[0].get("generated_text", "") or out[0].get("summary_text", "") | |
| if not text: | |
| text = json.dumps(out) | |
| return text | |
| except Exception: | |
| return "" | |
| return "" | |
| def solve_chat(message: str, history: list[list[str]]): | |
| if not PAT_API_URL or not PAT_API_KEY: | |
| return "API not configured" | |
| def _insert_implicit_mul(expr: str) -> str: | |
| e = expr | |
| e = re.sub(r'(\d|\))\(', r'\1*(', e) | |
| e = re.sub(r'(\d|\))([a-zA-Z])', r'\1*\2', e) | |
| e = re.sub(r'([a-zA-Z]|\))(\d)', r'\1*\2', e) | |
| e = re.sub(r'([a-zA-Z]|\))([a-zA-Z])', r'\1*\2', e) | |
| return e | |
| def _norm_expr(msg: str) -> str: | |
| s = msg.strip() | |
| try: | |
| m = re.search(r'"([^"]+)"', s) | |
| if m: | |
| return _insert_implicit_mul(m.group(1)) | |
| m = re.match(r'^(?:simplify|expand)\s*:??\s+(.+)$', s, re.I) | |
| if m: | |
| return _insert_implicit_mul(m.group(1)) | |
| except Exception: | |
| pass | |
| return _insert_implicit_mul(s) | |
| def _http_json_post(url: str, body: dict) -> dict: | |
| data = json.dumps(body).encode("utf-8") | |
| r = requests.post(url, headers=_sign_headers(data), data=data, timeout=30) | |
| try: | |
| r.raise_for_status() | |
| except requests.RequestException as e: | |
| try: | |
| return {"error": r.json()} | |
| except Exception: | |
| return {"error": str(e)} | |
| return r.json() | |
| def _sanitize_symbolic(s: str) -> str: | |
| s2 = s | |
| s2 = s2.replace('−','-').replace('–','-').replace('—','-') | |
| s2 = s2.replace('×','*').replace('·','*') | |
| s2 = s2.replace('⁻','-').replace('^','^') | |
| s2 = s2.replace('“','"').replace('”','"').replace("’","'") | |
| s2 = re.sub(r"[^0-9a-zA-Z()+\-*/^ ]+", " ", s2) | |
| return _insert_implicit_mul(s2.strip()) | |
| def _poly_from_str(poly: str) -> list[str] | None: | |
| s = _insert_implicit_mul(poly.replace(' ', '')) | |
| pos = 0; n = len(s) | |
| def peek(): return s[pos] if pos < n else '' | |
| def consume(ch=None): | |
| nonlocal pos | |
| if ch and peek() != ch: return False | |
| pos += 1; return True | |
| def poly_const(c: int): return {0: c} | |
| def poly_var(): return {1: 1} | |
| def poly_add(a: dict, b: dict): | |
| r = dict(a) | |
| for k,v in b.items(): r[k] = r.get(k,0) + v | |
| return {k:v for k,v in r.items() if v != 0} | |
| def poly_mul(a: dict, b: dict): | |
| r: dict[int,int] = {} | |
| for da,va in a.items(): | |
| for db,vb in b.items(): | |
| r[da+db] = r.get(da+db,0) + va*vb | |
| return {k:v for k,v in r.items() if v != 0} | |
| def poly_pow(a: dict, m: int): | |
| res = {0:1}; p = dict(a); mm = int(m) | |
| if mm < 0: return None | |
| while mm > 0: | |
| if mm & 1: res = poly_mul(res, p) | |
| p = poly_mul(p, p); mm >>= 1 | |
| return res | |
| def parse_number(): | |
| nonlocal pos | |
| start = pos | |
| if peek() in '+-': pos += 1 | |
| while pos < n and s[pos].isdigit(): pos += 1 | |
| if start == pos or (s[start] in '+-' and pos == start+1): return None | |
| return int(s[start:pos]) | |
| def parse_factor(): | |
| nonlocal pos | |
| if peek().isdigit() or peek() in '+-': | |
| c = parse_number(); return poly_const(int(c)) if c is not None else None | |
| if peek() == 'x': | |
| consume('x'); base = poly_var() | |
| if peek() == '^': | |
| consume('^'); exp = parse_number() | |
| if exp is None: return None | |
| return poly_pow(base, int(exp)) | |
| return base | |
| if peek() == '(': | |
| consume('('); a = parse_expr() | |
| if not consume(')'): return None | |
| if peek() == '^': | |
| consume('^'); exp = parse_number() | |
| if exp is None: return None | |
| return poly_pow(a, int(exp)) | |
| return a | |
| return None | |
| def parse_term(): | |
| v = parse_factor() | |
| while True: | |
| if peek() == '*': | |
| consume('*'); w = parse_factor() | |
| if w is None: return None | |
| v = poly_mul(v, w) | |
| else: | |
| break | |
| return v | |
| def parse_expr(): | |
| v = parse_term() | |
| while True: | |
| if peek() == '+': | |
| consume('+'); w = parse_term() | |
| if w is None: return None | |
| v = poly_add(v, w) | |
| elif peek() == '-': | |
| consume('-'); w = parse_term() | |
| if w is None: return None | |
| v = poly_add(v, poly_mul(w, {0:-1})) | |
| else: | |
| break | |
| return v | |
| poly_map = parse_expr() | |
| if poly_map is None: return None | |
| deg_max = max(poly_map.keys()) if poly_map else 0 | |
| coeffs = [str(poly_map.get(d, 0)) for d in range(0, deg_max+1)] | |
| return coeffs | |
| def _fmt_factor(resp: dict) -> str: | |
| try: | |
| factors = resp.get("factors") or [] | |
| parts = [] | |
| for item in factors: | |
| coeffs = item.get("factor") or [] | |
| m = int(item.get("multiplicity", 1)) | |
| if len(coeffs) == 2 and coeffs[1] == "1": | |
| c0 = coeffs[0] | |
| sgn = '+' if not str(c0).startswith('-') else '' | |
| base = f"(x{sgn}{c0})" if c0 != "0" else "x" | |
| else: | |
| terms = [] | |
| for i,c in enumerate(coeffs): | |
| cstr = str(c) | |
| if cstr == "0": continue | |
| if i == 0: terms.append(cstr) | |
| elif i == 1: terms.append(f"{cstr}*x") | |
| else: terms.append(f"{cstr}*x^{i}") | |
| base = "(" + " + ".join(terms) + ")" | |
| parts.append(base if m == 1 else f"{base}^{m}") | |
| return " * ".join(parts) if parts else "1" | |
| except Exception: | |
| return json.dumps(resp) | |
| def _fmt_partial(resp: dict) -> str: | |
| try: | |
| if not resp.get("decomposed"): return "Could not decompose." | |
| poly = resp.get("polynomial_part") or [] | |
| def poly_to_str(coeffs: list[str]) -> str: | |
| terms = [] | |
| for i,c in enumerate(coeffs): | |
| cstr = str(c) | |
| if cstr == "0": continue | |
| if i == 0: terms.append(cstr) | |
| elif i == 1: terms.append(f"{cstr}*x") | |
| else: terms.append(f"{cstr}*x^{i}") | |
| return " + ".join(terms) if terms else "0" | |
| parts = [] | |
| if poly: | |
| pstr = poly_to_str(poly) | |
| if pstr and pstr != "0": parts.append(pstr) | |
| for t in resp.get("terms") or []: | |
| A = t.get("A"); root = t.get("root") | |
| if str(A) == "0": continue | |
| parts.append(f"{A}/(x-({root}))") | |
| return " + ".join(parts) if parts else "0" | |
| except Exception: | |
| return json.dumps(resp) | |
| def _fmt_recover(resp: dict) -> str: | |
| try: | |
| expr = resp.get("expr") or {} | |
| if expr.get("type") == "poly": | |
| return _fmt_partial({"decomposed": True, "polynomial_part": expr.get("coeffs", []), "terms": []}) | |
| if expr.get("type") == "rational": | |
| num = _fmt_partial({"decomposed": True, "polynomial_part": expr.get("num", []), "terms": []}) | |
| den = _fmt_partial({"decomposed": True, "polynomial_part": expr.get("den", []), "terms": []}) | |
| return f"({num})/({den})" | |
| return json.dumps(resp) | |
| except Exception: | |
| return json.dumps(resp) | |
| def _dispatch(plan: dict) -> str: | |
| intent = (plan.get('intent') or '').lower() | |
| if intent in ('simplify','expand','eval'): | |
| expr = _norm_expr(plan.get('expr','')) | |
| mode = 'simplify' if intent in ('simplify','expand') else 'eval' | |
| body = {"mode": mode, "expr": expr} | |
| if mode == 'eval' and isinstance(plan.get('subs'), dict): | |
| body['subs'] = {k: int(str(v)) for k,v in plan['subs'].items() if re.fullmatch(r"[a-z]", k)} | |
| resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", body) | |
| return resp.get("expr") or resp.get("value") or json.dumps(resp) | |
| if intent == 'factor': | |
| poly = _insert_implicit_mul(plan.get('poly','')) | |
| coeffs = _try_poly_coeffs(poly) | |
| if coeffs: | |
| resp = _http_json_post(f"{PAT_API_URL}/v1/factor", {"coeffs": coeffs}) | |
| return _fmt_factor(resp) if isinstance(resp, dict) else str(resp) | |
| resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode":"simplify","expr": poly}) | |
| return resp.get("expr") or resp.get("value") or json.dumps(resp) | |
| if intent == 'check_equal': | |
| left = _insert_implicit_mul(plan.get('left','')) | |
| right = _insert_implicit_mul(plan.get('right','')) | |
| resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode":"check_equal","expr":"","left":left,"right":right,"certify_points":["-2","-1","0","1","2","3"]}) | |
| if isinstance(resp, dict) and resp.get('error') or (isinstance(resp, dict) and 'detail' in resp and resp.get('detail')): | |
| diff_expr = f"({left})-({right})" | |
| alt = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode":"simplify","expr": diff_expr}) | |
| if isinstance(alt, dict): | |
| val = str(alt.get('expr') or alt.get('value') or '') | |
| if val.strip() in ('0','0/1'): | |
| return "equal: True" | |
| ok = bool(resp.get("equal")) | |
| return f"equal: {ok}" | |
| if intent == 'recover': | |
| kind = (plan.get('kind') or 'rational').lower() | |
| samples = plan.get('samples') or [] | |
| if not samples: | |
| return "No samples provided" | |
| body = {"mode": 'poly' if kind=='poly' else 'rational', "var": "x", "samples": samples, "certify": True} | |
| resp = _http_json_post(f"{PAT_API_URL}/v1/recover", body) | |
| return _fmt_recover(resp) if isinstance(resp, dict) else str(resp) | |
| if intent == 'logic': | |
| resp = _http_json_post(f"{PAT_API_URL}/v1/logic_simplify", {"expr": plan.get('expr',''), "compare": plan.get('compare')}) | |
| return json.dumps(resp) | |
| if intent == 'code': | |
| resp = _http_json_post(f"{PAT_API_URL}/v1/code_equiv", { | |
| "expr_a": plan.get('expr_a',''), | |
| "expr_b": plan.get('expr_b',''), | |
| "vars": plan.get('vars') or ['x'], | |
| "bound": int(plan.get('bound') or 5), | |
| }) | |
| return json.dumps(resp) | |
| if intent == 'ode': | |
| resp = _http_json_post(f"{PAT_API_URL}/v1/ode_solve", {"A": plan.get('A'), "x0": plan.get('x0'), "dt": plan.get('dt'), "steps": int(plan.get('steps') or 1)}) | |
| return json.dumps(resp) | |
| if intent == 'pde': | |
| resp = _http_json_post(f"{PAT_API_URL}/v1/pde_solve", {"kind": plan.get('kind') or 'poisson_1d', "N": int(plan.get('N') or 8)}) | |
| return json.dumps(resp) | |
| if intent == 'solve': | |
| expr = (plan.get('expr') or '').strip() | |
| if not _EXPR_RE.fullmatch(expr.rstrip('=')): | |
| return "Could not parse the question into a numeric expression." | |
| body = {"question": expr.rstrip('='), "options": {"certificates": True}} | |
| resp = _http_json_post(f"{PAT_API_URL}/v1/solve", body) | |
| return str(resp.get('answer')) if isinstance(resp, dict) else str(resp) | |
| return "Unknown intent" | |
| def _try_poly_coeffs(expr: str) -> list[str] | None: | |
| s = expr.replace(' ', '') | |
| if any(ch in s for ch in '()/'): | |
| return None | |
| var = None; pos = 0; n = len(s) | |
| deg_to_coeff: dict[int, int] = {} | |
| while pos < n: | |
| sign = 1 | |
| if s[pos] in '+-': | |
| sign = -1 if s[pos] == '-' else 1 | |
| pos += 1 | |
| start = pos | |
| while pos < n and s[pos].isdigit(): | |
| pos += 1 | |
| coeff = int(s[start:pos] or '1') | |
| if pos < n and s[pos] == '*': | |
| pos += 1 | |
| exp = 0 | |
| if pos < n and s[pos].isalpha(): | |
| v = s[pos] | |
| if var is None: var = v | |
| elif var != v: return None | |
| pos += 1; exp = 1 | |
| if pos < n and s[pos] == '^': | |
| pos += 1; start = pos | |
| while pos < n and s[pos].isdigit(): pos += 1 | |
| if start == pos: return None | |
| exp = int(s[start:pos]) | |
| deg_to_coeff[exp] = deg_to_coeff.get(exp, 0) + sign * coeff | |
| if pos < n and s[pos] not in '+-': return None | |
| if not deg_to_coeff: return None | |
| max_deg = max(deg_to_coeff.keys()) | |
| return [str(deg_to_coeff.get(d, 0)) for d in range(0, max_deg + 1)] | |
| # Command routing (string-first UX) | |
| msg = message.strip() | |
| # Simplify/Expand: explicit commands bypass NL parsing | |
| m_se = re.match(r'^(simplify|expand)\s*:??\s+(.+)$', msg, re.I) | |
| if m_se: | |
| expr = _insert_implicit_mul(m_se.group(2)) | |
| resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode": "simplify", "expr": expr}) | |
| if "error" in resp: | |
| return json.dumps(resp["error"]) if isinstance(resp["error"], dict) else str(resp["error"]) | |
| return resp.get("expr") or resp.get("value") or json.dumps(resp) | |
| # Equality: supports "check_equal: A = B" or "A = B", or single side to simplify | |
| if re.match(r'^(check_equal|equal|equiv|prove)\s*:?', msg, re.I) or re.search(r'(==|=|≟|\?=|≡)', msg): | |
| payload = msg | |
| mcmd = re.match(r'^(check_equal|equal|equiv|prove)\s*:??\s*(.+)$', msg, re.I) | |
| if mcmd: | |
| payload = mcmd.group(2) | |
| parts = re.split(r'\s*(?:==|=|≟|\?=|≡)\s*', payload, maxsplit=1) | |
| if len(parts) == 2: | |
| left = _sanitize_symbolic(parts[0]) | |
| right = _sanitize_symbolic(parts[1]) | |
| resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode": "check_equal", "expr": "", "left": left, "right": right, "certify_points": ["-2","-1","0","1","2","3"]}) | |
| if "error" in resp: | |
| return json.dumps(resp["error"]) if isinstance(resp["error"], dict) else str(resp["error"]) | |
| return f"equal: {bool(resp.get('equal'))}" | |
| expr = _sanitize_symbolic(payload) | |
| resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode": "simplify", "expr": expr}) | |
| if "error" in resp: | |
| return json.dumps(resp["error"]) if isinstance(resp["error"], dict) else str(resp["error"]) | |
| return resp.get("expr") or resp.get("value") or json.dumps(resp) | |
| # Factor | |
| m = re.match(r'^(factor)\s*:??\s+(.+)$', msg, re.I) | |
| if m: | |
| expr = _sanitize_symbolic(m.group(2)) | |
| coeffs = _try_poly_coeffs(expr) | |
| if coeffs: | |
| resp = _http_json_post(f"{PAT_API_URL}/v1/factor", {"coeffs": coeffs}) | |
| return _fmt_factor(resp) if isinstance(resp, dict) else str(resp) | |
| resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode": "simplify", "expr": expr}) | |
| if "error" in resp: | |
| return json.dumps(resp["error"]) if isinstance(resp["error"], dict) else str(resp["error"]) | |
| return resp.get("expr") or resp.get("value") or json.dumps(resp) | |
| # Recover | |
| m = re.match(r'^recover\s*:??\s+(.+)$', msg, re.I) | |
| if m: | |
| payload = m.group(1) | |
| samples = [] | |
| for mm in re.finditer(r'[-\d]+\s*(?:->|→)\s*[-\d]+(?:/\d+)?', payload): | |
| pair = mm.group(0) | |
| lhs, rhs = re.split(r'\s*(?:->|→)\s*', pair) | |
| samples.append({"x": lhs.strip(), "y": rhs.strip()}) | |
| if samples: | |
| resp = _http_json_post(f"{PAT_API_URL}/v1/recover", {"mode": "rational", "var": "x", "samples": samples, "certify": True}) | |
| return _fmt_recover(resp) if isinstance(resp, dict) else str(resp) | |
| return "Could not parse samples for recover. Use like: recover: (1->3/2), (2->4/3)" | |
| # Variables present → prefer direct symbolic simplify unless sentence-like words appear | |
| if any(c.isalpha() for c in message): | |
| if re.search(r"[A-Za-z]{2,}", message): | |
| pass | |
| else: | |
| expr = _norm_expr(message) | |
| resp = _http_json_post(f"{PAT_API_URL}/v1/symbolic", {"mode": "simplify", "expr": expr}) | |
| if "error" in resp: | |
| return json.dumps(resp["error"]) if isinstance(resp["error"], dict) else str(resp["error"]) | |
| expr_out = resp.get("expr") or resp.get("value") or resp | |
| return expr_out if isinstance(expr_out, str) else json.dumps(expr_out) | |
| # 1) Ask Qwen for a structured plan | |
| plan_text = _call_qwen(message) | |
| try: | |
| plan = json.loads(plan_text) if plan_text else {} | |
| except Exception: | |
| plan = {} | |
| # 1a) Heuristic fallback for partial fractions when no Qwen plan is available | |
| if not plan: | |
| m_pf = re.search(r"partial\s+fractions?:\s*(.+)$", message, re.I) | |
| if not m_pf: | |
| m_pf = re.search(r"break\s+this\s+rational.*?:\s*(.+)$", message, re.I) | |
| if m_pf: | |
| frac = m_pf.group(1).strip() | |
| m2 = re.match(r"^\((.+)\)\s*/\s*\((.+)\)\s*$", frac) | |
| if not m2: | |
| m2 = re.match(r"^([^/]+)\s*/\s*([^/]+)$", frac) | |
| if m2: | |
| P_str, Q_str = m2.group(1), m2.group(2) | |
| P_coeffs = _poly_from_str(P_str) | |
| Q_coeffs = _poly_from_str(Q_str) | |
| if P_coeffs and Q_coeffs: | |
| resp = _http_json_post(f"{PAT_API_URL}/v1/partial_fractions", {"P": P_coeffs, "Q": Q_coeffs}) | |
| return _fmt_partial(resp) if isinstance(resp, dict) else str(resp) | |
| # Fallbacks: numeric-only quick path | |
| if not plan: | |
| expr = message.strip() | |
| if _EXPR_RE.fullmatch(expr.rstrip('=')): | |
| plan = {"intent": "solve", "expr": expr} | |
| if not plan: | |
| return "Could not parse. Try a shorter prompt with explicit math." | |
| return _dispatch(plan) | |
| chat = gr.ChatInterface( | |
| fn=solve_chat, | |
| title="Modulus Chat", | |
| description="Qwen parses → Modulus computes (certified).", | |
| ) | |
| if __name__ == "__main__": | |
| chat.launch() |