File size: 6,797 Bytes
c555633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80e2cd3
c555633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import logging
import nltk
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Configure logger to print to console
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Download necessary NLTK data (if not already downloaded)
nltk.download('punkt')

class Summarizer:
    def __init__(self, model_path, tokenizer_path):
        """
        Initialize the summarizer with a fine-tuned model and tokenizer.
        Both model and tokenizer are loaded from the same directory.
        """
        logger.info(f"Initializing Summarizer with model_path: {model_path} and tokenizer_path: {tokenizer_path}")
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
        self.tokenizer_path = tokenizer_path
        
        # Set device to GPU if available
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)
        logger.info(f"Model loaded on device: {self.device}")
    
    def model_summarize(self, text_chunk, 
                        max_length=200, 
                        min_length=30,
                        num_beams=4, 
                        temperature=0.7, 
                        top_k=50, 
                        top_p=0.95):
        """
        Summarizes a text chunk using the fine-tuned model.
        The prompt instructs the model to include explicit noun references.
        """
        logger.info(f"Summarizing text chunk of {len(text_chunk.split())} words.")
        
        # Re-load tokenizer from the given path (as in original code)
        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
        input_text = "summarize : " + text_chunk
        inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            output_ids = self.model.generate(
                **inputs,
                max_length=max_length,
                min_length=30,
                num_beams=num_beams,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                do_sample=True,
                early_stopping=True
            )
        
        summary = self.tokenizer.decode(output_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
        logger.info("Summary generated.")
        return summary

    def split_into_sentences(self, text):
        """
        Splits the text into sentences using NLTK.
        """
        sentences = nltk.sent_tokenize(text)
        logger.info(f"Text split into {len(sentences)} sentences.")
        return sentences
    
    def chunk_sentences(self, sentences):
        """
        Groups sentences into chunks.
        Each chunk contains as many sentences as possible while keeping its total word count below 300.
        Only chunks with at least 50 words are kept; chunks with fewer words are discarded.
        """
        logger.info("Starting sentence chunking.")
        chunks = []
        current_chunk = []
        current_word_count = 0

        for sentence in sentences:
            sentence_word_count = len(sentence.split())
            
            # If adding this sentence keeps the chunk under 300 words, add it.
            if current_word_count + sentence_word_count <= 300:
                current_chunk.append(sentence)
                current_word_count += sentence_word_count
            else:
                # If current chunk meets the minimum word requirement, add it to the chunks list.
                if current_word_count >= 75:
                    chunks.append(" ".join(current_chunk))
                    logger.info(f"Created a chunk with {current_word_count} words.")
                # Start a new chunk with the current sentence.
                current_chunk = [sentence]
                current_word_count = sentence_word_count

        # After the loop, add the last chunk if it meets the minimum requirement.
        if current_word_count >= 75:
            chunks.append(" ".join(current_chunk))
            logger.info(f"Final chunk created with {current_word_count} words.")
        
        logger.info(f"Total chunks created: {len(chunks)}")
        return chunks
    
    def recursive_summarize(self, text, threshold=75):
        """
        Recursively summarizes the text until its word count is below the threshold.
        If the combined summary consists of a single sentence (even if its length is above the threshold),
        the recursion stops.
        """
        logger.info(f"Recursive summarization called on text with {len(text.split())} words.")
        if len(text.split()) <= threshold:
            logger.info("Text is below the threshold; returning original text.")
            return text

        sentences = self.split_into_sentences(text)
        if not sentences:
            logger.warning("No sentences found; returning original text.")
            return text  # Edge case if sentence splitting fails

        chunks = self.chunk_sentences(sentences)
        logger.info("Generating summaries for each chunk.")
        summaries = [self.model_summarize(chunk) for chunk in chunks]
        combined_summary = " ".join(summaries)
        logger.info(f"Combined summary length: {len(combined_summary.split())} words.")

        # Check if the combined summary is a single sentence; if so, stop recursion.
        summary_sentences = self.split_into_sentences(combined_summary)
        if len(summary_sentences) == 1:
            logger.info("Combined summary consists of a single sentence; returning summary without further recursion.")
            return combined_summary

        if len(combined_summary.split()) > threshold:
            logger.info("Combined summary exceeds threshold; recursing further.")
            return self.recursive_summarize(combined_summary, threshold)
        else:
            logger.info("Combined summary meets threshold; summarization complete.")
            return combined_summary

    def iterative_summarization(self, text, threshold=75):
        """
        Alias for recursive_summarize to maintain compatibility with fetch_top_news.py.
        """
        logger.info("Starting iterative summarization.")
        return self.recursive_summarize(text, threshold)

# if __name__ == "__main__":
#     # Example test block to verify functionality.
#     text = """Your test text here."""
#     summarizer = Summarizer("beta./model", "beta./model")
#     final_summary = summarizer.iterative_summarization(text, threshold=50)
#     print(final_summary)