Spaces:
Sleeping
Sleeping
File size: 4,095 Bytes
0efd337 c0f5c7a 07b00c0 0e4ab50 2ad20f5 ce16e77 a47f900 c0f5c7a 07b00c0 a47f900 c0f5c7a a47f900 c0f5c7a a47f900 ec853a0 0e4ab50 a47f900 ec853a0 0e4ab50 a47f900 2c5cee9 a47f900 07b00c0 4d1575b 2ad20f5 a47f900 fa7af89 a47f900 fa7af89 a47f900 fa7af89 a47f900 fa7af89 a47f900 2ad20f5 616f6e6 2ad20f5 616f6e6 0efd337 616f6e6 2ad20f5 0e4ab50 616f6e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import gradio as gr
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
import torch
from langchain.memory import ConversationBufferMemory
# Move model to device (GPU if available)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Load the tokenizer (you can use the pre-trained tokenizer for GPT-2 family)
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
# Manually create a configuration for the model (since we don't have config.json)
config = GPT2Config.from_pretrained("distilgpt2")
# Initialize the model using the manually created configuration
model = GPT2LMHeadModel(config)
# Load the weights from the pytorch_model.bin file
model_path = "./pytorch_model_100.bin" # Path to local model file
state_dict = torch.load(model_path, map_location=device) # Load the state_dict
model.load_state_dict(state_dict) # Load the state dict into the model
# Move model to the device (GPU or CPU)
model.to(device)
# Set up conversational memory using LangChain's ConversationBufferMemory
memory = ConversationBufferMemory()
# Define the chatbot function with memory and additional parameters
def chat_with_distilgpt2(input_text, temperature, top_p, top_k):
# Retrieve conversation history
conversation_history = memory.load_memory_variables({})['history']
# Combine the (possibly summarized) history with the current user input
no_memory_input = f"Question: {input_text}\nAnswer:"
# Tokenize the input and convert to tensor
input_ids = tokenizer.encode(no_memory_input, return_tensors="pt").to(device)
# Generate the response using the model with adjusted parameters
outputs = model.generate(
input_ids,
max_length=input_ids.shape[1] + 50, # Limit total length
max_new_tokens=15,
num_return_sequences=1,
no_repeat_ngram_size=3,
repetition_penalty=1.2,
early_stopping=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
temperature=temperature, # Add temperature from slider
top_p=top_p, # Add top_p from slider
top_k=top_k # Add top_k from slider
)
# Decode the model output
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Update the memory with the user input and model response
memory.save_context({"input": input_text}, {"output": response})
# Format the chat history for display
chat_history = conversation_history + f"\nYou: {input_text}\nBot: {response}\n"
return chat_history
# Set up the Gradio interface with the input box below the output box
with gr.Blocks() as interface:
chatbot_output = gr.Textbox(label="Conversation", lines=15, placeholder="Chat history will appear here...", interactive=False)
# Add the instruction message above the input box
gr.Markdown("**Instructions:** Press `Shift + Enter` to submit, and `Enter` for a new line.")
# Input box for the user
user_input = gr.Textbox(label="Your Input", placeholder="Type your message here...", lines=2, show_label=True)
# Sliders for temperature, top_p, and top_k
temperature_slider = gr.Slider(0.1, 1.0, step=0.1, value=1.0, label="Temperature")
top_p_slider = gr.Slider(0.0, 1.0, step=0.1, value=1.0, label="Top-p")
top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k")
# Define the function to update the chat
def update_chat(input_text, chat_history, temperature, top_p, top_k):
updated_history = chat_with_distilgpt2(input_text, temperature, top_p, top_k)
return updated_history, ""
# Submit when pressing Shift + Enter
user_input.submit(update_chat,
inputs=[user_input, chatbot_output, temperature_slider, top_p_slider, top_k_slider],
outputs=[chatbot_output, user_input])
# Layout for sliders and chatbot UI
gr.Row([temperature_slider, top_p_slider, top_k_slider])
# Launch the Gradio app
interface.launch()
|