| import streamlit as st |
| import google.generativeai as genai |
| import os |
| import json |
| import base64 |
| from dotenv import load_dotenv |
| from streamlit_local_storage import LocalStorage |
| import re |
| import streamlit.components.v1 as components |
| import math |
|
|
| |
| st.set_page_config( |
| page_title="Math Jegna - Your AI Math Tutor", |
| page_icon="๐ง ", |
| layout="wide" |
| ) |
|
|
| |
| localS = LocalStorage() |
|
|
| |
| def format_chat_for_download(chat_history): |
| """Formats the chat history into a human-readable string for download.""" |
| formatted_text = f"# Math Mentor Chat\n\n" |
| for message in chat_history: |
| role = "You" if message["role"] == "user" else "Math Mentor" |
| formatted_text += f"**{role}:**\n{message['content']}\n\n---\n\n" |
| return formatted_text |
|
|
| def convert_role_for_gemini(role): |
| """Convert Streamlit chat roles to Gemini API roles""" |
| if role == "assistant": |
| return "model" |
| return role |
|
|
| def should_generate_visual(user_prompt, ai_response): |
| """Determine if a visual aid would be helpful based on the content""" |
| |
| k12_visual_keywords = [ |
| 'add', 'subtract', 'multiply', 'times', 'divide', 'counting', 'numbers', |
| 'fraction', 'half', 'quarter', 'third', 'parts', 'whole', |
| 'shape', 'triangle', 'circle', 'square', 'rectangle', |
| 'money', 'coins', 'dollars', 'cents', 'change', |
| 'time', 'clock', 'hours', 'minutes', 'o\'clock', |
| 'measurement', 'length', 'height', 'weight', |
| 'place value', 'tens', 'ones', 'hundreds', |
| 'pattern', 'sequence', 'skip counting', |
| 'greater than', 'less than', 'equal', 'compare', |
| 'number line', 'array', 'grid' |
| ] |
| |
| combined_text = (user_prompt + " " + ai_response).lower() |
| return any(keyword in combined_text for keyword in k12_visual_keywords) |
|
|
| def create_visual_manipulative(user_prompt, ai_response): |
| """-- SMART VISUAL ROUTER -- |
| Parses the user prompt and calls the appropriate dynamic visual function.""" |
| try: |
| user_lower = user_prompt.lower() |
| |
| |
| time_match = re.search(r'(\d{1,2}):(\d{2})', user_lower) or re.search(r'(\d{1,2})\s*o\'clock', user_lower) |
| if time_match: |
| groups = time_match.groups() |
| hour = int(groups[0]) |
| minute = int(groups[1]) if len(groups) > 1 and groups[1] else 0 |
| if 1 <= hour <= 12 and 0 <= minute <= 59: |
| return create_clock_visual(hour, minute) |
|
|
| |
| fraction_match = re.search(r'(\d+)/(\d+)', user_lower) |
| if fraction_match: |
| num, den = int(fraction_match.group(1)), int(fraction_match.group(2)) |
| if 0 < num <= den and den <= 16: |
| return create_dynamic_fraction_circle(num, den) |
|
|
| |
| mult_match = re.search(r'(\d+)\s*(?:x|times)\s*(\d+)', user_lower) |
| if mult_match: |
| rows, cols = int(mult_match.group(1)), int(mult_match.group(2)) |
| if rows <= 10 and cols <= 10: |
| return create_multiplication_array(rows, cols) |
|
|
| |
| if any(word in user_lower for word in ['add', 'plus', '+', 'subtract', 'minus', 'take away', '-']): |
| numbers = re.findall(r'\d+', user_prompt) |
| if len(numbers) >= 2: |
| num1, num2 = int(numbers[0]), int(numbers[1]) |
| operation = 'add' if any(w in user_lower for w in ['add', 'plus', '+']) else 'subtract' |
| if num1 <= 20 and num2 <= 20: |
| return create_counting_blocks(num1, num2, operation) |
|
|
| |
| if 'number line' in user_lower: |
| numbers = [int(n) for n in re.findall(r'\d+', user_prompt)] |
| if numbers: |
| start = min(numbers) - 2 |
| end = max(numbers) + 2 |
| return create_number_line(start, end, numbers, "Your Numbers on the Line") |
|
|
| |
| if 'place value' in user_lower: |
| numbers = re.findall(r'\d+', user_prompt) |
| if numbers: |
| num = int(numbers[0]) |
| if num <= 999: |
| return create_place_value_blocks(num) |
| |
| |
| if any(word in user_lower for word in ['fraction', 'part']): return create_dynamic_fraction_circle(1, 2) |
| if any(word in user_lower for word in ['shape']): return create_shape_explorer() |
| if any(word in user_lower for word in ['money', 'coin']): return create_money_counter() |
| if any(word in user_lower for word in ['time', 'clock']): return create_clock_visual(10, 10) |
|
|
| return None |
| |
| except Exception as e: |
| st.error(f"Could not create visual: {e}") |
| return None |
|
|
| |
|
|
| def create_counting_blocks(num1, num2, operation): |
| """(Dynamic) Create colorful counting blocks for addition/subtraction.""" |
| html = f""" |
| <div style="padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 15px; margin: 10px 0;"> |
| <h3 style="color: white; text-align: center; margin-bottom: 20px;">๐งฎ Counting Blocks: {num1} {'+' if operation == 'add' else 'โ'} {num2}</h3> |
| <div style="display: flex; justify-content: center; align-items: center; gap: 20px; flex-wrap: wrap;"> |
| <!-- Blocks for Num1 --> |
| <div style="display: flex; flex-wrap: wrap; gap: 5px; border: 2px dashed #FFE066; padding: 5px; border-radius: 5px; align-items: center; justify-content: center; min-width: 100px;"><div style="width: 100%; text-align:center; color: white; font-weight: bold;">{num1}</div>{''.join([f'<div style="width: 25px; height: 25px; background: #FF6B6B; border-radius: 5px;"></div>' for _ in range(num1)])}</div> |
| <div style="font-size: 40px; color: #FFE066;">{'+' if operation == 'add' else 'โ'}</div> |
| <!-- Blocks for Num2 --> |
| <div style="display: flex; flex-wrap: wrap; gap: 5px; border: 2px dashed #FFE066; padding: 5px; border-radius: 5px; align-items: center; justify-content: center; min-width: 100px;"><div style="width: 100%; text-align:center; color: white; font-weight: bold;">{num2}</div>{''.join([f'<div style="width: 25px; height: 25px; background: #4ECDC4; border-radius: 5px;"></div>' for _ in range(num2)])}</div> |
| <div style="font-size: 40px; color: #FFE066;">=</div> |
| <!-- Blocks for Answer --> |
| <div style="display: flex; flex-wrap: wrap; gap: 5px; border: 2px solid white; background: rgba(255,255,255,0.2); padding: 5px; border-radius: 5px; align-items: center; justify-content: center; min-width: 100px;"><div style="width: 100%; text-align:center; color: white; font-weight: bold;">{num1 + num2 if operation == 'add' else max(0, num1 - num2)}</div>{''.join([f'<div style="width: 25px; height: 25px; background: #95E1D3; border-radius: 5px;"></div>' for _ in range(num1 + num2 if operation == 'add' else max(0, num1 - num2))])}</div> |
| </div> |
| </div>""" |
| return html |
|
|
| def create_dynamic_fraction_circle(numerator, denominator): |
| """(Dynamic) Generates an SVG of a pizza/pie to represent a fraction.""" |
| if not (0 < numerator <= denominator): return "<p>I can only show proper fractions!</p>" |
| width, height, radius = 150, 150, 60 |
| cx, cy = width / 2, height / 2 |
| slices_html = '' |
| angle_step = 360 / denominator |
| for i in range(denominator): |
| start_angle, end_angle = i * angle_step, (i + 1) * angle_step |
| fill_color = "#FF6B6B" if i < numerator else "#DDDDDD" |
| start_rad, end_rad = math.radians(start_angle - 90), math.radians(end_angle - 90) |
| x1, y1 = cx + radius * math.cos(start_rad), cy + radius * math.sin(start_rad) |
| x2, y2 = cx + radius * math.cos(end_rad), cy + radius * math.sin(end_rad) |
| large_arc_flag = 1 if angle_step > 180 else 0 |
| path_d = f"M {cx},{cy} L {x1},{y1} A {radius},{radius} 0 {large_arc_flag},1 {x2},{y2} Z" |
| slices_html += f'<path d="{path_d}" fill="{fill_color}" stroke="#333" stroke-width="2"/>' |
| html = f"""<div style="padding: 20px; background: linear-gradient(135deg, #A8EDEA 0%, #FED6E3 100%); border-radius: 15px; margin: 10px 0;"><h3 style="color: #333; text-align: center;">Fraction Pizza: {numerator}/{denominator}</h3><div style="display: flex; justify-content: center;"><svg width="{width}" height="{height}">{slices_html}</svg></div><p style="color: #333; text-align: center; margin-top: 15px; font-size: 18px;">The pizza is cut into <b>{denominator}</b> equal slices, and we are showing <b>{numerator}</b> of them! ๐</p></div>""" |
| return html |
|
|
| def create_clock_visual(hours, minutes): |
| """(Dynamic) Create a clock showing a specific time.""" |
| min_angle = minutes * 6 |
| hour_angle = (hours % 12 + minutes / 60) * 30 |
| html = f"""<div style="padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 15px; margin: 10px 0;"><h3 style="color: white; text-align: center; margin-bottom: 20px;">๐ Learning Time!</h3><div style="display: flex; justify-content: center;"><svg width="250" height="250" viewBox="0 0 250 250" style="background: white; border-radius: 50%; border: 8px solid #FFE066;"><circle cx="125" cy="125" r="110" fill="white" stroke="#333" stroke-width="2"/><text x="125" y="45" text-anchor="middle" font-size="20" font-weight="bold" fill="#333">12</text><text x="205" y="130" text-anchor="middle" font-size="20" font-weight="bold" fill="#333">3</text><text x="125" y="215" text-anchor="middle" font-size="20" font-weight="bold" fill="#333">6</text><text x="45" y="130" text-anchor="middle" font-size="20" font-weight="bold" fill="#333">9</text><line x1="125" y1="125" x2="125" y2="40" stroke="#FF6B6B" stroke-width="6" stroke-linecap="round" transform="rotate({hour_angle}, 125, 125)"/><line x1="125" y1="125" x2="125" y2="25" stroke="#4ECDC4" stroke-width="4" stroke-linecap="round" transform="rotate({min_angle}, 125, 125)"/><circle cx="125" cy="125" r="8" fill="#333"/></svg></div><div style="text-align: center; margin-top: 20px;"><p style="color: #FFE066; font-size: 24px; font-weight: bold;">This clock shows {hours:02d}:{minutes:02d}</p><p style="color: white; font-size: 16px;">The short <span style="color:#FF6B6B">red</span> hand points to the hour. The long <span style="color:#4ECDC4">blue</span> hand points to the minutes.</p></div></div>""" |
| return html |
|
|
| def create_multiplication_array(rows, cols): |
| """(NEW & Dynamic) Generates an SVG grid of dots to show multiplication.""" |
| cell_size, gap = 25, 5 |
| svg_width = cols * (cell_size + gap) |
| svg_height = rows * (cell_size + gap) |
| dots_html = "".join([f'<circle cx="{c * (cell_size + gap) + cell_size/2}" cy="{r * (cell_size + gap) + cell_size/2}" r="{cell_size/2 - 2}" fill="#FF6B6B"/>' for r in range(rows) for c in range(cols)]) |
| html = f"""<div style="padding: 20px; background: linear-gradient(135deg, #FF9A9E 0%, #FECFEF 100%); border-radius: 15px; margin: 10px 0;"><h3 style="color:#333; text-align: center;">Multiplication Array: {rows} ร {cols} = {rows * cols}</h3><div style="display: flex; justify-content: center; padding: 10px;"><svg width="{svg_width}" height="{svg_height}">{dots_html}</svg></div><p style="color: #333; text-align: center; font-size: 18px;">See? There are <b>{rows}</b> rows of <b>{cols}</b> dots. That's <b>{rows*cols}</b> dots in total!</p></div>""" |
| return html |
|
|
| def create_number_line(start, end, points, title="Number Line"): |
| """(NEW & Dynamic) Creates a simple number line SVG.""" |
| width = 600 |
| padding = 30 |
| |
| if start >= end: |
| end = start + 1 |
| scale = (width - 2 * padding) / (end - start) |
| def to_x(n): return padding + (n - start) * scale |
| ticks_html = "".join([f'<g transform="translate({to_x(i)}, 50)"><line y2="10" stroke="#aaa"/><text y="30" text-anchor="middle" fill="#555">{i}</text></g>' for i in range(start, end + 1)]) |
| points_html = "".join([f'<g transform="translate({to_x(p)}, 50)"><circle r="8" fill="#FF6B6B" stroke="white" stroke-width="2"/><text y="-15" text-anchor="middle" font-weight="bold" fill="#D63031">{p}</text></g>' for p in points]) |
| html = f"""<div style="padding: 20px; background: #f7f1e3; border-radius: 15px; margin: 10px 0;"><h3 style="text-align: center; color: #333;">{title}</h3><svg width="{width}" height="100"><line x1="{padding}" y1="50" x2="{width-padding}" y2="50" stroke="#333" stroke-width="2"/>{ticks_html}{points_html}</svg></div>""" |
| return html |
|
|
| def create_place_value_blocks(number): |
| """(FIXED & Dynamic) Create place value blocks for understanding numbers.""" |
| hundreds, tens, ones = number // 100, (number % 100) // 10, number % 10 |
| |
| |
| h_block_html = "" |
| if hundreds > 0: |
| hundreds_grid = "".join(["<div style='background:#F5A6A6'></div>"] * 100) |
| hundreds_squares = "".join([f""" |
| <div style="width: 100px; height: 100px; background: #FF6B6B; border: 2px solid #D63031; display: grid; grid-template-columns: repeat(10, 1fr); gap: 2px; padding: 2px;"> |
| {hundreds_grid} |
| </div> |
| """ for _ in range(hundreds)]) |
| h_block_html = f""" |
| <div style="text-align: center;"> |
| <h4>Hundreds: {hundreds}</h4> |
| <div style="display: flex; gap: 5px;">{hundreds_squares}</div> |
| </div> |
| """ |
|
|
| |
| t_block_html = "" |
| if tens > 0: |
| tens_grid = "".join(["<div style='background:#A2E8E4'></div>"] * 10) |
| tens_sticks = "".join([f""" |
| <div style="width: 10px; height: 100px; background: #4ECDC4; border: 2px solid #00B894; display: grid; grid-template-rows: repeat(10, 1fr); gap: 2px; padding: 2px;"> |
| {tens_grid} |
| </div> |
| """ for _ in range(tens)]) |
| t_block_html = f""" |
| <div style="text-align: center;"> |
| <h4>Tens: {tens}</h4> |
| <div style="display: flex; gap: 5px; align-items: flex-end;">{tens_sticks}</div> |
| </div> |
| """ |
|
|
| |
| o_block_html = "" |
| if ones > 0: |
| ones_cubes = "".join(['<div style="width: 10px; height: 10px; background: #FFE066; border: 2px solid #FDCB6E;"></div>' for _ in range(ones)]) |
| o_block_html = f""" |
| <div style="text-align: center;"> |
| <h4>Ones: {ones}</h4> |
| <div style="display: flex; gap: 5px; align-items: flex-end; flex-wrap: wrap; width: 50px; justify-content: center;">{ones_cubes}</div> |
| </div> |
| """ |
|
|
| |
| html = f""" |
| <div style="padding: 20px; background: linear-gradient(135deg, #dfe6e9 0%, #b2bec3 100%); border-radius: 15px; margin: 10px 0;"> |
| <h3 style="color: #333; text-align: center;">Place Value Blocks for {number}</h3> |
| <div style="display: flex; justify-content: center; align-items: flex-end; gap: 20px; flex-wrap: wrap; padding: 20px 0; min-height: 150px;"> |
| {h_block_html} |
| {t_block_html} |
| {o_block_html} |
| </div> |
| <div style="text-align: center; margin-top: 15px; padding: 10px; background: rgba(0,0,0,0.1); border-radius: 10px;"> |
| <h4 style="color: #333; margin:0;"> |
| {hundreds} Hundreds + {tens} Tens + {ones} Ones = {number} |
| </h4> |
| </div> |
| </div> |
| """ |
| return html |
|
|
| def create_shape_explorer(): |
| """(Static) Create colorful shape recognition tool.""" |
| html = """<div style="padding: 20px; background: linear-gradient(135deg, #A8EDEA 0%, #FED6E3 100%); border-radius: 15px; margin: 10px 0;"><h3 style="color: #333; text-align: center; margin-bottom: 20px;">๐ท Shape Explorer!</h3><div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(150px, 1fr)); gap: 20px; max-width: 600px; margin: 0 auto;"><div style="text-align: center; padding: 15px; background: white; border-radius: 10px; box-shadow: 0 4px 8px rgba(0,0,0,0.1);"><h4 style="color: #333; margin-bottom: 10px;">Circle</h4><svg width="80" height="80"><circle cx="40" cy="40" r="35" fill="#FF6B6B" stroke="#333" stroke-width="3"/></svg><p style="color: #666; font-size: 12px; margin-top: 10px;">Round and smooth!</p></div><div style="text-align: center; padding: 15px; background: white; border-radius: 10px; box-shadow: 0 4px 8px rgba(0,0,0,0.1);"><h4 style="color: #333; margin-bottom: 10px;">Square</h4><svg width="80" height="80"><rect x="12.5" y="12.5" width="55" height="55" fill="#4ECDC4" stroke="#333" stroke-width="3"/></svg><p style="color: #666; font-size: 12px; margin-top: 10px;">4 equal sides!</p></div><div style="text-align: center; padding: 15px; background: white; border-radius: 10px; box-shadow: 0 4px 8px rgba(0,0,0,0.1);"><h4 style="color: #333; margin-bottom: 10px;">Triangle</h4><svg width="80" height="80"><polygon points="40,15 15,65 65,65" fill="#FFD93D" stroke="#333" stroke-width="3"/></svg><p style="color: #666; font-size: 12px; margin-top: 10px;">3 sides and corners!</p></div><div style="text-align: center; padding: 15px; background: white; border-radius: 10px; box-shadow: 0 4px 8px rgba(0,0,0,0.1);"><h4 style="color: #333; margin-bottom: 10px;">Rectangle</h4><svg width="80" height="80"><rect x="10" y="25" width="60" height="30" fill="#95E1D3" stroke="#333" stroke-width="3"/></svg><p style="color: #666; font-size: 12px; margin-top: 10px;">4 sides, opposite sides equal!</p></div></div><p style="color: #333; text-align: center; margin-top: 20px; font-size: 18px;">Can you find these shapes around you? ๐โจ</p></div>""" |
| return html |
|
|
| def create_money_counter(): |
| """(Static) Create coin counting visual.""" |
| html = """<div style="padding: 20px; background: linear-gradient(135deg, #FFE259 0%, #FFA751 100%); border-radius: 15px; margin: 10px 0;"><h3 style="color: #333; text-align: center; margin-bottom: 20px;">๐ฐ Money Counter!</h3><div style="display: flex; justify-content: center; gap: 30px; flex-wrap: wrap;"><div style="text-align: center; padding: 15px; background: white; border-radius: 10px;"><h4 style="color: #333;">Penny</h4><div style="width: 50px; height: 50px; background: #CD7F32; border-radius: 50%; margin: 10px auto; display: flex; align-items: center; justify-content: center; border: 3px solid #8B4513;"><span style="color: white; font-weight: bold;">1ยข</span></div><p style="color: #666; font-size: 12px;">1 cent</p></div><div style="text-align: center; padding: 15px; background: white; border-radius: 10px;"><h4 style="color: #333;">Nickel</h4><div style="width: 55px; height: 55px; background: #C0C0C0; border-radius: 50%; margin: 10px auto; display: flex; align-items: center; justify-content: center; border: 3px solid #808080;"><span style="color: #333; font-weight: bold;">5ยข</span></div><p style="color: #666; font-size: 12px;">5 cents</p></div><div style="text-align: center; padding: 15px; background: white; border-radius: 10px;"><h4 style="color: #333;">Dime</h4><div style="width: 45px; height: 45px; background: #C0C0C0; border-radius: 50%; margin: 10px auto; display: flex; align-items: center; justify-content: center; border: 3px solid #808080;"><span style="color: #333; font-weight: bold;">10ยข</span></div><p style="color: #666; font-size: 12px;">10 cents</p></div><div style="text-align: center; padding: 15px; background: white; border-radius: 10px;"><h4 style="color: #333;">Quarter</h4><div style="width: 60px; height: 60px; background: #C0C0C0; border-radius: 50%; margin: 10px auto; display: flex; align-items: center; justify-content: center; border: 3px solid #808080;"><span style="color: #333; font-weight: bold;">25ยข</span></div><p style="color: #666; font-size: 12px;">25 cents</p></div></div><p style="color: #333; text-align: center; margin-top: 20px; font-size: 18px;">Practice counting coins to make different amounts! ๐ชโจ</p></div>""" |
| return html |
|
|
| |
| |
|
|
| |
| load_dotenv() |
| api_key = None |
|
|
| try: |
| api_key = st.secrets["GOOGLE_API_KEY"] |
| except (KeyError, FileNotFoundError): |
| api_key = os.getenv("GOOGLE_API_KEY") |
|
|
| if api_key: |
| genai.configure(api_key=api_key) |
| |
| |
| model = genai.GenerativeModel( |
| model_name="gemini-1.5-flash", |
| system_instruction=""" |
| You are "Math Jegna", an AI specializing exclusively in K-12 mathematics. |
| Your one and only function is to solve and explain math problems for children. |
| You are an AI math tutor that uses the Professor B methodology developed by Everard Barrett. This methodology is designed to activate children's natural learning capacities and present mathematics as a contextual, developmental story that makes sense. |
| |
| IMPORTANT: When explaining mathematical concepts to young learners, mention that colorful visual aids will be provided to help illustrate the concept. Use phrases like: |
| - "Let me show you this with some colorful blocks..." |
| - "A fun visual will help you see how this works..." |
| - "I'll create a picture to help you understand this fraction..." |
| |
| Focus on concepts appropriate for K-12 students: |
| - Basic counting and number recognition |
| - Simple addition and subtraction (using manipulatives) |
| - Multiplication as arrays or groups |
| - Basic shapes and geometry |
| - Place value with hundreds, tens, ones |
| - Money counting and coin recognition |
| - Time telling with analog clocks |
| - Simple patterns and sequences |
| - Basic measurement concepts |
| |
| Always use age-appropriate language and relate math to real-world examples children understand. |
| |
| Core Philosophy and Principles |
| 1. Contextual Learning Approach |
| Present math as a story: Every mathematical concept should be taught as part of a continuing narrative that builds connections between ideas |
| Use concrete manipulatives: Always relate abstract concepts to physical, visual representations |
| Truth-telling: Present arithmetic computations simply and truthfully without confusing steps |
| |
| 2. Natural Learning Activation |
| Leverage natural capacities: Recognize that each child has mental capabilities designed to learn naturally |
| Story-based retention: Use stories and visual representations that children can easily remember |
| Reduced anxiety: Make math fun and engaging, not scary or confusing |
| |
| 3. Hands-on Learning |
| Mental gymnastics: Use finger counting, visual blocks, and interactive elements |
| No rote memorization: Focus on understanding through play and exploration |
| Build confidence: Celebrate small victories and progress |
| |
| You are strictly forbidden from answering any question that is not mathematical in nature. |
| If you receive a non-mathematical question, you MUST decline with: "I can only answer math questions for students. Please ask me about numbers, shapes, counting, or other math topics!" |
| |
| Keep explanations simple, encouraging, and fun for young learners. |
| """ |
| ) |
| else: |
| st.error("๐จ Google API Key not found! Please add it to your secrets or a local .env file.") |
| st.stop() |
|
|
| |
| if "chats" not in st.session_state: |
| try: |
| shared_chat_b64 = st.query_params.get("shared_chat") |
| if shared_chat_b64: |
| decoded_chat_json = base64.urlsafe_b64decode(shared_chat_b64).decode() |
| st.session_state.chats = {"Shared Chat": json.loads(decoded_chat_json)} |
| st.session_state.active_chat_key = "Shared Chat" |
| st.query_params.clear() |
| else: |
| raise ValueError("No shared chat") |
| except (TypeError, ValueError, Exception): |
| saved_data_json = localS.getItem("math_mentor_chats") |
| if saved_data_json: |
| saved_data = json.loads(saved_data_json) |
| st.session_state.chats = saved_data.get("chats", {}) |
| st.session_state.active_chat_key = saved_data.get("active_chat_key", "New Chat") |
| else: |
| st.session_state.chats = { |
| "New Chat": [ |
| {"role": "assistant", "content": "Hello! I'm Math Jegna, your friendly math helper! ๐ง โจ I love helping students learn math with colorful pictures and fun activities. What would you like to learn about today? Maybe counting, shapes, or solving a math problem? ๐"} |
| ] |
| } |
| st.session_state.active_chat_key = "New Chat" |
|
|
| |
| @st.dialog("Rename Chat") |
| def rename_chat(chat_key): |
| st.write(f"Enter a new name for '{chat_key}':") |
| new_name = st.text_input("New Name", key=f"rename_input_{chat_key}") |
| if st.button("Save", key=f"save_rename_{chat_key}"): |
| if new_name and new_name not in st.session_state.chats: |
| st.session_state.chats[new_name] = st.session_state.chats.pop(chat_key) |
| st.session_state.active_chat_key = new_name |
| st.rerun() |
| elif not new_name: |
| st.error("Name cannot be empty.") |
| else: |
| st.error("A chat with this name already exists.") |
|
|
| |
| @st.dialog("Delete Chat") |
| def delete_chat(chat_key): |
| st.warning(f"Are you sure you want to delete '{chat_key}'? This cannot be undone.") |
| if st.button("Yes, Delete", type="primary", key=f"confirm_delete_{chat_key}"): |
| st.session_state.chats.pop(chat_key) |
| |
| if st.session_state.active_chat_key == chat_key: |
| |
| if st.session_state.chats: |
| st.session_state.active_chat_key = next(iter(st.session_state.chats)) |
| else: |
| |
| st.session_state.chats["New Chat"] = [ |
| {"role": "assistant", "content": "Hello! Let's start a new math adventure! ๐"} |
| ] |
| st.session_state.active_chat_key = "New Chat" |
| st.rerun() |
|
|
| |
| with st.sidebar: |
| st.title("๐งฎ Math Jegna") |
| st.write("Your K-8 AI Math Tutor") |
| st.divider() |
| |
| |
| for chat_key in list(st.session_state.chats.keys()): |
| col1, col2, col3 = st.columns([0.6, 0.2, 0.2]) |
| with col1: |
| if st.button(chat_key, key=f"switch_{chat_key}", use_container_width=True, type="primary" if st.session_state.active_chat_key == chat_key else "secondary"): |
| st.session_state.active_chat_key = chat_key |
| st.rerun() |
| with col2: |
| if st.button("โ๏ธ", key=f"rename_{chat_key}", help="Rename Chat"): |
| rename_chat(chat_key) |
| with col3: |
| if st.button("๐๏ธ", key=f"delete_{chat_key}", help="Delete Chat"): |
| delete_chat(chat_key) |
|
|
| if st.button("โ New Chat", use_container_width=True): |
| new_chat_name = f"Chat {len(st.session_state.chats) + 1}" |
| |
| while new_chat_name in st.session_state.chats: |
| new_chat_name += "*" |
| st.session_state.chats[new_chat_name] = [ |
| {"role": "assistant", "content": "Ready for a new math problem! What's on your mind? ๐"} |
| ] |
| st.session_state.active_chat_key = new_chat_name |
| st.rerun() |
| |
| st.divider() |
| |
| |
| if st.button("๐พ Save Chats", use_container_width=True): |
| data_to_save = { |
| "chats": st.session_state.chats, |
| "active_chat_key": st.session_state.active_chat_key |
| } |
| localS.setItem("math_mentor_chats", json.dumps(data_to_save)) |
| st.toast("Chats saved to your browser!", icon="โ
") |
|
|
| |
| active_chat_history = st.session_state.chats[st.session_state.active_chat_key] |
| download_str = format_chat_for_download(active_chat_history) |
| st.download_button( |
| label="๐ฅ Download Chat", |
| data=download_str, |
| file_name=f"{st.session_state.active_chat_key.replace(' ', '_')}_history.md", |
| mime="text/markdown", |
| use_container_width=True |
| ) |
| |
| |
| if st.button("๐ Share Chat", use_container_width=True): |
| chat_json = json.dumps(st.session_state.chats[st.session_state.active_chat_key]) |
| chat_b64 = base64.urlsafe_b64encode(chat_json.encode()).decode() |
| |
| share_url = f"https://huggingface.co/spaces/YOUR_SPACE_HERE?shared_chat={chat_b64}" |
| st.code(share_url) |
| st.info("Copy the URL above to share this specific chat! (You might need to update the base URL)") |
|
|
|
|
| st.header(f"Chatting with Math Jegna: _{st.session_state.active_chat_key}_") |
|
|
| |
| for message in st.session_state.chats[st.session_state.active_chat_key]: |
| with st.chat_message(message["role"]): |
| st.markdown(message["content"]) |
| |
| if "visual_html" in message and message["visual_html"]: |
| components.html(message["visual_html"], height=400, scrolling=True) |
|
|
| |
| if prompt := st.chat_input("Ask a K-8 math question..."): |
| |
| st.session_state.chats[st.session_state.active_chat_key].append({"role": "user", "content": prompt}) |
| with st.chat_message("user"): |
| st.markdown(prompt) |
|
|
| |
| gemini_chat_history = [ |
| {"role": convert_role_for_gemini(m["role"]), "parts": [m["content"]]} |
| for m in st.session_state.chats[st.session_state.active_chat_key] |
| ] |
|
|
| |
| with st.chat_message("assistant"): |
| with st.spinner("Math Jegna is thinking..."): |
| try: |
| chat_session = model.start_chat(history=gemini_chat_history) |
| response = chat_session.send_message(prompt, stream=True) |
| |
| full_response = "" |
| response_container = st.empty() |
| for chunk in response: |
| full_response += chunk.text |
| response_container.markdown(full_response + " โ") |
| response_container.markdown(full_response) |
| |
| |
| visual_html_content = None |
| if should_generate_visual(prompt, full_response): |
| visual_html_content = create_visual_manipulative(prompt, full_response) |
| if visual_html_content: |
| components.html(visual_html_content, height=400, scrolling=True) |
|
|
| |
| st.session_state.chats[st.session_state.active_chat_key].append({ |
| "role": "assistant", |
| "content": full_response, |
| "visual_html": visual_html_content |
| }) |
|
|
|
|
| except genai.types.generation_types.BlockedPromptException as e: |
| error_message = "I can only answer math questions for students. Please ask me about numbers, shapes, or other math topics!" |
| st.error(error_message) |
| st.session_state.chats[st.session_state.active_chat_key].append({"role": "assistant", "content": error_message, "visual_html": None}) |
| except Exception as e: |
| st.error(f"An error occurred: {e}") |