kdevoe commited on
Commit
b43a17c
·
verified ·
1 Parent(s): 5167829

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -22
app.py CHANGED
@@ -1,14 +1,14 @@
1
  import gradio as gr
2
- from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoModelForSeq2SeqLM, AutoTokenizer
3
  import torch
4
  from langchain.memory import ConversationBufferMemory
5
 
6
  # Move model to device (GPU if available)
7
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
8
 
9
- # Load the tokenizer and model for DistilGPT-2
10
- tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
11
- model = GPT2LMHeadModel.from_pretrained("distilgpt2")
12
  model.to(device)
13
 
14
  # Load summarization model (e.g., T5-small)
@@ -27,12 +27,12 @@ def summarize_history(history):
27
  # Set up conversational memory using LangChain's ConversationBufferMemory
28
  memory = ConversationBufferMemory()
29
 
30
- # Define the chatbot function with memory
31
- def chat_with_distilgpt2(input_text):
32
  # Retrieve conversation history
33
  conversation_history = memory.load_memory_variables({})['history']
34
 
35
- # Summarize if history exceeds certain length
36
  if len(conversation_history.split()) > 200:
37
  conversation_history = summarize_history(conversation_history)
38
 
@@ -40,22 +40,19 @@ def chat_with_distilgpt2(input_text):
40
  full_input = f"{conversation_history}\nUser: {input_text}\nAssistant:"
41
 
42
  # Tokenize the input and convert to tensor
43
- input_ids = tokenizer.encode(full_input, return_tensors="pt").to(device)
44
 
45
- # Generate the response using the model with adjusted parameters
46
  outputs = model.generate(
47
- input_ids,
48
- max_length=input_ids.shape[1] + 100, # Limit total length
49
- max_new_tokens=100,
50
- num_return_sequences=1,
51
  no_repeat_ngram_size=3,
52
  repetition_penalty=1.2,
53
  temperature=0.9,
54
  top_k=20,
55
- top_p=0.8,
56
- early_stopping=True,
57
- pad_token_id=tokenizer.eos_token_id,
58
- eos_token_id=tokenizer.eos_token_id
59
  )
60
 
61
  # Decode the model output
@@ -68,14 +65,15 @@ def chat_with_distilgpt2(input_text):
68
 
69
  # Set up the Gradio interface
70
  interface = gr.Interface(
71
- fn=chat_with_distilgpt2,
72
- inputs=gr.Textbox(label="Chat with DistilGPT-2"),
73
- outputs=gr.Textbox(label="DistilGPT-2's Response"),
74
- title="DistilGPT-2 Chatbot with Memory",
75
- description="This is a simple chatbot powered by the DistilGPT-2 model with conversational memory, using LangChain.",
76
  )
77
 
78
  # Launch the Gradio app
79
  interface.launch()
80
 
81
 
 
 
1
  import gradio as gr
2
+ from transformers import BartTokenizer, BartForConditionalGeneration, AutoModelForSeq2SeqLM, AutoTokenizer
3
  import torch
4
  from langchain.memory import ConversationBufferMemory
5
 
6
  # Move model to device (GPU if available)
7
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
8
 
9
+ # Load the tokenizer and model for BART Base
10
+ tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
11
+ model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
12
  model.to(device)
13
 
14
  # Load summarization model (e.g., T5-small)
 
27
  # Set up conversational memory using LangChain's ConversationBufferMemory
28
  memory = ConversationBufferMemory()
29
 
30
+ # Define the chatbot function with memory using BART Base
31
+ def chat_with_bart(input_text):
32
  # Retrieve conversation history
33
  conversation_history = memory.load_memory_variables({})['history']
34
 
35
+ # Summarize if history exceeds a certain length
36
  if len(conversation_history.split()) > 200:
37
  conversation_history = summarize_history(conversation_history)
38
 
 
40
  full_input = f"{conversation_history}\nUser: {input_text}\nAssistant:"
41
 
42
  # Tokenize the input and convert to tensor
43
+ inputs = tokenizer(full_input, return_tensors="pt", max_length=1024, truncation=True).to(device)
44
 
45
+ # Generate the response using the BART model
46
  outputs = model.generate(
47
+ inputs["input_ids"],
48
+ max_length=1024,
49
+ num_beams=4,
50
+ early_stopping=True,
51
  no_repeat_ngram_size=3,
52
  repetition_penalty=1.2,
53
  temperature=0.9,
54
  top_k=20,
55
+ top_p=0.8
 
 
 
56
  )
57
 
58
  # Decode the model output
 
65
 
66
  # Set up the Gradio interface
67
  interface = gr.Interface(
68
+ fn=chat_with_bart,
69
+ inputs=gr.Textbox(label="Chat with BART Base"),
70
+ outputs=gr.Textbox(label="BART Base's Response"),
71
+ title="BART Base Chatbot with Memory",
72
+ description="This is a simple chatbot powered by the BART Base model with conversational memory, using LangChain.",
73
  )
74
 
75
  # Launch the Gradio app
76
  interface.launch()
77
 
78
 
79
+