File size: 2,953 Bytes
d6094ad
a975747
 
 
 
 
 
 
 
922583c
a975747
 
 
 
 
17a6709
a975747
 
 
 
 
 
 
 
 
922583c
 
 
 
 
a975747
 
 
17a6709
922583c
a975747
 
 
 
17a6709
922583c
 
 
 
 
 
 
 
 
 
 
 
 
a975747
922583c
a975747
d6094ad
922583c
a975747
922583c
 
 
 
 
a975747
 
922583c
 
a975747
17a6709
 
 
922583c
a975747
922583c
 
 
 
17a6709
d6094ad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import gradio as gr
import requests

BACKEND_URL = "https://0learn-fastchat.hf.space"  # Update this with your actual FastAPI backend URL

class ConversationState:
    def __init__(self):
        self.conversation_id = None
        self.title = "New Conversation"

state = ConversationState()

def chat_with_groq(message, history):
    messages = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} 
                for i, msg in enumerate([item for sublist in history for item in sublist] + [message])]
    
    payload = {"messages": messages}
    if state.conversation_id:
        payload["conversation_id"] = state.conversation_id
    
    response = requests.post(f"{BACKEND_URL}/chat", json=payload)
    response_data = response.json()
    
    state.conversation_id = response_data["conversation_id"]
    
    # Update title if it's a new conversation
    if len(history) == 0:
        state.title = message[:30] + "..." if len(message) > 30 else message
    
    return response_data["response"]

def load_conversation(conversation_id):
    if not conversation_id:
        return [], state.title
    response = requests.get(f"{BACKEND_URL}/conversations/{conversation_id}")
    if response.status_code == 200:
        conversation_data = response.json()
        state.conversation_id = conversation_id
        messages = conversation_data["messages"]
        state.title = messages[0]["content"][:30] + "..." if len(messages[0]["content"]) > 30 else messages[0]["content"]
        return [(msg["content"], messages[i+1]["content"]) for i, msg in enumerate(messages[::2])], state.title
    else:
        return [], "Conversation not found"

def start_new_conversation():
    state.conversation_id = None
    state.title = "New Conversation"
    return [], state.title

def continue_conversation():
    if state.conversation_id:
        return load_conversation(state.conversation_id)
    else:
        return [], "No previous conversation"

with gr.Blocks() as demo:
    gr.Markdown("# Groq Chatbot with Enhanced Conversation Management")
    
    with gr.Row():
        new_conv_btn = gr.Button("New Conversation")
        continue_conv_btn = gr.Button("Continue Previous Conversation")
    
    title_display = gr.Markdown(f"## Current Conversation: {state.title}")
    
    chatbot = gr.Chatbot()
    msg = gr.Textbox(label="Type your message here")
    clear = gr.Button("Clear Chat")

    def respond(message, chat_history):
        bot_message = chat_with_groq(message, chat_history)
        chat_history.append((message, bot_message))
        return "", chat_history, state.title

    msg.submit(respond, [msg, chatbot], [msg, chatbot, title_display])
    clear.click(lambda: ([], state.title), None, [chatbot, title_display], queue=False)
    new_conv_btn.click(start_new_conversation, None, [chatbot, title_display])
    continue_conv_btn.click(continue_conversation, None, [chatbot, title_display])

demo.launch()