ZinebSN commited on
Commit
57f1d8d
·
1 Parent(s): ebf8670

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ # Import GPT2 Model and Tokenizer
6
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
7
+
8
+ # Import T5 Model and Tokenizer
9
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
10
+
11
+ st.title("Text Summarizer")
12
+
13
+ models = {
14
+ "T5 Small": "ZinebSN/T5_summarizer",
15
+ "GPT2": "ZinebSN/GPT2_summarizer"
16
+ }
17
+
18
+ selected_model = st.radio("Select Model", list(models.keys()))
19
+
20
+ model_name = models[selected_model]
21
+
22
+ if selected_model=='GPT2':
23
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
24
+ model = GPT2LMHeadModel.from_pretrained(model_name)
25
+ else:
26
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
27
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
28
+
29
+ # Inference function for GPT2
30
+ def gpt2_summarize(input_text, tokenizer, model, length):
31
+ text=tokenizer.encode_plus(f'<bos> {input_text} <sep>', truncation=True, max_length=1024).input_ids
32
+ text_length=len(text)
33
+ text = torch.tensor(text, dtype=torch.long, device=device)
34
+ text = text.unsqueeze(0)
35
+ generated = text
36
+ model = model.to(device)
37
+ with torch.no_grad():
38
+ for _ in range(length):
39
+ inputs = {'input_ids': generated}
40
+ outputs = model(**inputs)
41
+ next_token_logits = outputs[0][0, -1, :]
42
+ next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1)
43
+ generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
44
+ generated=generated[:, -1024:]
45
+ generated = generated[0, text_length:]
46
+ text = tokenizer.convert_ids_to_tokens(generated,skip_special_tokens=True)
47
+ text = tokenizer.convert_tokens_to_string(text)
48
+ return text
49
+
50
+ # Inference function for T5
51
+ def t5_summarize(input_text, tokenizer, model):
52
+ inputs=tokenizer('summarize: '+input_text, truncation=True, padding='max_length', max_length=600, return_tensors='pt')
53
+ output_sequence=model.generate(input_ids=inputs["input_ids"],attention_mask=inputs["attention_mask"], max_new_tokens=100)
54
+ summary = tokenizer.batch_decode(output_sequence, skip_special_tokens=True)
55
+ return summary[0]
56
+
57
+
58
+ input_text=st.text_area("Input the text to summarize","", height=300)
59
+ if st.button("Summarize"):
60
+ st.text("It may take a minute or two.")
61
+ nwords=len(input_text.split(" "))
62
+ if selected_model=='GPT2':
63
+ summary=gpt2_summarize(input_text, tokenizer, model, 30)
64
+ else:
65
+ summary=t5_summarize(input_text, tokenizer, model)
66
+
67
+ st.header("Summary")
68
+ st.markdown(summary)