Spaces:
Sleeping
Sleeping
File size: 4,208 Bytes
0efd337 c0f5c7a 07b00c0 0e4ab50 2ad20f5 ce16e77 fa7af89 c0f5c7a 07b00c0 fa7af89 c0f5c7a fa7af89 c0f5c7a fa7af89 ec853a0 0e4ab50 fa7af89 ec853a0 0e4ab50 fa7af89 2c5cee9 fa7af89 07b00c0 4d1575b 2ad20f5 fa7af89 fb600ee fa7af89 90d8219 fb600ee 5b2a7b6 fa7af89 fb600ee 2ad20f5 fa7af89 0e4ab50 fa7af89 0e4ab50 fa7af89 2ad20f5 5b2a7b6 2ad20f5 fa7af89 5b2a7b6 fa7af89 0efd337 2ad20f5 0e4ab50 fa7af89 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | import gradio as gr
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
import torch
from langchain.memory import ConversationBufferMemory
# Move model to device (GPU if available)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Load the tokenizer (same tokenizer for both models since both are GPT-2 based)
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
# Load the baseline model (pre-trained DistilGPT2)
baseline_model = GPT2LMHeadModel.from_pretrained("distilgpt2").to(device)
# Load the fine-tuned model using its configuration and state dictionary
# You should have a local fine-tuned model file for this (pytorch_model_100.bin)
fine_tuned_config = GPT2Config.from_pretrained("distilgpt2")
fine_tuned_model = GPT2LMHeadModel(fine_tuned_config)
# Load the fine-tuned weights
model_path = "./pytorch_model_100.bin" # Path to your fine-tuned model file
state_dict = torch.load(model_path, map_location=device)
fine_tuned_model.load_state_dict(state_dict)
fine_tuned_model.to(device)
# Set up conversational memory using LangChain's ConversationBufferMemory
memory = ConversationBufferMemory()
# Define the chatbot function with both baseline and fine-tuned models
def chat_with_both_models(input_text, temperature, top_p, top_k):
# Retrieve conversation history
conversation_history = memory.load_memory_variables({})['history']
# Combine the conversation history with the user input (or just use input directly)
no_memory_input = f"Question: {input_text}\nAnswer:"
# Tokenize the input and convert to tensor
input_ids = tokenizer.encode(no_memory_input, return_tensors="pt").to(device)
# Generate response from baseline DistilGPT2
baseline_outputs = baseline_model.generate(
input_ids,
max_length=input_ids.shape[1] + 50,
max_new_tokens=15,
num_return_sequences=1,
no_repeat_ngram_size=3,
repetition_penalty=1.2,
early_stopping=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
temperature=temperature,
top_p=top_p,
top_k=top_k
)
# Decode the baseline model output
baseline_response = tokenizer.decode(baseline_outputs[0], skip_special_tokens=True)
# Generate response from the fine-tuned DistilGPT2
fine_tuned_outputs = fine_tuned_model.generate(
input_ids,
max_length=input_ids.shape[1] + 50,
max_new_tokens=15,
num_return_sequences=1,
no_repeat_ngram_size=3,
repetition_penalty=1.2,
early_stopping=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
temperature=temperature,
top_p=top_p,
top_k=top_k
)
# Decode the fine-tuned model output
fine_tuned_response = tokenizer.decode(fine_tuned_outputs[0], skip_special_tokens=True)
# Update the memory with the user input and responses from both models
memory.save_context({"input": input_text}, {"baseline_output": baseline_response, "fine_tuned_output": fine_tuned_response})
# Return both responses
return baseline_response, fine_tuned_response
# Set up the Gradio interface with additional sliders
interface = gr.Interface(
fn=chat_with_both_models,
inputs=[
gr.Textbox(label="Chat with DistilGPT-2"), # User input text
gr.Slider(0.1, 1.0, step=0.1, value=1.0, label="Temperature"), # Slider for temperature
gr.Slider(0.0, 1.0, step=0.1, value=1.0, label="Top-p"), # Slider for top-p
gr.Slider(1, 100, step=1, value=50, label="Top-k") # Slider for top-k
],
outputs=[
gr.Textbox(label="Baseline DistilGPT-2's Response"), # Baseline model response
gr.Textbox(label="Fine-tuned DistilGPT-2's Response") # Fine-tuned model response
],
title="DistilGPT-2 Chatbot: Baseline vs Fine-tuned",
description="This app compares the responses of a baseline DistilGPT-2 and a fine-tuned version for each input prompt. You can adjust temperature, top-p, and top-k using the sliders.",
)
# Launch the Gradio app
interface.launch()
|