| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| import torch |
| import logging |
|
|
| class TextSummarizer: |
| def __init__(self, model_name="facebook/bart-large-cnn"): |
| """ |
| Initialize summarization model directly without using pipeline |
| |
| Args: |
| model_name (str): Hugging Face model for summarization |
| """ |
| try: |
| |
| self.device = torch.device('cpu') |
| |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
| |
| |
| self.model = self.model.to(self.device) |
| self.model.eval() |
| |
| logging.info("Summarization model initialized successfully") |
| |
| except Exception as e: |
| logging.error(f"Failed to load summarization model: {str(e)}") |
| raise RuntimeError(f"Failed to load summarization model: {str(e)}") |
| |
| def generate_summary(self, text, max_length=400, min_length=100): |
| """ |
| Generate summary for given text |
| |
| Args: |
| text (str): Input text to summarize |
| max_length (int): Maximum length of summary |
| min_length (int): Minimum length of summary |
| |
| Returns: |
| str: Generated summary |
| """ |
| try: |
| |
| if not text or len(text.strip()) == 0: |
| return "No text provided for summarization." |
| |
| |
| min_length = min(min_length, max_length) |
| |
| |
| max_chunk_length = 1024 |
| chunks = [text[i:i + max_chunk_length] for i in range(0, len(text), max_chunk_length)] |
| summaries = [] |
| |
| with torch.no_grad(): |
| for chunk in chunks: |
| if chunk.strip(): |
| |
| inputs = self.tokenizer(chunk, max_length=1024, truncation=True, |
| return_tensors="pt") |
| inputs = inputs.to(self.device) |
| |
| |
| summary_ids = self.model.generate( |
| inputs["input_ids"], |
| num_beams=4, |
| max_length=max_length // len(chunks), |
| min_length=min_length // len(chunks), |
| length_penalty=2.0, |
| early_stopping=True |
| ) |
| |
| |
| summary = self.tokenizer.decode(summary_ids[0], |
| skip_special_tokens=True) |
| summaries.append(summary) |
| |
| return " ".join(summaries) |
| |
| except Exception as e: |
| logging.error(f"Error during summarization: {str(e)}") |
| return f"Error during summarization: {str(e)}" |