Ismetdh commited on
Commit
c555633
·
verified ·
1 Parent(s): 43fbc78

Update summarizer.py

Browse files
Files changed (1) hide show
  1. summarizer.py +155 -154
summarizer.py CHANGED
@@ -1,154 +1,155 @@
1
- import logging
2
- import nltk
3
- import torch
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
-
6
- # Configure logger to print to console
7
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
8
- logger = logging.getLogger(__name__)
9
-
10
- # Download necessary NLTK data (if not already downloaded)
11
- nltk.download('punkt')
12
-
13
- class Summarizer:
14
- def __init__(self, model_path, tokenizer_path):
15
- """
16
- Initialize the summarizer with a fine-tuned model and tokenizer.
17
- Both model and tokenizer are loaded from the same directory.
18
- """
19
- logger.info(f"Initializing Summarizer with model_path: {model_path} and tokenizer_path: {tokenizer_path}")
20
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
21
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
22
- self.tokenizer_path = tokenizer_path
23
-
24
- # Set device to GPU if available
25
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
26
- self.model.to(self.device)
27
- logger.info(f"Model loaded on device: {self.device}")
28
-
29
- def model_summarize(self, text_chunk,
30
- max_length=200,
31
- min_length=30,
32
- num_beams=4,
33
- temperature=0.7,
34
- top_k=50,
35
- top_p=0.95):
36
- """
37
- Summarizes a text chunk using the fine-tuned model.
38
- The prompt instructs the model to include explicit noun references.
39
- """
40
- logger.info(f"Summarizing text chunk of {len(text_chunk.split())} words.")
41
-
42
- # Re-load tokenizer from the given path (as in original code)
43
- self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
44
- input_text = "summarize : " + text_chunk
45
- inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
46
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
47
-
48
- with torch.no_grad():
49
- output_ids = self.model.generate(
50
- **inputs,
51
- max_length=max_length,
52
- num_beams=num_beams,
53
- temperature=temperature,
54
- top_k=top_k,
55
- top_p=top_p,
56
- do_sample=True,
57
- early_stopping=True
58
- )
59
-
60
- summary = self.tokenizer.decode(output_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
61
- logger.info("Summary generated.")
62
- return summary
63
-
64
- def split_into_sentences(self, text):
65
- """
66
- Splits the text into sentences using NLTK.
67
- """
68
- sentences = nltk.sent_tokenize(text)
69
- logger.info(f"Text split into {len(sentences)} sentences.")
70
- return sentences
71
-
72
- def chunk_sentences(self, sentences):
73
- """
74
- Groups sentences into chunks.
75
- Each chunk contains as many sentences as possible while keeping its total word count below 300.
76
- Only chunks with at least 50 words are kept; chunks with fewer words are discarded.
77
- """
78
- logger.info("Starting sentence chunking.")
79
- chunks = []
80
- current_chunk = []
81
- current_word_count = 0
82
-
83
- for sentence in sentences:
84
- sentence_word_count = len(sentence.split())
85
-
86
- # If adding this sentence keeps the chunk under 300 words, add it.
87
- if current_word_count + sentence_word_count <= 300:
88
- current_chunk.append(sentence)
89
- current_word_count += sentence_word_count
90
- else:
91
- # If current chunk meets the minimum word requirement, add it to the chunks list.
92
- if current_word_count >= 50:
93
- chunks.append(" ".join(current_chunk))
94
- logger.info(f"Created a chunk with {current_word_count} words.")
95
- # Start a new chunk with the current sentence.
96
- current_chunk = [sentence]
97
- current_word_count = sentence_word_count
98
-
99
- # After the loop, add the last chunk if it meets the minimum requirement.
100
- if current_word_count >= 75:
101
- chunks.append(" ".join(current_chunk))
102
- logger.info(f"Final chunk created with {current_word_count} words.")
103
-
104
- logger.info(f"Total chunks created: {len(chunks)}")
105
- return chunks
106
-
107
- def recursive_summarize(self, text, threshold=50):
108
- """
109
- Recursively summarizes the text until its word count is below the threshold.
110
- If the combined summary consists of a single sentence (even if its length is above the threshold),
111
- the recursion stops.
112
- """
113
- logger.info(f"Recursive summarization called on text with {len(text.split())} words.")
114
- if len(text.split()) <= threshold:
115
- logger.info("Text is below the threshold; returning original text.")
116
- return text
117
-
118
- sentences = self.split_into_sentences(text)
119
- if not sentences:
120
- logger.warning("No sentences found; returning original text.")
121
- return text # Edge case if sentence splitting fails
122
-
123
- chunks = self.chunk_sentences(sentences)
124
- logger.info("Generating summaries for each chunk.")
125
- summaries = [self.model_summarize(chunk) for chunk in chunks]
126
- combined_summary = " ".join(summaries)
127
- logger.info(f"Combined summary length: {len(combined_summary.split())} words.")
128
-
129
- # Check if the combined summary is a single sentence; if so, stop recursion.
130
- summary_sentences = self.split_into_sentences(combined_summary)
131
- if len(summary_sentences) == 1:
132
- logger.info("Combined summary consists of a single sentence; returning summary without further recursion.")
133
- return combined_summary
134
-
135
- if len(combined_summary.split()) > threshold:
136
- logger.info("Combined summary exceeds threshold; recursing further.")
137
- return self.recursive_summarize(combined_summary, threshold)
138
- else:
139
- logger.info("Combined summary meets threshold; summarization complete.")
140
- return combined_summary
141
-
142
- def iterative_summarization(self, text, threshold=50):
143
- """
144
- Alias for recursive_summarize to maintain compatibility with fetch_top_news.py.
145
- """
146
- logger.info("Starting iterative summarization.")
147
- return self.recursive_summarize(text, threshold)
148
-
149
- # if __name__ == "__main__":
150
- # # Example test block to verify functionality.
151
- # text = """Your test text here."""
152
- # summarizer = Summarizer("beta./model", "beta./model")
153
- # final_summary = summarizer.iterative_summarization(text, threshold=50)
154
- # print(final_summary)
 
 
1
+ import logging
2
+ import nltk
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+
6
+ # Configure logger to print to console
7
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # Download necessary NLTK data (if not already downloaded)
11
+ nltk.download('punkt')
12
+
13
+ class Summarizer:
14
+ def __init__(self, model_path, tokenizer_path):
15
+ """
16
+ Initialize the summarizer with a fine-tuned model and tokenizer.
17
+ Both model and tokenizer are loaded from the same directory.
18
+ """
19
+ logger.info(f"Initializing Summarizer with model_path: {model_path} and tokenizer_path: {tokenizer_path}")
20
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
21
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
22
+ self.tokenizer_path = tokenizer_path
23
+
24
+ # Set device to GPU if available
25
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ self.model.to(self.device)
27
+ logger.info(f"Model loaded on device: {self.device}")
28
+
29
+ def model_summarize(self, text_chunk,
30
+ max_length=200,
31
+ min_length=30,
32
+ num_beams=4,
33
+ temperature=0.7,
34
+ top_k=50,
35
+ top_p=0.95):
36
+ """
37
+ Summarizes a text chunk using the fine-tuned model.
38
+ The prompt instructs the model to include explicit noun references.
39
+ """
40
+ logger.info(f"Summarizing text chunk of {len(text_chunk.split())} words.")
41
+
42
+ # Re-load tokenizer from the given path (as in original code)
43
+ self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
44
+ input_text = "summarize : " + text_chunk
45
+ inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
46
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
47
+
48
+ with torch.no_grad():
49
+ output_ids = self.model.generate(
50
+ **inputs,
51
+ max_length=max_length,
52
+ min_length=30
53
+ num_beams=num_beams,
54
+ temperature=temperature,
55
+ top_k=top_k,
56
+ top_p=top_p,
57
+ do_sample=True,
58
+ early_stopping=True
59
+ )
60
+
61
+ summary = self.tokenizer.decode(output_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
62
+ logger.info("Summary generated.")
63
+ return summary
64
+
65
+ def split_into_sentences(self, text):
66
+ """
67
+ Splits the text into sentences using NLTK.
68
+ """
69
+ sentences = nltk.sent_tokenize(text)
70
+ logger.info(f"Text split into {len(sentences)} sentences.")
71
+ return sentences
72
+
73
+ def chunk_sentences(self, sentences):
74
+ """
75
+ Groups sentences into chunks.
76
+ Each chunk contains as many sentences as possible while keeping its total word count below 300.
77
+ Only chunks with at least 50 words are kept; chunks with fewer words are discarded.
78
+ """
79
+ logger.info("Starting sentence chunking.")
80
+ chunks = []
81
+ current_chunk = []
82
+ current_word_count = 0
83
+
84
+ for sentence in sentences:
85
+ sentence_word_count = len(sentence.split())
86
+
87
+ # If adding this sentence keeps the chunk under 300 words, add it.
88
+ if current_word_count + sentence_word_count <= 300:
89
+ current_chunk.append(sentence)
90
+ current_word_count += sentence_word_count
91
+ else:
92
+ # If current chunk meets the minimum word requirement, add it to the chunks list.
93
+ if current_word_count >= 75:
94
+ chunks.append(" ".join(current_chunk))
95
+ logger.info(f"Created a chunk with {current_word_count} words.")
96
+ # Start a new chunk with the current sentence.
97
+ current_chunk = [sentence]
98
+ current_word_count = sentence_word_count
99
+
100
+ # After the loop, add the last chunk if it meets the minimum requirement.
101
+ if current_word_count >= 75:
102
+ chunks.append(" ".join(current_chunk))
103
+ logger.info(f"Final chunk created with {current_word_count} words.")
104
+
105
+ logger.info(f"Total chunks created: {len(chunks)}")
106
+ return chunks
107
+
108
+ def recursive_summarize(self, text, threshold=75):
109
+ """
110
+ Recursively summarizes the text until its word count is below the threshold.
111
+ If the combined summary consists of a single sentence (even if its length is above the threshold),
112
+ the recursion stops.
113
+ """
114
+ logger.info(f"Recursive summarization called on text with {len(text.split())} words.")
115
+ if len(text.split()) <= threshold:
116
+ logger.info("Text is below the threshold; returning original text.")
117
+ return text
118
+
119
+ sentences = self.split_into_sentences(text)
120
+ if not sentences:
121
+ logger.warning("No sentences found; returning original text.")
122
+ return text # Edge case if sentence splitting fails
123
+
124
+ chunks = self.chunk_sentences(sentences)
125
+ logger.info("Generating summaries for each chunk.")
126
+ summaries = [self.model_summarize(chunk) for chunk in chunks]
127
+ combined_summary = " ".join(summaries)
128
+ logger.info(f"Combined summary length: {len(combined_summary.split())} words.")
129
+
130
+ # Check if the combined summary is a single sentence; if so, stop recursion.
131
+ summary_sentences = self.split_into_sentences(combined_summary)
132
+ if len(summary_sentences) == 1:
133
+ logger.info("Combined summary consists of a single sentence; returning summary without further recursion.")
134
+ return combined_summary
135
+
136
+ if len(combined_summary.split()) > threshold:
137
+ logger.info("Combined summary exceeds threshold; recursing further.")
138
+ return self.recursive_summarize(combined_summary, threshold)
139
+ else:
140
+ logger.info("Combined summary meets threshold; summarization complete.")
141
+ return combined_summary
142
+
143
+ def iterative_summarization(self, text, threshold=75):
144
+ """
145
+ Alias for recursive_summarize to maintain compatibility with fetch_top_news.py.
146
+ """
147
+ logger.info("Starting iterative summarization.")
148
+ return self.recursive_summarize(text, threshold)
149
+
150
+ # if __name__ == "__main__":
151
+ # # Example test block to verify functionality.
152
+ # text = """Your test text here."""
153
+ # summarizer = Summarizer("beta./model", "beta./model")
154
+ # final_summary = summarizer.iterative_summarization(text, threshold=50)
155
+ # print(final_summary)