from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch import logging class TextSummarizer: def __init__(self, model_name="facebook/bart-large-cnn"): """ Initialize summarization model directly without using pipeline Args: model_name (str): Hugging Face model for summarization """ try: # Force CPU usage and disable GPU self.device = torch.device('cpu') # Initialize tokenizer and model separately self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) # Move model to CPU and eval mode self.model = self.model.to(self.device) self.model.eval() logging.info("Summarization model initialized successfully") except Exception as e: logging.error(f"Failed to load summarization model: {str(e)}") raise RuntimeError(f"Failed to load summarization model: {str(e)}") def generate_summary(self, text, max_length=400, min_length=100): """ Generate summary for given text Args: text (str): Input text to summarize max_length (int): Maximum length of summary min_length (int): Minimum length of summary Returns: str: Generated summary """ try: # Validate input text if not text or len(text.strip()) == 0: return "No text provided for summarization." # Ensure min_length is less than max_length min_length = min(min_length, max_length) # Process text in chunks due to length limitations max_chunk_length = 1024 # BART's max input length chunks = [text[i:i + max_chunk_length] for i in range(0, len(text), max_chunk_length)] summaries = [] with torch.no_grad(): # Disable gradient calculation for chunk in chunks: if chunk.strip(): # Tokenize inputs = self.tokenizer(chunk, max_length=1024, truncation=True, return_tensors="pt") inputs = inputs.to(self.device) # Generate summary summary_ids = self.model.generate( inputs["input_ids"], num_beams=4, max_length=max_length // len(chunks), min_length=min_length // len(chunks), length_penalty=2.0, early_stopping=True ) # Decode summary summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True) summaries.append(summary) return " ".join(summaries) except Exception as e: logging.error(f"Error during summarization: {str(e)}") return f"Error during summarization: {str(e)}"