Equation2Graph / app_enhanced.py
KarthikGarlapati's picture
Upload 10 files
434dce9 verified
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
import sympy as sp
from sympy.parsing.sympy_parser import parse_expr
from sympy.utilities.lambdify import lambdify
import io
import os
import base64
import re
from datetime import datetime
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Create required directories
os.makedirs("uploads", exist_ok=True)
# Page config
st.set_page_config(page_title="Equation2Graph", page_icon="πŸ“ˆ", layout="wide")
# Session state
if 'history' not in st.session_state:
st.session_state.history = []
if 'equation' not in st.session_state:
st.session_state.equation = ""
# Initialize NLP model
@st.cache_resource
def load_nlp_model():
try:
tokenizer = AutoTokenizer.from_pretrained("t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
return tokenizer, model
except Exception as e:
st.warning(f"⚠ AI model could not be loaded: {str(e)}. Falling back to basic equation parsing.")
return None, None
# Load models
tokenizer, model = load_nlp_model()
def convert_nl_to_equation(nl_input):
"""Convert natural language to equation using AI model or fallback to basic parsing"""
if tokenizer and model:
try:
inputs = tokenizer.encode(nl_input, return_tensors="pt", truncation=True)
outputs = model.generate(inputs, max_length=50)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
st.error(f"Error in AI conversion: {str(e)}")
return basic_nl_to_equation(nl_input)
else:
return basic_nl_to_equation(nl_input)
def basic_nl_to_equation(text):
"""Basic natural language to equation conversion"""
# Define basic math terms and advanced functions
terms = {
# Basic arithmetic
'squared': '^2',
'cubed': '^3',
'plus': '+',
'minus': '-',
'times': '*',
'multiply by': '*',
'multiplied by': '*',
'divided by': '/',
'over': '/',
'to the power of': '^',
'power': '^',
'raised to': '^',
# Trigonometric functions
'sin of': 'sin(',
'sine of': 'sin(',
'cos of': 'cos(',
'cosine of': 'cos(',
'tan of': 'tan(',
'tangent of': 'tan(',
'arcsin of': 'asin(',
'arccos of': 'acos(',
'arctan of': 'atan(',
'csc of': '1/sin(',
'sec of': '1/cos(',
'cot of': '1/tan(',
# Exponential and logarithmic
'exponential of': 'exp(',
'exp of': 'exp(',
'e to the': 'exp(',
'e raised to': 'exp(',
'ln of': 'log(',
'log of': 'log(',
'log base 10 of': 'log10(',
'logarithm of': 'log(',
'natural log of': 'log(',
# Other mathematical functions
'absolute value of': 'abs(',
'modulus of': 'abs(',
'square root of': 'sqrt(',
'sqrt of': 'sqrt(',
'cube root of': 'cbrt(',
# Hyperbolic functions
'sinh of': 'sinh(',
'cosh of': 'cosh(',
'tanh of': 'tanh(',
# Piecewise functions
'piecewise': 'piecewise',
'step function': 'Heaviside(',
# Multi-part expressions
'divided by the sum of': '/(',
'divided by the product of': '/(',
'multiplied by the sum of': '*(',
'multiplied by the difference of': '*(',
'multiplied by the product of': '*('
}
# Normalize text
text = text.lower().strip()
# Replace terms
for term, symbol in terms.items():
text = text.replace(term, symbol)
# Special case handling for complex expressions
text = text.replace('the sum of', '(')
text = text.replace('the difference between', '(')
text = text.replace('the product of', '(')
text = text.replace('the quotient of', '(')
# Convert verbal numbers to digits
number_map = {
'zero': '0', 'one': '1', 'two': '2', 'three': '3', 'four': '4',
'five': '5', 'six': '6', 'seven': '7', 'eight': '8', 'nine': '9',
'ten': '10', 'pi': 'pi', 'infinity': 'oo'
}
for word, num in number_map.items():
# Use word boundaries to avoid partial matches
text = re.sub(r'\b' + word + r'\b', num, text)
# Clean up the equation
text = text.replace(' ', '')
text = text.replace('and', '+')
# Balance parentheses
while text.count('(') > text.count(')'):
text += ')' # Close any unclosed parentheses
# Remove any invalid characters
text = ''.join(c for c in text if c.isalnum() or c in '+-*/^().Ο€e')
# Replace common mathematical constants
text = text.replace('pi', 'pi')
text = text.replace('Ο€', 'pi')
text = text.replace('e', 'E') # Avoid confusion with Euler's number
return text
def parse_and_validate_equation(equation_str):
try:
# Clean up the equation string
equation_str = equation_str.strip()
# Replace carets with sympy's power operator
equation_str = equation_str.replace('^', '**')
# Handle common alternative syntax
equation_str = equation_str.replace('sqrt', 'sp.sqrt')
equation_str = equation_str.replace('cbrt', 'lambda x: sp.root(x, 3)')
# Trigonometric functions
equation_str = equation_str.replace('sin', 'sp.sin')
equation_str = equation_str.replace('cos', 'sp.cos')
equation_str = equation_str.replace('tan', 'sp.tan')
equation_str = equation_str.replace('asin', 'sp.asin')
equation_str = equation_str.replace('acos', 'sp.acos')
equation_str = equation_str.replace('atan', 'sp.atan')
# Hyperbolic functions
equation_str = equation_str.replace('sinh', 'sp.sinh')
equation_str = equation_str.replace('cosh', 'sp.cosh')
equation_str = equation_str.replace('tanh', 'sp.tanh')
# Exponential and logarithmic
equation_str = equation_str.replace('exp', 'sp.exp')
equation_str = equation_str.replace('log', 'sp.log')
equation_str = equation_str.replace('log10', 'lambda x: sp.log(x, 10)')
# Special functions
equation_str = equation_str.replace('abs', 'sp.Abs')
equation_str = equation_str.replace('gamma', 'sp.gamma')
equation_str = equation_str.replace('factorial', 'sp.factorial')
# Handle piecewise functions
equation_str = equation_str.replace('piecewise', 'sp.Piecewise')
# Constants
equation_str = equation_str.replace('pi', 'sp.pi')
equation_str = equation_str.replace('E', 'sp.E')
equation_str = equation_str.replace('oo', 'sp.oo')
# Create a Symbol for x
x = sp.Symbol('x')
# Check for multivariate expressions and handle them
if any(var in equation_str for var in ['y', 'z']):
st.warning("Note: Only the variable 'x' is supported. Other variables will be treated as symbols.")
# Safety check for extremely long expressions
if len(equation_str) > 1000:
st.warning("This is a very complex equation. Rendering may take longer than usual.")
# Parse the expression using sympy with a timeout
try:
expr = parse_expr(equation_str, local_dict={'x': x})
except Exception as parsing_error:
st.error(f"Could not parse the expression. Please check the syntax: {str(parsing_error)}")
return None, None
# Create a lambda function for numerical evaluation with safety checks
try:
f = lambdify(x, expr, modules=['numpy', 'scipy', 'sympy'])
# Test the function with a simple value to verify it works
test_val = f(0)
# Check for complex results
if isinstance(test_val, complex):
st.info("This equation may produce complex values for some inputs. Only the real part will be displayed.")
# Create a wrapped function that returns only the real part
orig_f = f
def real_part_wrapper(x):
return np.real(orig_f(x))
f = real_part_wrapper
return f, expr
except Exception as eval_error:
st.error(f"Error creating evaluable function: {str(eval_error)}")
return None, None
except Exception as e:
st.error(f"Error processing equation: {str(e)}")
return None, None
def plot_equation(f, expr, x_range=(-10, 10), points=1000):
try:
fig, ax = plt.subplots(figsize=(10, 6))
x = np.linspace(x_range[0], x_range[1], points)
# Handle potential division by zero and other numerical errors
with np.errstate(divide='ignore', invalid='ignore'):
y = f(x)
# Replace infinities and NaNs for proper plotting
if isinstance(y, np.ndarray):
# Convert infinities to NaN to get breaks in the plot line
y[~np.isfinite(y)] = np.nan
# If entirely NaN, show an error message
if np.all(np.isnan(y)):
st.warning("This equation doesn't produce valid results in the specified range.")
return None
# Plot with NaNs to show discontinuities properly
ax.plot(x, y, '-b', label=str(expr))
else:
# Scalar result case
ax.axhline(y=float(y), color='b', linestyle='-', label=str(expr))
# Auto-adjust y-range if there are valid values
y_valid = y[np.isfinite(y)] if isinstance(y, np.ndarray) else np.array([y]) if np.isfinite(y) else np.array([])
if len(y_valid) > 0:
y_min, y_max = np.min(y_valid), np.max(y_valid)
y_range = y_max - y_min
if y_range > 0:
ax.set_ylim([y_min - 0.1*y_range, y_max + 0.1*y_range])
ax.grid(True, alpha=0.3)
ax.axhline(y=0, color='k', linestyle='-', alpha=0.3)
ax.axvline(x=0, color='k', linestyle='-', alpha=0.3)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_title(f'Graph of {str(expr)}')
ax.legend()
return fig
except Exception as e:
st.error(f"Error plotting equation: {str(e)}")
return None
def get_download_link(fig):
buf = io.BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
b64 = base64.b64encode(buf.read()).decode()
return f'<a href="data:image/png;base64,{b64}" download="equation_plot.png">πŸ“₯ Download Plot</a>'
# Header
st.title("πŸ“ˆ Equation2Graph")
st.markdown("Visualize mathematical equations instantly, from symbolic or natural language input!")
# Columns
col1, col2 = st.columns([2, 1])
with col1:
# Natural language input
st.subheader("πŸ—£ Describe your equation (optional)")
nl_input = st.text_input(
"Enter your equation in plain English:",
placeholder="e.g., the square of x plus two times x plus one"
)
equation = ""
if nl_input:
equation = convert_nl_to_equation(nl_input)
if equation:
st.success(f"Converted to equation: {equation}")
else:
st.error("Failed to convert input") # Direct equation input (fallback or primary)
equation = st.text_area(
"Or enter your equation directly:",
value=equation,
placeholder="e.g., x^2 + 2*x + 1",
height=100
)
# Plot settings
with st.expander("βš™ Plot Settings"):
col_a, col_b = st.columns(2)
with col_a:
x_min = st.number_input("X-axis minimum", value=-10.0)
with col_b:
x_max = st.number_input("X-axis maximum", value=10.0)
points = st.slider("Number of points", 100, 2000, 1000)
# Process and plot
if equation:
f, expr = parse_and_validate_equation(equation)
if f and expr:
fig = plot_equation(f, expr, (x_min, x_max), points)
if fig:
st.pyplot(fig)
st.markdown(get_download_link(fig), unsafe_allow_html=True)
if equation not in [h['equation'] for h in st.session_state.history]:
st.session_state.history.append({
'equation': equation,
'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
})
with col2:
st.subheader("πŸ“š Example Equations")
examples = [
"x^2 + 2*x + 1",
"sin(x) + cos(x)",
"exp(-x^2)",
"x^3 - 4*x",
"tan(x)",
"log(abs(x))",
"sin(x^2)*exp(-0.1*x^2)",
"(x^4 - 5*x^2 + 4)/(x^2 + 1)",
"sqrt(abs(x))*sin(10*x)",
"1/(1+exp(-x))",
"sin(1/x)",
"(x^5-5*x^3+4*x)/(x^2+1)",
"piecewise((x^2, x<0), (x^3, x>=0))",
"sinh(x)*cosh(x)"
]
for ex in examples:
if st.button(ex):
st.experimental_set_query_params(equation=ex)
st.subheader("πŸ“– Recent Equations")
for item in reversed(st.session_state.history[-5:]):
st.text(f"{item['equation']}")
st.caption(f"Plotted on: {item['timestamp']}")
with st.expander("β„Ή How to use Equation2Graph"):
st.markdown("""
### Instructions:
1. Describe or type your equation
2. The graph updates automatically
3. Adjust plot settings if needed
4. Download the graph as PNG
### Natural Language Support:
- Try: "the square of x plus two times x"
- You can still type equations like: x^2 + 2*x
### Supported Math Functions:
- Basic: +, -, *, /, ^
- Trig: sin(x), cos(x), tan(x), arcsin(x), arccos(x), arctan(x)
- Exponential/log: exp(x), log(x), log10(x)
- Hyperbolic: sinh(x), cosh(x), tanh(x)
- Piecewise: piecewise((expr1, cond1), (expr2, cond2), ...)
- Advanced: sqrt(x), abs(x), combinations of functions
### Complex Equations:
You can enter lengthy and complex equations such as:
- Combinations of multiple functions: `sin(x^2)*exp(-0.1*x^2)`
- Rational functions: `(x^4 - 5*x^2 + 4)/(x^2 + 1)`
- Piecewise functions: `piecewise((x^2, x<0), (x^3, x>=0))`
- Highly oscillatory functions: `sin(1/x)`
- Special functions: `gamma(x)`, `factorial(x)`
- Any valid mathematical expression with variable x
### Tips for Complex Equations:
- For equations with discontinuities, try adjusting the plot range
- For highly oscillatory functions, increase the number of points in Plot Settings
- Use the piecewise function for functions defined differently across domains
- If your equation produces complex values, only the real part will be displayed
- Extremely long expressions (>1000 characters) are supported but may take longer to render
""")
# Footer
st.markdown("---")
st.markdown("Equation2Graph | Now with 🧠 AI-powered equation parser | Made with ❀ by MathVizMinds")