Equation2Graph / src /streamlit_app.py
KarthikGarlapati's picture
Update src/streamlit_app.py
b357fbc verified
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
import sympy as sp
from sympy.parsing.sympy_parser import parse_expr
from sympy.utilities.lambdify import lambdify
import io
import os
import base64
from datetime import datetime
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Create required directories
os.makedirs("uploads", exist_ok=True)
# Page config
st.set_page_config(page_title="Equation2Graph", page_icon="πŸ“ˆ", layout="wide")
# Session state
if 'history' not in st.session_state:
st.session_state.history = []
# Initialize NLP model
@st.cache_resource
def load_nlp_model():
try:
tokenizer = AutoTokenizer.from_pretrained("t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
return tokenizer, model
except Exception as e:
st.warning(f"⚠ AI model could not be loaded: {str(e)}. Falling back to basic equation parsing.")
return None, None
# Load models
tokenizer, model = load_nlp_model()
def convert_nl_to_equation(nl_input):
"""Convert natural language to equation using AI model or fallback to basic parsing"""
if tokenizer and model:
try:
inputs = tokenizer.encode(nl_input, return_tensors="pt", truncation=True)
outputs = model.generate(inputs, max_length=50)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
st.error(f"Error in AI conversion: {str(e)}")
return basic_nl_to_equation(nl_input)
else:
return basic_nl_to_equation(nl_input)
def basic_nl_to_equation(text):
"""Basic natural language to equation conversion"""
# Define basic math terms
terms = {
'squared': '^2',
'cubed': '^3',
'plus': '+',
'minus': '-',
'times': '*',
'multiply by': '*',
'divided by': '/',
'over': '/',
'to the power of': '^',
'sin of': 'sin(',
'cos of': 'cos(',
'tan of': 'tan(',
'exponential of': 'exp(',
'log of': 'log(',
'absolute value of': 'abs('
}
# Normalize text
text = text.lower().strip()
# Replace terms
for term, symbol in terms.items():
text = text.replace(term, symbol)
# Clean up the equation
text = text.replace(' ', '')
if text.count('(') > text.count(')'):
text += ')' # Close any unclosed parentheses
# Remove any invalid characters
text = ''.join(c for c in text if c.isalnum() or c in '+-*/^().')
return text
def parse_and_validate_equation(equation_str):
try:
equation_str = equation_str.replace('^', '')
expr = parse_expr(equation_str)
x = sp.Symbol('x')
f = lambdify(x, expr, modules=['numpy'])
return f, expr
except Exception as e:
st.error(f"Error parsing equation: {str(e)}")
return None, None
def plot_equation(f, expr, x_range=(-10, 10), points=1000):
try:
fig, ax = plt.subplots(figsize=(10, 6))
x = np.linspace(x_range[0], x_range[1], points)
y = f(x)
ax.plot(x, y, '-b', label=str(expr))
ax.grid(True, alpha=0.3)
ax.axhline(y=0, color='k', linestyle='-', alpha=0.3)
ax.axvline(x=0, color='k', linestyle='-', alpha=0.3)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_title(f'Graph of {str(expr)}')
ax.legend()
return fig
except Exception as e:
st.error(f"Error plotting equation: {str(e)}")
return None
def get_download_link(fig):
buf = io.BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
b64 = base64.b64encode(buf.read()).decode()
return f'<a href="data:image/png;base64,{b64}" download="equation_plot.png">πŸ“₯ Download Plot</a>'
# Header
st.title("πŸ“ˆ Equation2Graph")
st.markdown("Visualize mathematical equations instantly, from symbolic or natural language input!")
# Columns
col1, col2 = st.columns([2, 1])
with col1:
# Natural language input
st.subheader("πŸ—£ Describe your equation (optional)")
nl_input = st.text_input(
"Enter your equation in plain English:",
placeholder="e.g., the square of x plus two times x plus one"
)
equation = ""
if nl_input:
equation = convert_nl_to_equation(nl_input)
if equation:
st.success(f"Converted to equation: {equation}")
else:
st.error("Failed to convert input")
# Direct equation input (fallback or primary)
equation = st.text_input(
"Or enter your equation directly:",
value=equation,
placeholder="e.g., x^2 + 2*x + 1"
)
# Plot settings
with st.expander("βš™ Plot Settings"):
col_a, col_b = st.columns(2)
with col_a:
x_min = st.number_input("X-axis minimum", value=-10.0)
with col_b:
x_max = st.number_input("X-axis maximum", value=10.0)
points = st.slider("Number of points", 100, 2000, 1000)
# Process and plot
if equation:
f, expr = parse_and_validate_equation(equation)
if f and expr:
fig = plot_equation(f, expr, (x_min, x_max), points)
if fig:
st.pyplot(fig)
st.markdown(get_download_link(fig), unsafe_allow_html=True)
if equation not in [h['equation'] for h in st.session_state.history]:
st.session_state.history.append({
'equation': equation,
'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
})
with col2:
st.subheader("πŸ“š Example Equations")
examples = [
"x^2 + 2*x + 1",
"sin(x) + cos(x)",
"exp(-x^2)",
"x^3 - 4*x",
"tan(x)",
"log(abs(x))"
]
for ex in examples:
if st.button(ex):
st.experimental_set_query_params(equation=ex)
st.subheader("πŸ“– Recent Equations")
for item in reversed(st.session_state.history[-5:]):
st.text(f"{item['equation']}")
st.caption(f"Plotted on: {item['timestamp']}")
with st.expander("β„Ή How to use Equation2Graph"):
st.markdown("""
### Instructions:
1. Describe or type your equation
2. The graph updates automatically
3. Adjust plot settings if needed
4. Download the graph as PNG
### Natural Language Support:
- Try: "the square of x plus two times x"
- You can still type equations like: x^2 + 2*x
### Supported Math Functions:
- Basic: +, -, *, /, ^
- Trig: sin(x), cos(x), tan(x)
- Exponential/log: exp(x), log(x)
""")
# Footer
st.markdown("---")
st.markdown("Equation2Graph | Now with 🧠 AI-powered equation parser | Made by webwhiz")