File size: 2,986 Bytes
1bd11c4
 
 
cf24b60
1bd11c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import json

import langchain
from langchain.memory import ConversationBufferWindowMemory
from langchain.llms.openai import OpenAI
from langchain.llms.base import LLM
from typing import Optional, List
from langchain.llms.utils import enforce_stop_tokens
from transformers import AutoTokenizer, AutoModel
import torch

# ... (Your existing code with the ChatGLM class definition) ...

class ChatGLM(LLM):
    max_token: int = 10000
    temperature: float = 0.1
    top_p = 0.9
    history = []

    def __init__(self):
        super().__init__()

    @property
    def _llm_type(self) -> str:
        return "ChatGLM"

    def _call(self,
              prompt: str,
              stop: Optional[List[str]] = None) -> str:
        response, updated_history = model.chat(
            tokenizer,
            prompt,
            history=self.history,
            max_length=self.max_token,
            temperature=self.temperature,
        )
        torch_gc()
        print("history: ", self.history)
        if stop is not None:
            response = enforce_stop_tokens(response, stop)
        self.history = updated_history
        return response


def chatbots_conversation(num_turns: int, chat_gpt_prompt: str, chat_glm_prompt: str):
    chatgpt_chain = LLMChain(
        llm=OpenAI(temperature=0),
        prompt=chat_gpt_prompt,
        verbose=True,
        memory=ConversationBufferWindowMemory(k=2),
    )

    chat_glm = ChatGLM()

    conversation_history = []

    for _ in range(num_turns):
        chatgpt_response = chatgpt_chain(chat_glm_prompt)
        conversation_history.append({"bot": "chatgpt", "text": chatgpt_response})
        chat_glm_prompt = chatgpt_response

        chat_glm_response = chat_glm(chatgpt_response)
        conversation_history.append({"bot": "chatglm", "text": chat_glm_response})
        chat_glm_prompt = chat_glm_response

    return conversation_history


def save_to_json(conversations, file_name="conversations.json"):
    with open(file_name, "w") as outfile:
        json.dump(conversations, outfile, indent=4)


def gradio_wrapper(num_turns: int, chat_gpt_prompt: str, chat_glm_prompt: str, save_conversations: bool = False):
    conversations = chatbots_conversation(num_turns, chat_gpt_prompt, chat_glm_prompt)

    if save_conversations:
        save_to_json(conversations)

    formatted_conversations = "\n".join([f"{conv['bot']}: {conv['text']}" for conv in conversations])

    return formatted_conversations


iface = gr.Interface(
    gradio_wrapper,
    inputs=[
        gr.inputs.Slider(1, 10, 1, label="Number of turns"),
        gr.inputs.Textbox(lines=3, label="ChatGPT Prompt"),
        gr.inputs.Textbox(lines=3, label="ChatGLM Prompt"),
        gr.inputs.Checkbox(label="Save Conversations"),
    ],
    outputs=gr.outputs.Textbox(lines=10, label="Conversations"),
    title="Chatbot Conversation",
    description="A conversation between OpenAI's ChatGPT and THUDM's ChatGLM",
)

iface.launch()