pavanhitloop commited on
Commit
c491574
·
1 Parent(s): f77573d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -0
app.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MBartForConditionalGeneration
3
+ import torch
4
+ import gradio as gr
5
+ import requests
6
+ import json
7
+
8
+
9
+ class LTRC_Translation_API():
10
+ def __init__(self, url = 'https://ssmt.iiit.ac.in/onemt', src_lang = 'eng', tgt_lang = 'te'):
11
+ self.lang_map = {'te': 'tel', 'en': 'eng', 'ta': 'tam', 'ml': 'mal', 'mr': 'mar', 'kn': 'kan', 'hi': 'hin'}
12
+ self.url = url
13
+
14
+ self.headers = {
15
+ 'Content-Type': 'application/json',
16
+ 'Accept': 'application/json'
17
+ }
18
+
19
+ lang = self.lang_map.get(tgt_lang, 'te')
20
+
21
+ self.src_lang = src_lang
22
+ self.tgt_lang = lang
23
+
24
+ def translate(self, text):
25
+ try:
26
+ data = {'text': text, 'source_language': self.src_lang, 'target_language': self.tgt_lang}
27
+
28
+ response = requests.post(self.url, headers = self.headers, json = data)
29
+ translated_text = json.loads(response.text).get('data', '')
30
+
31
+ return translated_text
32
+
33
+ except Exception as e:
34
+ print("Exception: ", e)
35
+
36
+ return ''
37
+
38
+
39
+ class Headline_Generation():
40
+ def __init__(self, model_name = "ai4bharat/MultiIndicHeadlineGenerationSS"):
41
+ self.model_name = model_name
42
+
43
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
44
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False, use_fast=False, keep_accents=True)
45
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
46
+ self.model.to(self.device)
47
+ self.model.eval()
48
+
49
+ self.bos_id = self.tokenizer._convert_token_to_id_with_added_voc("<s>")
50
+ self.eos_id = self.tokenizer._convert_token_to_id_with_added_voc("</s>")
51
+ self.pad_id = self.tokenizer._convert_token_to_id_with_added_voc("<pad>")
52
+
53
+ self.lang_map = {'as': '<2as>', 'bn': '<2bn>', 'en': '<2en>', 'gu': '<2gu>', 'hi': '<2hi>', 'kn': '<2kn>', 'ml': '<2ml>', 'mr': '<2mr>', 'or': '<2or>', 'pa': '<2pa>', 'ta': '<2ta>', 'te': '<2te>'}
54
+
55
+ print("Headline Generation model loaded...!")
56
+
57
+
58
+ def get_headline(self, text, lang_id):
59
+
60
+ inp = self.tokenizer(text, add_special_tokens=False, return_tensors="pt", padding=True).to(self.device)
61
+ inp = inp['input_ids']
62
+
63
+ lang_code = self.lang_map.get(lang_id, '')
64
+
65
+ text = text + "</s> " + lang_code
66
+ # print("Text: ", text)
67
+
68
+ model_output = self.model.generate(
69
+ inp,
70
+ use_cache=True,
71
+ num_beams=5,
72
+ max_length=32,
73
+ min_length=1,
74
+ early_stopping=True,
75
+ pad_token_id = self.pad_id,
76
+ bos_token_id = self.bos_id,
77
+ eos_token_id = self.eos_id,
78
+ decoder_start_token_id = self.tokenizer._convert_token_to_id_with_added_voc(lang_code)
79
+ )
80
+
81
+ decoded_output = self.tokenizer.decode(
82
+ model_output[0],
83
+ skip_special_tokens=True,
84
+ clean_up_tokenization_spaces=False
85
+ )
86
+
87
+ return decoded_output
88
+
89
+
90
+ class Summarization():
91
+ def __init__(self, model_name = "ai4bharat/MultiIndicSentenceSummarizationSS"):
92
+ self.model_name = model_name
93
+
94
+ self.device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
95
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False, use_fast=False, keep_accents=True)
96
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
97
+ self.model.to(self.device)
98
+ self.model.eval()
99
+
100
+ self.bos_id = self.tokenizer._convert_token_to_id_with_added_voc("<s>")
101
+ self.eos_id = self.tokenizer._convert_token_to_id_with_added_voc("</s>")
102
+ self.pad_id = self.tokenizer._convert_token_to_id_with_added_voc("<pad>")
103
+
104
+ self.lang_map = {'as': '<2as>', 'bn': '<2bn>', 'en': '<2en>', 'gu': '<2gu>', 'hi': '<2hi>', 'kn': '<2kn>', 'ml': '<2ml>', 'mr': '<2mr>', 'or': '<2or>', 'pa': '<2pa>', 'ta': '<2ta>', 'te': '<2te>'}
105
+
106
+ print("Summarization model loaded...!")
107
+
108
+
109
+ def get_summary(self, text, lang_id):
110
+
111
+ inp = self.tokenizer(text, add_special_tokens=False, return_tensors="pt", padding=True).to(self.device)
112
+ inp = inp['input_ids']
113
+
114
+ lang_code = self.lang_map.get(lang_id, '')
115
+
116
+ text = text + "</s> " + lang_code
117
+ # print("Text: ", text)
118
+
119
+ model_output = self.model.generate(
120
+ inp,
121
+ use_cache=True,
122
+ num_beams=5,
123
+ max_length=32,
124
+ min_length=1,
125
+ early_stopping=True,
126
+ pad_token_id = self.pad_id,
127
+ bos_token_id = self.bos_id,
128
+ eos_token_id = self.eos_id,
129
+ decoder_start_token_id = self.tokenizer._convert_token_to_id_with_added_voc(lang_code)
130
+ )
131
+
132
+ decoded_output = self.tokenizer.decode(
133
+ model_output[0],
134
+ skip_special_tokens=True,
135
+ clean_up_tokenization_spaces=False
136
+ )
137
+
138
+ return decoded_output
139
+
140
+
141
+ def get_prediction(text, lang_id, translate = False):
142
+ # if len(sys.argv)<3:
143
+ # print("Usage: python app.py <text_file_path> <lang_id>")
144
+ # print("Text file should contain the article news")
145
+ # exit()
146
+
147
+ # txt_path = sys.argv[1]
148
+ # lang_id = sys.argv[2]
149
+
150
+ # if not os.path.exists(txt_path):
151
+ # print("Path: {} do not exists".format(txt_path))
152
+ # exit()
153
+
154
+ # text = ''
155
+ # with open(txt_path, 'r', encoding='utf-8') as fp:
156
+ # text = fp.read().strip()
157
+
158
+ headline_generator = Headline_Generation()
159
+ summarizer = Summarization()
160
+ if translate == True:
161
+ translator = LTRC_Translation_API(tgt_lang = lang_id)
162
+ text = translator.translate(text)
163
+
164
+ headline = headline_generator.get_headline(text, lang_id)
165
+ summary = summarizer.get_summary(text, lang_id)
166
+
167
+
168
+ # print("Article: ", text)
169
+ # print("Summary: ", summary)
170
+ # print("Headline: ", headline)
171
+
172
+ # return "Headline: " + headline + "\nSummary: " + summary
173
+ return [text, summary, headline]
174
+
175
+ interface = gr.Interface(
176
+ get_prediction,
177
+ inputs=[
178
+ gr.Textbox(lines = 8, label = "News Article Text", info = "Provide the news article text here. Check the `Translate` if the source language is english."),
179
+ gr.Dropdown(
180
+ ['as', 'bn', 'en', 'gu', 'hi', 'kn', 'ml', 'mr', 'or', 'pa', 'ta', 'te'], label="Language code", info="select the target language code"
181
+ ),
182
+ gr.Checkbox(label="Translate", info="Is translation required?")
183
+ ],
184
+ outputs=[
185
+ gr.Textbox(lines = 8, label = "Source Article Text", info = "Source article text (if `Translate` is enabled then the source will be translated to target language)"),
186
+ gr.Textbox(lines = 4, label = "Summary", info = "Summary of the given article (translated if `Translate` is enabled)"),
187
+ gr.Textbox(lines = 2, label = "Headline", info = "Generated headline of the given article (translated if `Translate` is enabled)")
188
+ ]
189
+ )
190
+ interface.launch(share=True)