File size: 3,812 Bytes
a5fd16c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os 
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import streamlit as st
from src.cot import generate_answer
from src.consistency import self_consistent_answer
from src.examples import DEFAULT_EXAMPLES
from dotenv import load_dotenv
load_dotenv()

st.set_page_config(page_title="Reasonify")
st.title("Reasoning Playground")

st.sidebar.header("Settings")
mode = st.sidebar.selectbox("Prompting mode", ["Base", "Chain Of Thought"])
model_id = st.sidebar.selectbox("Model",["llama3-8b-8192"])
temperature = st.sidebar.slider("Temperature", 0.1, 0.5, 0.7, step=0.1)
num_paths = st.sidebar.selectbox("Number of Reasoning Paths", [1,3,5])

zero_shot = False
if mode == "Chain Of Thought":
    zero_shot = st.sidebar.selectbox("Zero-Shot prompting", [False, True], index=0)

    if not zero_shot:
        st.sidebar.markdown("### Few-Shot Examples")

        examples = []
        for i in range(3):
            with st.sidebar.expander(f"Example {i+1}", expanded=(i == 0)):
                default_q, default_a = DEFAULT_EXAMPLES[i]
                q_edit = st.text_area(f"Question {i+1}", default_q, key=f"q_{i}")
                a_edit = st.text_area(f"Answer {i+1}", default_a, key=f"a_{i}")
                examples.append((q_edit, a_edit))
    else:
        examples = []

st.markdown("### Enter your question")
question = st.text_input("Ask something...")

reason, ans = None, None

if st.button("Generate"):
    if question.strip():
        st.spinner("Thinking...")

        mode = "cot" if mode == "Chain Of Thought" else "base"

        if mode == "base":
            ans = generate_answer(question, model_id, temperature)

        elif mode == "cot" and num_paths == 1:

            if not zero_shot:
                reason, ans = generate_answer(question=question,
                                            model_id=model_id, 
                                            temperature=temperature, 
                                            max_tokens=200, 
                                            mode="cot", 
                                            exampler=DEFAULT_EXAMPLES,
                                            zero_shot=zero_shot
                                            )
            else:
                reason, ans = generate_answer(question=question, 
                                              model_id=model_id, 
                                              temperature=temperature, 
                                              max_tokens=200, 
                                              mode="cot", 
                                              zero_shot=zero_shot)
        
        else:
            reason, ans = self_consistent_answer(question=question, 
                                                 model_id=model_id, 
                                                 temperature=temperature, 
                                                 max_tokens=200, 
                                                 exampler=DEFAULT_EXAMPLES, 
                                                 num_samples=num_paths
                                                 )


        st.markdown("##  Output")

        if mode == "base":
            st.success(f"**Answer:** {ans}")
        elif mode == "cot" and num_paths == 1:
            with st.expander(f"Final Answer: {ans}"):
                st.write(reason)
        else:
            st.markdown("### Self-Consistent Reasoning Paths")
            for i, (r, a) in enumerate(zip(reason, ans)):
                with st.expander(f"Path {i+1}: {a.strip()}"):
                    st.write(r.strip())
            
            st.success(f"**Final Answer (most common):** {max(set(ans), key=ans.count)}")
    else:
        st.warning("Please enter a question.")