|
|
import gradio as gr |
|
|
from transformers import BartTokenizer, BartForConditionalGeneration, AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
import torch |
|
|
from langchain.memory import ConversationBufferMemory |
|
|
|
|
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
|
|
|
|
|
|
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") |
|
|
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
memory = ConversationBufferMemory() |
|
|
|
|
|
|
|
|
def chat_with_bart(input_text): |
|
|
|
|
|
conversation_history = memory.load_memory_variables({})['history'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
full_input = f"{conversation_history}\nUser: {input_text}\nAssistant:" |
|
|
|
|
|
|
|
|
inputs = tokenizer(full_input, return_tensors="pt", max_length=1024, truncation=True).to(device) |
|
|
|
|
|
|
|
|
outputs = model.generate( |
|
|
inputs["input_ids"], |
|
|
max_length=1024, |
|
|
num_beams=4, |
|
|
early_stopping=True, |
|
|
no_repeat_ngram_size=3, |
|
|
repetition_penalty=1.2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
memory.save_context({"input": input_text}, {"output": response}) |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
|
fn=chat_with_bart, |
|
|
inputs=gr.Textbox(label="Chat with BART Base"), |
|
|
outputs=gr.Textbox(label="BART Base's Response"), |
|
|
title="BART Base Chatbot with Memory", |
|
|
description="This is a simple chatbot powered by the BART Base model with conversational memory, using LangChain.", |
|
|
) |
|
|
|
|
|
|
|
|
interface.launch() |
|
|
|
|
|
|
|
|
|
|
|
|