kdevoe commited on
Commit
1516740
·
verified ·
1 Parent(s): 97784fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -11,18 +11,18 @@ tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
11
  model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
12
  model.to(device)
13
 
14
- # Load summarization model (e.g., T5-small)
15
- summarizer_tokenizer = AutoTokenizer.from_pretrained("t5-small")
16
- summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(device)
17
 
18
- def summarize_history(history):
19
- input_ids = summarizer_tokenizer.encode(
20
- "summarize: " + history,
21
- return_tensors="pt"
22
- ).to(device)
23
- summary_ids = summarizer_model.generate(input_ids, max_length=50, min_length=25, length_penalty=5., num_beams=2)
24
- summary = summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
25
- return summary
26
 
27
  # Set up conversational memory using LangChain's ConversationBufferMemory
28
  memory = ConversationBufferMemory()
@@ -32,9 +32,9 @@ def chat_with_bart(input_text):
32
  # Retrieve conversation history
33
  conversation_history = memory.load_memory_variables({})['history']
34
 
35
- # Summarize if history exceeds a certain length
36
- if len(conversation_history.split()) > 200:
37
- conversation_history = summarize_history(conversation_history)
38
 
39
  # Combine the (possibly summarized) history with the current user input
40
  full_input = f"{conversation_history}\nUser: {input_text}\nAssistant:"
 
11
  model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
12
  model.to(device)
13
 
14
+ # # Load summarization model (e.g., T5-small)
15
+ # summarizer_tokenizer = AutoTokenizer.from_pretrained("t5-small")
16
+ # summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(device)
17
 
18
+ # def summarize_history(history):
19
+ # input_ids = summarizer_tokenizer.encode(
20
+ # "summarize: " + history,
21
+ # return_tensors="pt"
22
+ # ).to(device)
23
+ # summary_ids = summarizer_model.generate(input_ids, max_length=50, min_length=25, length_penalty=5., num_beams=2)
24
+ # summary = summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
25
+ # return summary
26
 
27
  # Set up conversational memory using LangChain's ConversationBufferMemory
28
  memory = ConversationBufferMemory()
 
32
  # Retrieve conversation history
33
  conversation_history = memory.load_memory_variables({})['history']
34
 
35
+ # # Summarize if history exceeds a certain length
36
+ # if len(conversation_history.split()) > 200:
37
+ # conversation_history = summarize_history(conversation_history)
38
 
39
  # Combine the (possibly summarized) history with the current user input
40
  full_input = f"{conversation_history}\nUser: {input_text}\nAssistant:"