File size: 3,534 Bytes
48217bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd6e68b
48217bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd6e68b
48217bb
 
 
 
 
 
cd6e68b
 
 
 
 
48217bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import time

import openai
import gradio as gr

alice_prefix = 'Alice: '
bob_prefix = 'Bob: '

def seed_submit(seed, history):
    if seed == '':
        seed = 'Hi Bob!'
    return "", history + [[f'{alice_prefix}{seed}', None]]

def bot_update(history, key, alice_system, alice_temp, alice_model, bob_system, bob_temp, bob_model):  

    while True:             
        if history[-1][1] is None:  # Bob's turn
            messages = [
                {"role": "system", "content": bob_system},        
            ]
            for [msg, resp] in history:
                messages.append({"role": "user", "content": msg.replace(alice_prefix, '', 1)})
                if resp:
                    messages.append({"role": "assistant", "content": resp.replace(bob_prefix, '', 1)})
            resp = openai.ChatCompletion.create(
                api_key=key,
                model=bob_model, 
                messages=messages,
                temperature=bob_temp
            )
            bob_response = resp['choices'][0]['message']['content']
            history[-1][1] = f'{bob_prefix}{bob_response}'        
        else: # Alice's turn
            messages = [
                {"role": "system", "content": alice_system},        
            ]
            for n, [msg, resp] in enumerate(history):
                if n > 0: # skip the conversation opener
                    messages.append({"role": "assistant", "content": msg.replace(alice_prefix, '', 1)})
                if resp:
                    messages.append({"role": "user", "content": resp.replace(bob_prefix, '', 1)})
            resp = openai.ChatCompletion.create(
                api_key=key,
                model=alice_model, 
                messages=messages,
                temperature=alice_temp
            )            
            alice_response = resp['choices'][0]['message']['content']
            history.append([f'{alice_prefix}{alice_response}', None])

        if len(messages) >= 10:
            return history
        else:
            yield history
    

with gr.Blocks() as demo:
    key = gr.Textbox(placeholder='Enter your OpenAI API Key') 
    with gr.Row():
        with gr.Column():
            with gr.Row():
                with gr.Column():
                    alice_system = gr.Textbox(label='Alice\'s Prompt', value='Your name is Alice, you love to talk, and give your opinion', lines=5, interactive=True)
                    alice_temp = gr.Slider(value=1, minimum=0, maximum=1, step=0.1, label='Temperature')
                    alice_model = gr.Radio(choices=['gpt-4', 'gpt-3.5-turbo'], value='gpt-4')
            
            with gr.Row():
                with gr.Column():
                    bob_system = gr.Textbox(label='Bob\'s Prompt', value='Your name is Bob, you love to listen, and ask for opinions', lines=5, interactive=True)
                    bob_temp = gr.Slider(value=1, minimum=0, maximum=1, step=0.1, label='Temperature')
                    bob_model = gr.Radio(choices=['gpt-4', 'gpt-3.5-turbo'], value='gpt-4')

        with gr.Column():
            chatbot = gr.Chatbot()    
            seed = gr.Textbox(placeholder="How does Alice open the conversation?")            
            start = gr.Button('Start conversation')

    start.click(seed_submit, [seed, chatbot], [seed, chatbot], queue=False).then(
        bot_update, [chatbot, key, alice_system, alice_temp, alice_model, bob_system, bob_temp, bob_model], [chatbot]
    )

demo.queue(concurrency_count=1)
demo.launch(show_error=True)