odaly commited on
Commit
199d849
·
verified ·
1 Parent(s): eca93c2

update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -22
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  from transformers import T5ForConditionalGeneration, T5Tokenizer
3
  from nltk.tokenize import sent_tokenize
 
4
 
5
  nltk.download('punkt')
6
 
@@ -18,14 +19,9 @@ def parse_text(text):
18
 
19
 
20
  def summarize(text):
21
- input_ids = tokenizer.parse(text=text, return_tensors="pt")
22
- labels = input_ids.clone()
23
- decoder_input_ids = model.generate(input_ids=input_ids)
24
-
25
-
26
- output = tokenizer.batch_decode(decoder_output_ids.cpu(), skip_specialTokens=True)
27
-
28
-
29
  return output[0]
30
 
31
 
@@ -34,24 +30,20 @@ def format_messages_for_summary(messages):
34
  for message in messages:
35
  if message["role"] == "assistant":
36
  summary_text_list.append(message["content"].lower())
37
-
38
-
39
  return ''.join(summary_text_list)
40
 
41
 
42
  def main():
43
  st.title("T5 Chat Interface")
44
 
 
 
45
 
46
- user_input = st.text_area("Enter your prompt:")
47
-
48
-
49
- submitted = st.form_submit_button(label="Submit")
50
-
51
 
52
  if submitted:
53
-
54
-
55
  messages = [
56
  {
57
  "role": "user",
@@ -59,14 +51,12 @@ def main():
59
  }
60
  ]
61
 
62
-
63
  response = summarize(user_input)
64
 
65
-
66
  st.session_state['messages'].append({
67
  "role": "assistant",
68
- "content": response})
69
-
70
 
71
  st.write(format_messages_for_summary(st.session_state['messages']))
72
 
@@ -76,4 +66,4 @@ def save_session():
76
 
77
 
78
  if __name__ == '__main__':
79
- main()
 
1
  import streamlit as st
2
  from transformers import T5ForConditionalGeneration, T5Tokenizer
3
  from nltk.tokenize import sent_tokenize
4
+ import nltk
5
 
6
  nltk.download('punkt')
7
 
 
19
 
20
 
21
  def summarize(text):
22
+ input_ids = tokenizer.encode(text, return_tensors="pt")
23
+ decoder_output_ids = model.generate(input_ids=input_ids)
24
+ output = tokenizer.batch_decode(decoder_output_ids, skip_special_tokens=True)
 
 
 
 
 
25
  return output[0]
26
 
27
 
 
30
  for message in messages:
31
  if message["role"] == "assistant":
32
  summary_text_list.append(message["content"].lower())
 
 
33
  return ''.join(summary_text_list)
34
 
35
 
36
  def main():
37
  st.title("T5 Chat Interface")
38
 
39
+ if 'messages' not in st.session_state:
40
+ st.session_state['messages'] = []
41
 
42
+ with st.form(key='input_form'):
43
+ user_input = st.text_area("Enter your prompt:")
44
+ submitted = st.form_submit_button(label="Submit")
 
 
45
 
46
  if submitted:
 
 
47
  messages = [
48
  {
49
  "role": "user",
 
51
  }
52
  ]
53
 
 
54
  response = summarize(user_input)
55
 
 
56
  st.session_state['messages'].append({
57
  "role": "assistant",
58
+ "content": response
59
+ })
60
 
61
  st.write(format_messages_for_summary(st.session_state['messages']))
62
 
 
66
 
67
 
68
  if __name__ == '__main__':
69
+ main()