import logging import nltk import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Configure logger to print to console logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Download necessary NLTK data (if not already downloaded) nltk.download('punkt') class Summarizer: def __init__(self, model_path, tokenizer_path): """ Initialize the summarizer with a fine-tuned model and tokenizer. Both model and tokenizer are loaded from the same directory. """ logger.info(f"Initializing Summarizer with model_path: {model_path} and tokenizer_path: {tokenizer_path}") self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path) self.tokenizer_path = tokenizer_path # Set device to GPU if available self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) logger.info(f"Model loaded on device: {self.device}") def model_summarize(self, text_chunk, max_length=200, min_length=30, num_beams=4, temperature=0.7, top_k=50, top_p=0.95): """ Summarizes a text chunk using the fine-tuned model. The prompt instructs the model to include explicit noun references. """ logger.info(f"Summarizing text chunk of {len(text_chunk.split())} words.") # Re-load tokenizer from the given path (as in original code) self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path) input_text = "summarize : " + text_chunk inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): output_ids = self.model.generate( **inputs, max_length=max_length, min_length=30, num_beams=num_beams, temperature=temperature, top_k=top_k, top_p=top_p, do_sample=True, early_stopping=True ) summary = self.tokenizer.decode(output_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) logger.info("Summary generated.") return summary def split_into_sentences(self, text): """ Splits the text into sentences using NLTK. """ sentences = nltk.sent_tokenize(text) logger.info(f"Text split into {len(sentences)} sentences.") return sentences def chunk_sentences(self, sentences): """ Groups sentences into chunks. Each chunk contains as many sentences as possible while keeping its total word count below 300. Only chunks with at least 50 words are kept; chunks with fewer words are discarded. """ logger.info("Starting sentence chunking.") chunks = [] current_chunk = [] current_word_count = 0 for sentence in sentences: sentence_word_count = len(sentence.split()) # If adding this sentence keeps the chunk under 300 words, add it. if current_word_count + sentence_word_count <= 300: current_chunk.append(sentence) current_word_count += sentence_word_count else: # If current chunk meets the minimum word requirement, add it to the chunks list. if current_word_count >= 75: chunks.append(" ".join(current_chunk)) logger.info(f"Created a chunk with {current_word_count} words.") # Start a new chunk with the current sentence. current_chunk = [sentence] current_word_count = sentence_word_count # After the loop, add the last chunk if it meets the minimum requirement. if current_word_count >= 75: chunks.append(" ".join(current_chunk)) logger.info(f"Final chunk created with {current_word_count} words.") logger.info(f"Total chunks created: {len(chunks)}") return chunks def recursive_summarize(self, text, threshold=75): """ Recursively summarizes the text until its word count is below the threshold. If the combined summary consists of a single sentence (even if its length is above the threshold), the recursion stops. """ logger.info(f"Recursive summarization called on text with {len(text.split())} words.") if len(text.split()) <= threshold: logger.info("Text is below the threshold; returning original text.") return text sentences = self.split_into_sentences(text) if not sentences: logger.warning("No sentences found; returning original text.") return text # Edge case if sentence splitting fails chunks = self.chunk_sentences(sentences) logger.info("Generating summaries for each chunk.") summaries = [self.model_summarize(chunk) for chunk in chunks] combined_summary = " ".join(summaries) logger.info(f"Combined summary length: {len(combined_summary.split())} words.") # Check if the combined summary is a single sentence; if so, stop recursion. summary_sentences = self.split_into_sentences(combined_summary) if len(summary_sentences) == 1: logger.info("Combined summary consists of a single sentence; returning summary without further recursion.") return combined_summary if len(combined_summary.split()) > threshold: logger.info("Combined summary exceeds threshold; recursing further.") return self.recursive_summarize(combined_summary, threshold) else: logger.info("Combined summary meets threshold; summarization complete.") return combined_summary def iterative_summarization(self, text, threshold=75): """ Alias for recursive_summarize to maintain compatibility with fetch_top_news.py. """ logger.info("Starting iterative summarization.") return self.recursive_summarize(text, threshold) # if __name__ == "__main__": # # Example test block to verify functionality. # text = """Your test text here.""" # summarizer = Summarizer("beta./model", "beta./model") # final_summary = summarizer.iterative_summarization(text, threshold=50) # print(final_summary)