fix transformers
Browse files
app.py
CHANGED
|
@@ -32,6 +32,8 @@ from operator import itemgetter
|
|
| 32 |
|
| 33 |
import gradio as gr
|
| 34 |
|
|
|
|
|
|
|
| 35 |
global df
|
| 36 |
bearer_token = 'AAAAAAAAAAAAAAAAAAAAACEigwEAAAAACoP8KHJYLOKCL4OyB9LEPV00VB0%3DmyeDROUvw4uipHwvbPPfnTuY0M9ORrLuXrMvcByqZhwo3SUc4F'
|
| 37 |
client = tweepy.Client(bearer_token=bearer_token)
|
|
@@ -392,11 +394,10 @@ def full_lda(df):
|
|
| 392 |
return top_tweets
|
| 393 |
|
| 394 |
def topic_summarization(topic_groups):
|
| 395 |
-
|
| 396 |
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
model = model.to(device)
|
| 400 |
translator = Translator()
|
| 401 |
|
| 402 |
headlines = []
|
|
@@ -410,8 +411,8 @@ def topic_summarization(topic_groups):
|
|
| 410 |
max_len = 256
|
| 411 |
|
| 412 |
encoding = tokenizer.encode_plus(text, return_tensors = "pt")
|
| 413 |
-
input_ids = encoding["input_ids"]
|
| 414 |
-
attention_masks = encoding["attention_mask"]
|
| 415 |
|
| 416 |
beam_outputs = model.generate(
|
| 417 |
input_ids = input_ids,
|
|
|
|
| 32 |
|
| 33 |
import gradio as gr
|
| 34 |
|
| 35 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 36 |
+
|
| 37 |
global df
|
| 38 |
bearer_token = 'AAAAAAAAAAAAAAAAAAAAACEigwEAAAAACoP8KHJYLOKCL4OyB9LEPV00VB0%3DmyeDROUvw4uipHwvbPPfnTuY0M9ORrLuXrMvcByqZhwo3SUc4F'
|
| 39 |
client = tweepy.Client(bearer_token=bearer_token)
|
|
|
|
| 394 |
return top_tweets
|
| 395 |
|
| 396 |
def topic_summarization(topic_groups):
|
| 397 |
+
|
| 398 |
|
| 399 |
+
tokenizer = AutoTokenizer.from_pretrained("Michau/t5-base-en-generate-headline")
|
| 400 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("Michau/t5-base-en-generate-headline")
|
|
|
|
| 401 |
translator = Translator()
|
| 402 |
|
| 403 |
headlines = []
|
|
|
|
| 411 |
max_len = 256
|
| 412 |
|
| 413 |
encoding = tokenizer.encode_plus(text, return_tensors = "pt")
|
| 414 |
+
input_ids = encoding["input_ids"]
|
| 415 |
+
attention_masks = encoding["attention_mask"]
|
| 416 |
|
| 417 |
beam_outputs = model.generate(
|
| 418 |
input_ids = input_ids,
|