File size: 3,254 Bytes
a19918d
 
 
 
1fe665e
a19918d
 
1fe665e
 
 
a19918d
 
 
 
 
 
e216e88
a19918d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77f4a16
a19918d
 
 
 
 
77f4a16
 
a19918d
77f4a16
 
 
 
a19918d
77f4a16
 
 
 
 
 
 
 
a19918d
77f4a16
 
 
 
 
a19918d
 
 
 
 
 
 
 
 
 
 
 
 
1fe665e
 
 
a19918d
 
1fe665e
 
 
 
 
 
 
 
e216e88
1fe665e
 
a19918d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import aiohttp
import asyncio
import json
import gradio as gr
import os

# Define the API URLs
BASE_URL = os.getenv("URL")  # Ensure the environment variable "URL" is set
if not BASE_URL:
    raise ValueError("Environment variable 'URL' not set")
TOKEN_URL = BASE_URL + "/get-token"
CHAT_URL = BASE_URL + "/conversation"

# Initialize the token and message history
token = ""
messHistory: list = []

async def chat(messList):
    """Async function to send and receive messages with the server."""
    global token
    async with aiohttp.ClientSession() as session:
        # Request token if not already set
        if token == "":
            async with session.get(TOKEN_URL) as resp:
                data = await resp.json()
                token = data["token"]

        body = {
            "token": token,
            "message": messList,
            "stream": True
        }

        fullmessage = ""
        # Make the POST request to the chat API
        async with session.post(CHAT_URL, json=body) as resp:
            if resp.status != 200:
                return "Error occurred during the chat process."

            # Use a buffer to accumulate data chunks
            buffer = ""
            async for chunk in resp.content.iter_any():
                buffer += chunk.decode("utf-8")

                # Attempt to decode JSON objects from the buffer
                while True:
                    try:
                        # Find the end of the JSON object
                        index = buffer.index('\n')
                        json_str = buffer[:index].strip()
                        buffer = buffer[index+1:]

                        if json_str.strip() == "[DONE]":
                            break

                        data_dict = json.loads(json_str)
                        fullmessage += data_dict.get("message", "")
                        token = data_dict.get("resp_token", token)  # Update token
                    except (ValueError, json.JSONDecodeError):
                        # Handle incomplete or malformed JSON
                        break
            
            messHistory.append({"role": "assistant", "content": fullmessage})  # Append assistant response
            return fullmessage

def gradio_chat(user_input, mode):
    """Synchronous wrapper for the async chat function, integrated with Gradio."""
    messHistory.append({"role": "user", "content": f"[{mode}] {user_input}"})
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    assistant_response = loop.run_until_complete(chat(messHistory))
    return assistant_response

# Gradio interface for user interaction
def chat_interface(user_input, mode):
    return gradio_chat(user_input, mode)

with gr.Blocks() as demo:
    gr.Markdown("# Chat with AI")
    
    with gr.Row():
        radio_mode = gr.Radio(["Friendly", "Formal", "Humorous"], label="Chat Mode", value="Friendly")
    
    with gr.Row():
        chatbot = gr.Interface(
            fn=lambda user_input: chat_interface(user_input, radio_mode.value),
            inputs=[gr.Textbox(label="Your message")],
            outputs=[gr.Markdown(label="Assistant response")]  # Use Markdown for the response
        )
    
# Launch the Gradio app
demo.launch()