odaly commited on
Commit
3e22531
·
verified ·
1 Parent(s): 349f68b

update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -18
app.py CHANGED
@@ -10,27 +10,21 @@ tokenizer = T5Tokenizer.from_pretrained("t5-base")
10
  model = T5ForConditionalGeneration.from_pretrained("t5-base")
11
 
12
 
13
- def parse_text(text):
14
- sentences = sent_tokenize(text)
15
- tokens = []
16
- for sentence in sentences:
17
- tokens.extend(sentence.split())
18
- return ". ".join(tokens)
19
-
20
-
21
- def summarize(text):
22
- input_ids = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=512, truncation=True)
23
- summary_ids = model.generate(input_ids=input_ids, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
24
- output = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
25
  return output
26
 
27
 
28
- def format_messages_for_summary(messages):
29
- summary_text_list = []
30
  for message in messages:
31
  if message["role"] == "assistant":
32
- summary_text_list.append(message["content"].lower())
33
- return ' '.join(summary_text_list)
 
 
34
 
35
 
36
  def main():
@@ -51,14 +45,14 @@ def main():
51
  }
52
  ]
53
 
54
- response = summarize(user_input)
55
 
56
  st.session_state['messages'].append({
57
  "role": "assistant",
58
  "content": response
59
  })
60
 
61
- st.write(format_messages_for_summary(st.session_state['messages']))
62
 
63
 
64
  def save_session():
 
10
  model = T5ForConditionalGeneration.from_pretrained("t5-base")
11
 
12
 
13
+ def generate_response(text):
14
+ input_ids = tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True)
15
+ response_ids = model.generate(input_ids=input_ids, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
16
+ output = tokenizer.decode(response_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
17
  return output
18
 
19
 
20
+ def format_messages_for_display(messages):
21
+ formatted_text = []
22
  for message in messages:
23
  if message["role"] == "assistant":
24
+ formatted_text.append(f"Assistant: {message['content']}")
25
+ else:
26
+ formatted_text.append(f"User: {message['content']}")
27
+ return "\n".join(formatted_text)
28
 
29
 
30
  def main():
 
45
  }
46
  ]
47
 
48
+ response = generate_response(user_input)
49
 
50
  st.session_state['messages'].append({
51
  "role": "assistant",
52
  "content": response
53
  })
54
 
55
+ st.write(format_messages_for_display(st.session_state['messages']))
56
 
57
 
58
  def save_session():