Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| import torch | |
| MODEL = "xTorch8/fine-tuned-bart" | |
| TOKEN = os.getenv("TOKEN") | |
| MAX_TOKENS = 1024 | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL, token = TOKEN) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL, token = TOKEN) | |
| def summarize_text(text): | |
| try: | |
| chunk_size = MAX_TOKENS * 4 | |
| overlap = chunk_size // 4 | |
| step = chunk_size - overlap | |
| chunks = [text[i:i + chunk_size] for i in range(0, len(text), step)] | |
| summaries = [] | |
| for chunk in chunks: | |
| inputs = tokenizer(chunk, return_tensors = "pt", truncation = True, max_length = 1024, padding = True) | |
| with torch.no_grad(): | |
| summary_ids = model.generate( | |
| **inputs, | |
| max_length = 1500, | |
| length_penalty = 2.0, | |
| num_beams = 4, | |
| early_stopping = True | |
| ) | |
| summary = tokenizer.decode(summary_ids[0], skip_special_tokens = True) | |
| summaries.append(summary) | |
| final_text = " ".join(summaries) | |
| summarization = final_text | |
| if len(final_text) > MAX_TOKENS: | |
| inputs = tokenizer(final_text, return_tensors = "pt", truncation = True, max_length = 1024, padding = True) | |
| with torch.no_grad(): | |
| summary_ids = model.generate( | |
| **inputs, | |
| min_length = 300, | |
| max_length = 1500, | |
| length_penalty = 2.0, | |
| num_beams = 4, | |
| early_stopping = True | |
| ) | |
| summarization = tokenizer.decode(summary_ids[0], skip_special_tokens = True) | |
| else: | |
| summarization = final_text | |
| return summarization | |
| except Exception as e: | |
| return e | |
| demo = gr.Interface( | |
| fn = summarize_text, | |
| inputs = gr.Textbox(lines = 20, label = "Input Text"), | |
| outputs = "text", | |
| title = "BART Summarizer" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |