File size: 2,936 Bytes
0efd337 b43a17c 07b00c0 0e4ab50 2ad20f5 ce16e77 b43a17c ce16e77 07b00c0 1516740 ec853a0 1516740 ec853a0 0e4ab50 b43a17c ec853a0 0e4ab50 1516740 ec853a0 0e4ab50 07b00c0 b43a17c 2ad20f5 b43a17c fb600ee b43a17c fb600ee 97784fe fb600ee 2ad20f5 0e4ab50 2ad20f5 0e4ab50 2ad20f5 b43a17c 0efd337 2ad20f5 0e4ab50 07b00c0 b43a17c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import gradio as gr
from transformers import BartTokenizer, BartForConditionalGeneration, AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from langchain.memory import ConversationBufferMemory
# Move model to device (GPU if available)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Load the tokenizer and model for BART Base
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
model.to(device)
# # Load summarization model (e.g., T5-small)
# summarizer_tokenizer = AutoTokenizer.from_pretrained("t5-small")
# summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(device)
# def summarize_history(history):
# input_ids = summarizer_tokenizer.encode(
# "summarize: " + history,
# return_tensors="pt"
# ).to(device)
# summary_ids = summarizer_model.generate(input_ids, max_length=50, min_length=25, length_penalty=5., num_beams=2)
# summary = summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# return summary
# Set up conversational memory using LangChain's ConversationBufferMemory
memory = ConversationBufferMemory()
# Define the chatbot function with memory using BART Base
def chat_with_bart(input_text):
# Retrieve conversation history
conversation_history = memory.load_memory_variables({})['history']
# # Summarize if history exceeds a certain length
# if len(conversation_history.split()) > 200:
# conversation_history = summarize_history(conversation_history)
# Combine the (possibly summarized) history with the current user input
full_input = f"{conversation_history}\nUser: {input_text}\nAssistant:"
# Tokenize the input and convert to tensor
inputs = tokenizer(full_input, return_tensors="pt", max_length=1024, truncation=True).to(device)
# Generate the response using the BART model
outputs = model.generate(
inputs["input_ids"],
max_length=1024,
num_beams=4,
early_stopping=True,
no_repeat_ngram_size=3,
repetition_penalty=1.2
# Set the following to default
#temperature=0.9,
#top_k=20,
#top_p=0.8
)
# Decode the model output
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Update the memory with the user input and model response
memory.save_context({"input": input_text}, {"output": response})
return response
# Set up the Gradio interface
interface = gr.Interface(
fn=chat_with_bart,
inputs=gr.Textbox(label="Chat with BART Base"),
outputs=gr.Textbox(label="BART Base's Response"),
title="BART Base Chatbot with Memory",
description="This is a simple chatbot powered by the BART Base model with conversational memory, using LangChain.",
)
# Launch the Gradio app
interface.launch()
|