File size: 1,801 Bytes
a735c63
 
b958ec4
a735c63
fdb91dd
48e3923
e293e40
 
48e3923
 
 
 
 
 
 
 
e293e40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b62ff1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import streamlit as st
import textgrad as tg
import os

tg.set_backward_engine(tg.get_engine("gpt-4o"), override = True)

# Hardcoded examples
default_initial_solution = """To solve the equation 3x^2 - 7x + 2 = 0, we use the quadratic formula:
x = (-b ± √(b^2 - 4ac)) / 2a
a = 3, b = -7, c = 2
x = (7 ± √((-7)^2 + 4(3)(2))) / 6
x = (7 ± √73) / 6
The solutions are:
x1 = (7 + √73)
x2 = (7 - √73)"""

default_loss_system_prompt = """You will evaluate a solution to a math question. 
Do not attempt to solve it yourself, do not give a solution, only identify errors. Be super concise."""

# Display clickable examples
if st.button("Use Example Initial Solution"):
    st.session_state.initial_solution = default_initial_solution

if st.button("Use Example Loss System Prompt"):
    st.session_state.loss_system_prompt = default_loss_system_prompt

# Input boxes
initial_solution = st.text_area("Initial Solution", st.session_state.get("initial_solution", ""))
loss_system_prompt = st.text_area("Loss System Prompt", st.session_state.get("loss_system_prompt", ""))
num_epochs = st.number_input("Epochs", min_value=1, value=1)

# Enter button
if st.button("Enter"):
    # Set up the textgrad variables
    solution = tg.Variable(initial_solution,
                           requires_grad=True,
                           role_description="solution to the math question")

    loss_fn = tg.TextLoss(tg.Variable(loss_system_prompt,
                                      requires_grad=False,
                                      role_description="system prompt"))
    optimizer = tg.TGD([solution])

    # Training loop
    for i in range(num_epochs):
        loss = loss_fn(solution)
        loss.backward()
        optimizer.step()

    # Output box
    st.text_area("Result", solution.value)