Spaces:
Running
Running
| import streamlit as st | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import sympy as sp | |
| from sympy.parsing.sympy_parser import parse_expr | |
| from sympy.utilities.lambdify import lambdify | |
| import io | |
| import os | |
| import base64 | |
| import re | |
| from datetime import datetime | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| # Create required directories | |
| os.makedirs("uploads", exist_ok=True) | |
| # Page config | |
| st.set_page_config(page_title="Equation2Graph", page_icon="π", layout="wide") | |
| # Session state | |
| if 'history' not in st.session_state: | |
| st.session_state.history = [] | |
| if 'equation' not in st.session_state: | |
| st.session_state.equation = "" | |
| # Initialize NLP model | |
| def load_nlp_model(): | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained("t5-small") | |
| model = AutoModelForSeq2SeqLM.from_pretrained("t5-small") | |
| return tokenizer, model | |
| except Exception as e: | |
| st.warning(f"β AI model could not be loaded: {str(e)}. Falling back to basic equation parsing.") | |
| return None, None | |
| # Load models | |
| tokenizer, model = load_nlp_model() | |
| def convert_nl_to_equation(nl_input): | |
| """Convert natural language to equation using AI model or fallback to basic parsing""" | |
| if tokenizer and model: | |
| try: | |
| inputs = tokenizer.encode(nl_input, return_tensors="pt", truncation=True) | |
| outputs = model.generate(inputs, max_length=50) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| except Exception as e: | |
| st.error(f"Error in AI conversion: {str(e)}") | |
| return basic_nl_to_equation(nl_input) | |
| else: | |
| return basic_nl_to_equation(nl_input) | |
| def basic_nl_to_equation(text): | |
| """Basic natural language to equation conversion""" | |
| # Define basic math terms and advanced functions | |
| terms = { | |
| # Basic arithmetic | |
| 'squared': '^2', | |
| 'cubed': '^3', | |
| 'plus': '+', | |
| 'minus': '-', | |
| 'times': '*', | |
| 'multiply by': '*', | |
| 'multiplied by': '*', | |
| 'divided by': '/', | |
| 'over': '/', | |
| 'to the power of': '^', | |
| 'power': '^', | |
| 'raised to': '^', | |
| # Trigonometric functions | |
| 'sin of': 'sin(', | |
| 'sine of': 'sin(', | |
| 'cos of': 'cos(', | |
| 'cosine of': 'cos(', | |
| 'tan of': 'tan(', | |
| 'tangent of': 'tan(', | |
| 'arcsin of': 'asin(', | |
| 'arccos of': 'acos(', | |
| 'arctan of': 'atan(', | |
| 'csc of': '1/sin(', | |
| 'sec of': '1/cos(', | |
| 'cot of': '1/tan(', | |
| # Exponential and logarithmic | |
| 'exponential of': 'exp(', | |
| 'exp of': 'exp(', | |
| 'e to the': 'exp(', | |
| 'e raised to': 'exp(', | |
| 'ln of': 'log(', | |
| 'log of': 'log(', | |
| 'log base 10 of': 'log10(', | |
| 'logarithm of': 'log(', | |
| 'natural log of': 'log(', | |
| # Other mathematical functions | |
| 'absolute value of': 'abs(', | |
| 'modulus of': 'abs(', | |
| 'square root of': 'sqrt(', | |
| 'sqrt of': 'sqrt(', | |
| 'cube root of': 'cbrt(', | |
| # Hyperbolic functions | |
| 'sinh of': 'sinh(', | |
| 'cosh of': 'cosh(', | |
| 'tanh of': 'tanh(', | |
| # Piecewise functions | |
| 'piecewise': 'piecewise', | |
| 'step function': 'Heaviside(', | |
| # Multi-part expressions | |
| 'divided by the sum of': '/(', | |
| 'divided by the product of': '/(', | |
| 'multiplied by the sum of': '*(', | |
| 'multiplied by the difference of': '*(', | |
| 'multiplied by the product of': '*(' | |
| } | |
| # Normalize text | |
| text = text.lower().strip() | |
| # Replace terms | |
| for term, symbol in terms.items(): | |
| text = text.replace(term, symbol) | |
| # Special case handling for complex expressions | |
| text = text.replace('the sum of', '(') | |
| text = text.replace('the difference between', '(') | |
| text = text.replace('the product of', '(') | |
| text = text.replace('the quotient of', '(') | |
| # Convert verbal numbers to digits | |
| number_map = { | |
| 'zero': '0', 'one': '1', 'two': '2', 'three': '3', 'four': '4', | |
| 'five': '5', 'six': '6', 'seven': '7', 'eight': '8', 'nine': '9', | |
| 'ten': '10', 'pi': 'pi', 'infinity': 'oo' | |
| } | |
| for word, num in number_map.items(): | |
| # Use word boundaries to avoid partial matches | |
| text = re.sub(r'\b' + word + r'\b', num, text) | |
| # Clean up the equation | |
| text = text.replace(' ', '') | |
| text = text.replace('and', '+') | |
| # Balance parentheses | |
| while text.count('(') > text.count(')'): | |
| text += ')' # Close any unclosed parentheses | |
| # Remove any invalid characters | |
| text = ''.join(c for c in text if c.isalnum() or c in '+-*/^().Οe') | |
| # Replace common mathematical constants | |
| text = text.replace('pi', 'pi') | |
| text = text.replace('Ο', 'pi') | |
| text = text.replace('e', 'E') # Avoid confusion with Euler's number | |
| return text | |
| def parse_and_validate_equation(equation_str): | |
| try: | |
| # Clean up the equation string | |
| equation_str = equation_str.strip() | |
| # Replace carets with sympy's power operator | |
| equation_str = equation_str.replace('^', '**') | |
| # Handle common alternative syntax | |
| equation_str = equation_str.replace('sqrt', 'sp.sqrt') | |
| equation_str = equation_str.replace('cbrt', 'lambda x: sp.root(x, 3)') | |
| # Trigonometric functions | |
| equation_str = equation_str.replace('sin', 'sp.sin') | |
| equation_str = equation_str.replace('cos', 'sp.cos') | |
| equation_str = equation_str.replace('tan', 'sp.tan') | |
| equation_str = equation_str.replace('asin', 'sp.asin') | |
| equation_str = equation_str.replace('acos', 'sp.acos') | |
| equation_str = equation_str.replace('atan', 'sp.atan') | |
| # Hyperbolic functions | |
| equation_str = equation_str.replace('sinh', 'sp.sinh') | |
| equation_str = equation_str.replace('cosh', 'sp.cosh') | |
| equation_str = equation_str.replace('tanh', 'sp.tanh') | |
| # Exponential and logarithmic | |
| equation_str = equation_str.replace('exp', 'sp.exp') | |
| equation_str = equation_str.replace('log', 'sp.log') | |
| equation_str = equation_str.replace('log10', 'lambda x: sp.log(x, 10)') | |
| # Special functions | |
| equation_str = equation_str.replace('abs', 'sp.Abs') | |
| equation_str = equation_str.replace('gamma', 'sp.gamma') | |
| equation_str = equation_str.replace('factorial', 'sp.factorial') | |
| # Handle piecewise functions | |
| equation_str = equation_str.replace('piecewise', 'sp.Piecewise') | |
| # Constants | |
| equation_str = equation_str.replace('pi', 'sp.pi') | |
| equation_str = equation_str.replace('E', 'sp.E') | |
| equation_str = equation_str.replace('oo', 'sp.oo') | |
| # Create a Symbol for x | |
| x = sp.Symbol('x') | |
| # Check for multivariate expressions and handle them | |
| if any(var in equation_str for var in ['y', 'z']): | |
| st.warning("Note: Only the variable 'x' is supported. Other variables will be treated as symbols.") | |
| # Safety check for extremely long expressions | |
| if len(equation_str) > 1000: | |
| st.warning("This is a very complex equation. Rendering may take longer than usual.") | |
| # Parse the expression using sympy with a timeout | |
| try: | |
| expr = parse_expr(equation_str, local_dict={'x': x}) | |
| except Exception as parsing_error: | |
| st.error(f"Could not parse the expression. Please check the syntax: {str(parsing_error)}") | |
| return None, None | |
| # Create a lambda function for numerical evaluation with safety checks | |
| try: | |
| f = lambdify(x, expr, modules=['numpy', 'scipy', 'sympy']) | |
| # Test the function with a simple value to verify it works | |
| test_val = f(0) | |
| # Check for complex results | |
| if isinstance(test_val, complex): | |
| st.info("This equation may produce complex values for some inputs. Only the real part will be displayed.") | |
| # Create a wrapped function that returns only the real part | |
| orig_f = f | |
| def real_part_wrapper(x): | |
| return np.real(orig_f(x)) | |
| f = real_part_wrapper | |
| return f, expr | |
| except Exception as eval_error: | |
| st.error(f"Error creating evaluable function: {str(eval_error)}") | |
| return None, None | |
| except Exception as e: | |
| st.error(f"Error processing equation: {str(e)}") | |
| return None, None | |
| def plot_equation(f, expr, x_range=(-10, 10), points=1000): | |
| try: | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| x = np.linspace(x_range[0], x_range[1], points) | |
| # Handle potential division by zero and other numerical errors | |
| with np.errstate(divide='ignore', invalid='ignore'): | |
| y = f(x) | |
| # Replace infinities and NaNs for proper plotting | |
| if isinstance(y, np.ndarray): | |
| # Convert infinities to NaN to get breaks in the plot line | |
| y[~np.isfinite(y)] = np.nan | |
| # If entirely NaN, show an error message | |
| if np.all(np.isnan(y)): | |
| st.warning("This equation doesn't produce valid results in the specified range.") | |
| return None | |
| # Plot with NaNs to show discontinuities properly | |
| ax.plot(x, y, '-b', label=str(expr)) | |
| else: | |
| # Scalar result case | |
| ax.axhline(y=float(y), color='b', linestyle='-', label=str(expr)) | |
| # Auto-adjust y-range if there are valid values | |
| y_valid = y[np.isfinite(y)] if isinstance(y, np.ndarray) else np.array([y]) if np.isfinite(y) else np.array([]) | |
| if len(y_valid) > 0: | |
| y_min, y_max = np.min(y_valid), np.max(y_valid) | |
| y_range = y_max - y_min | |
| if y_range > 0: | |
| ax.set_ylim([y_min - 0.1*y_range, y_max + 0.1*y_range]) | |
| ax.grid(True, alpha=0.3) | |
| ax.axhline(y=0, color='k', linestyle='-', alpha=0.3) | |
| ax.axvline(x=0, color='k', linestyle='-', alpha=0.3) | |
| ax.set_xlabel('x') | |
| ax.set_ylabel('y') | |
| ax.set_title(f'Graph of {str(expr)}') | |
| ax.legend() | |
| return fig | |
| except Exception as e: | |
| st.error(f"Error plotting equation: {str(e)}") | |
| return None | |
| def get_download_link(fig): | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png', bbox_inches='tight') | |
| buf.seek(0) | |
| b64 = base64.b64encode(buf.read()).decode() | |
| return f'<a href="data:image/png;base64,{b64}" download="equation_plot.png">π₯ Download Plot</a>' | |
| # Header | |
| st.title("π Equation2Graph") | |
| st.markdown("Visualize mathematical equations instantly, from symbolic or natural language input!") | |
| # Columns | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| # Natural language input | |
| st.subheader("π£ Describe your equation (optional)") | |
| nl_input = st.text_input( | |
| "Enter your equation in plain English:", | |
| placeholder="e.g., the square of x plus two times x plus one" | |
| ) | |
| equation = "" | |
| if nl_input: | |
| equation = convert_nl_to_equation(nl_input) | |
| if equation: | |
| st.success(f"Converted to equation: {equation}") | |
| else: | |
| st.error("Failed to convert input") # Direct equation input (fallback or primary) | |
| equation = st.text_area( | |
| "Or enter your equation directly:", | |
| value=equation, | |
| placeholder="e.g., x^2 + 2*x + 1", | |
| height=100 | |
| ) | |
| # Plot settings | |
| with st.expander("β Plot Settings"): | |
| col_a, col_b = st.columns(2) | |
| with col_a: | |
| x_min = st.number_input("X-axis minimum", value=-10.0) | |
| with col_b: | |
| x_max = st.number_input("X-axis maximum", value=10.0) | |
| points = st.slider("Number of points", 100, 2000, 1000) | |
| # Process and plot | |
| if equation: | |
| f, expr = parse_and_validate_equation(equation) | |
| if f and expr: | |
| fig = plot_equation(f, expr, (x_min, x_max), points) | |
| if fig: | |
| st.pyplot(fig) | |
| st.markdown(get_download_link(fig), unsafe_allow_html=True) | |
| if equation not in [h['equation'] for h in st.session_state.history]: | |
| st.session_state.history.append({ | |
| 'equation': equation, | |
| 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| }) | |
| with col2: | |
| st.subheader("π Example Equations") | |
| examples = [ | |
| "x^2 + 2*x + 1", | |
| "sin(x) + cos(x)", | |
| "exp(-x^2)", | |
| "x^3 - 4*x", | |
| "tan(x)", | |
| "log(abs(x))", | |
| "sin(x^2)*exp(-0.1*x^2)", | |
| "(x^4 - 5*x^2 + 4)/(x^2 + 1)", | |
| "sqrt(abs(x))*sin(10*x)", | |
| "1/(1+exp(-x))", | |
| "sin(1/x)", | |
| "(x^5-5*x^3+4*x)/(x^2+1)", | |
| "piecewise((x^2, x<0), (x^3, x>=0))", | |
| "sinh(x)*cosh(x)" | |
| ] | |
| for ex in examples: | |
| if st.button(ex): | |
| st.experimental_set_query_params(equation=ex) | |
| st.subheader("π Recent Equations") | |
| for item in reversed(st.session_state.history[-5:]): | |
| st.text(f"{item['equation']}") | |
| st.caption(f"Plotted on: {item['timestamp']}") | |
| with st.expander("βΉ How to use Equation2Graph"): | |
| st.markdown(""" | |
| ### Instructions: | |
| 1. Describe or type your equation | |
| 2. The graph updates automatically | |
| 3. Adjust plot settings if needed | |
| 4. Download the graph as PNG | |
| ### Natural Language Support: | |
| - Try: "the square of x plus two times x" | |
| - You can still type equations like: x^2 + 2*x | |
| ### Supported Math Functions: | |
| - Basic: +, -, *, /, ^ | |
| - Trig: sin(x), cos(x), tan(x), arcsin(x), arccos(x), arctan(x) | |
| - Exponential/log: exp(x), log(x), log10(x) | |
| - Hyperbolic: sinh(x), cosh(x), tanh(x) | |
| - Piecewise: piecewise((expr1, cond1), (expr2, cond2), ...) | |
| - Advanced: sqrt(x), abs(x), combinations of functions | |
| ### Complex Equations: | |
| You can enter lengthy and complex equations such as: | |
| - Combinations of multiple functions: `sin(x^2)*exp(-0.1*x^2)` | |
| - Rational functions: `(x^4 - 5*x^2 + 4)/(x^2 + 1)` | |
| - Piecewise functions: `piecewise((x^2, x<0), (x^3, x>=0))` | |
| - Highly oscillatory functions: `sin(1/x)` | |
| - Special functions: `gamma(x)`, `factorial(x)` | |
| - Any valid mathematical expression with variable x | |
| ### Tips for Complex Equations: | |
| - For equations with discontinuities, try adjusting the plot range | |
| - For highly oscillatory functions, increase the number of points in Plot Settings | |
| - Use the piecewise function for functions defined differently across domains | |
| - If your equation produces complex values, only the real part will be displayed | |
| - Extremely long expressions (>1000 characters) are supported but may take longer to render | |
| """) | |
| # Footer | |
| st.markdown("---") | |
| st.markdown("Equation2Graph | Now with π§ AI-powered equation parser | Made with β€ by MathVizMinds") | |