File size: 4,095 Bytes
0efd337
c0f5c7a
07b00c0
0e4ab50
2ad20f5
ce16e77
 
 
a47f900
c0f5c7a
07b00c0
a47f900
 
c0f5c7a
a47f900
 
c0f5c7a
a47f900
 
 
 
 
 
 
ec853a0
0e4ab50
 
 
a47f900
 
ec853a0
0e4ab50
 
a47f900
2c5cee9
a47f900
07b00c0
4d1575b
2ad20f5
a47f900
 
fa7af89
a47f900
fa7af89
 
 
 
 
 
 
a47f900
 
 
fa7af89
 
a47f900
 
fa7af89
a47f900
 
2ad20f5
616f6e6
 
 
 
2ad20f5
616f6e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0efd337
616f6e6
 
 
2ad20f5
 
0e4ab50
616f6e6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
import torch
from langchain.memory import ConversationBufferMemory

# Move model to device (GPU if available)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Load the tokenizer (you can use the pre-trained tokenizer for GPT-2 family)
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")

# Manually create a configuration for the model (since we don't have config.json)
config = GPT2Config.from_pretrained("distilgpt2")

# Initialize the model using the manually created configuration
model = GPT2LMHeadModel(config)

# Load the weights from the pytorch_model.bin file
model_path = "./pytorch_model_100.bin"  # Path to local model file
state_dict = torch.load(model_path, map_location=device)  # Load the state_dict
model.load_state_dict(state_dict)  # Load the state dict into the model

# Move model to the device (GPU or CPU)
model.to(device)

# Set up conversational memory using LangChain's ConversationBufferMemory
memory = ConversationBufferMemory()

# Define the chatbot function with memory and additional parameters
def chat_with_distilgpt2(input_text, temperature, top_p, top_k):
    # Retrieve conversation history
    conversation_history = memory.load_memory_variables({})['history']
    
    # Combine the (possibly summarized) history with the current user input
    no_memory_input = f"Question: {input_text}\nAnswer:"
    
    # Tokenize the input and convert to tensor
    input_ids = tokenizer.encode(no_memory_input, return_tensors="pt").to(device)
    
    # Generate the response using the model with adjusted parameters
    outputs = model.generate(
        input_ids,
        max_length=input_ids.shape[1] + 50,  # Limit total length
        max_new_tokens=15,
        num_return_sequences=1,
        no_repeat_ngram_size=3,
        repetition_penalty=1.2,
        early_stopping=True,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        temperature=temperature,  # Add temperature from slider
        top_p=top_p,              # Add top_p from slider
        top_k=top_k               # Add top_k from slider
    )
    
    # Decode the model output
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Update the memory with the user input and model response
    memory.save_context({"input": input_text}, {"output": response})
    
    # Format the chat history for display
    chat_history = conversation_history + f"\nYou: {input_text}\nBot: {response}\n"
    
    return chat_history

# Set up the Gradio interface with the input box below the output box
with gr.Blocks() as interface:
    chatbot_output = gr.Textbox(label="Conversation", lines=15, placeholder="Chat history will appear here...", interactive=False)
    
    # Add the instruction message above the input box
    gr.Markdown("**Instructions:** Press `Shift + Enter` to submit, and `Enter` for a new line.")
    
    # Input box for the user
    user_input = gr.Textbox(label="Your Input", placeholder="Type your message here...", lines=2, show_label=True)
    
    # Sliders for temperature, top_p, and top_k
    temperature_slider = gr.Slider(0.1, 1.0, step=0.1, value=1.0, label="Temperature")
    top_p_slider = gr.Slider(0.0, 1.0, step=0.1, value=1.0, label="Top-p")
    top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k")
    
    # Define the function to update the chat
    def update_chat(input_text, chat_history, temperature, top_p, top_k):
        updated_history = chat_with_distilgpt2(input_text, temperature, top_p, top_k)
        return updated_history, ""
    
    # Submit when pressing Shift + Enter
    user_input.submit(update_chat, 
                      inputs=[user_input, chatbot_output, temperature_slider, top_p_slider, top_k_slider], 
                      outputs=[chatbot_output, user_input])

    # Layout for sliders and chatbot UI
    gr.Row([temperature_slider, top_p_slider, top_k_slider])
    
# Launch the Gradio app
interface.launch()