File size: 2,936 Bytes
0efd337
b43a17c
07b00c0
0e4ab50
2ad20f5
ce16e77
 
 
b43a17c
 
 
ce16e77
07b00c0
1516740
 
 
ec853a0
1516740
 
 
 
 
 
 
 
ec853a0
0e4ab50
 
 
b43a17c
 
ec853a0
0e4ab50
 
1516740
 
 
ec853a0
 
0e4ab50
 
07b00c0
b43a17c
2ad20f5
b43a17c
fb600ee
b43a17c
 
 
 
fb600ee
97784fe
 
 
 
 
fb600ee
2ad20f5
0e4ab50
2ad20f5
0e4ab50
 
 
 
2ad20f5
 
 
 
b43a17c
 
 
 
 
0efd337
 
2ad20f5
 
0e4ab50
07b00c0
b43a17c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
from transformers import BartTokenizer, BartForConditionalGeneration, AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from langchain.memory import ConversationBufferMemory

# Move model to device (GPU if available)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Load the tokenizer and model for BART Base
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
model.to(device)

# # Load summarization model (e.g., T5-small)
# summarizer_tokenizer = AutoTokenizer.from_pretrained("t5-small")
# summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(device)

# def summarize_history(history):
#     input_ids = summarizer_tokenizer.encode(
#         "summarize: " + history,
#         return_tensors="pt"
#     ).to(device)
#     summary_ids = summarizer_model.generate(input_ids, max_length=50, min_length=25, length_penalty=5., num_beams=2)
#     summary = summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
#     return summary

# Set up conversational memory using LangChain's ConversationBufferMemory
memory = ConversationBufferMemory()

# Define the chatbot function with memory using BART Base
def chat_with_bart(input_text):
    # Retrieve conversation history
    conversation_history = memory.load_memory_variables({})['history']
    
    # # Summarize if history exceeds a certain length
    # if len(conversation_history.split()) > 200:
    #     conversation_history = summarize_history(conversation_history)
    
    # Combine the (possibly summarized) history with the current user input
    full_input = f"{conversation_history}\nUser: {input_text}\nAssistant:"
    
    # Tokenize the input and convert to tensor
    inputs = tokenizer(full_input, return_tensors="pt", max_length=1024, truncation=True).to(device)
    
    # Generate the response using the BART model
    outputs = model.generate(
        inputs["input_ids"],
        max_length=1024,
        num_beams=4,
        early_stopping=True,
        no_repeat_ngram_size=3,
        repetition_penalty=1.2
        # Set the following to default
        #temperature=0.9,
        #top_k=20,
        #top_p=0.8
    )
    
    # Decode the model output
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Update the memory with the user input and model response
    memory.save_context({"input": input_text}, {"output": response})
    
    return response

# Set up the Gradio interface
interface = gr.Interface(
    fn=chat_with_bart,
    inputs=gr.Textbox(label="Chat with BART Base"),
    outputs=gr.Textbox(label="BART Base's Response"),
    title="BART Base Chatbot with Memory",
    description="This is a simple chatbot powered by the BART Base model with conversational memory, using LangChain.",
)

# Launch the Gradio app
interface.launch()