|
|
| import os, re, time |
| import gradio as gr |
| import sympy as sp |
|
|
| TITLE = "LanguageBridge — Math Hybrid (Phi + SymPy)" |
|
|
| DEFAULT_MODEL_ID = os.environ.get("MODEL_ID", "microsoft/phi-2") |
| MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "96")) |
|
|
| _llm = None |
| _tok = None |
|
|
| def _try_load_llm(): |
| global _llm, _tok |
| if _llm is not None: |
| return True |
| try: |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| _tok = AutoTokenizer.from_pretrained(DEFAULT_MODEL_ID, use_fast=True) |
| _llm = AutoModelForCausalLM.from_pretrained( |
| DEFAULT_MODEL_ID, |
| torch_dtype=torch.float32, |
| device_map="auto" |
| ) |
| return True |
| except Exception: |
| _llm = None |
| _tok = None |
| return False |
|
|
| def _llm_parse_to_math(text): |
| |
| if not _try_load_llm(): |
| return text |
| import torch |
| from transformers import AutoTokenizer |
| prompt = ( |
| "Rewrite the following math question as a pure algebraic expression " |
| "or semicolon-separated equations suitable for SymPy. " |
| "Return only the expression/equations without explanations.\n\n" |
| f"Question:\n{text}\n\nExpression:\n" |
| ) |
| ids = _tok(prompt, return_tensors="pt").input_ids.to(_llm.device) |
| gen = _llm.generate( |
| ids, |
| max_new_tokens=MAX_NEW_TOKENS, |
| do_sample=False, |
| pad_token_id=_tok.eos_token_id |
| ) |
| out = _tok.decode(gen[0], skip_special_tokens=True) |
| if "Expression:" in out: |
| out = out.split("Expression:", 1)[-1].strip() |
| out = out.replace("^", "**").strip() |
| return out or text |
|
|
| def _solve_with_sympy(q): |
| q = (q or "").strip() |
| if not q: |
| return "Please enter an expression or equations. Example: 2*x + 5 = 11; or: sin(x)**2 + cos(x)**2" |
|
|
| if "=" in q: |
| parts = [] |
| for seg in q.split(";"): |
| parts.extend([s for s in seg.split("\n")]) |
| eqs = [] |
| syms = set() |
| for s in [p.strip() for p in parts if p.strip()]: |
| if "=" not in s: |
| expr = sp.sympify(s) |
| eqs.append(sp.Eq(expr, 0)) |
| syms |= expr.free_symbols |
| continue |
| left, right = s.split("=", 1) |
| L = sp.sympify(left) |
| R = sp.sympify(right) |
| eqs.append(sp.Eq(L, R)) |
| syms |= L.free_symbols |
| syms |= R.free_symbols |
| if not syms: |
| syms = {sp.symbols("x")} |
| sol = sp.solve(eqs, list(syms), dict=True) |
| if not sol: |
| return "No solution or more conditions are required." |
| lines = [] |
| for i, d in enumerate(sol, 1): |
| pretty = ", ".join([f"{k} = {sp.simplify(v)}" for k, v in d.items()]) |
| lines.append(f"Solution {i}: {pretty}") |
| return "\n".join(lines) |
|
|
| try: |
| expr = sp.sympify(q) |
| except Exception as e: |
| return f"SymPy parse failed: {e}" |
|
|
| tips = [] |
| try: |
| tips.append(f"Simplify: {sp.simplify(expr)}") |
| except Exception: |
| pass |
| try: |
| fac = sp.factor(expr) |
| if fac != expr: |
| tips.append(f"Factor: {fac}") |
| except Exception: |
| pass |
| try: |
| x = list(expr.free_symbols)[0] if expr.free_symbols else sp.symbols("x") |
| tips.append(f"d/d{x}: {sp.diff(expr, x)}") |
| tips.append(f"integrate wrt {x}: {sp.integrate(expr, x)}") |
| except Exception: |
| pass |
| return "\n".join(tips) if tips else f"Result: {expr}" |
|
|
| def hybrid_solve(user_text, use_llm): |
| text = (user_text or "").strip() |
| if not text: |
| return "Please enter an expression or problem statement." |
| if use_llm: |
| normalized = _llm_parse_to_math(text) |
| header = f"LLM normalized -> {normalized}\n---\n" |
| return header + _solve_with_sympy(normalized) |
| else: |
| return _solve_with_sympy(text) |
|
|
| with gr.Blocks(title=TITLE) as demo: |
| gr.Markdown(f"## {TITLE}\nPaste text or math: LLM helps rewrite -> SymPy solves (optional).") |
| q = gr.Textbox(lines=6, label="Problem / Expression (semicolon or newline for systems)") |
| use_llm = gr.Checkbox(value=False, label="Use Phi-2 to normalize text first (optional)") |
| out = gr.Textbox(lines=12, label="Output") |
| btn = gr.Button("Solve", variant="primary") |
| btn.click(hybrid_solve, inputs=[q, use_llm], outputs=out, concurrency_limit=1) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|