Ismetdh commited on
Commit
dc15e04
·
verified ·
1 Parent(s): d002892

Upload summarizer.py

Browse files
Files changed (1) hide show
  1. summarizer.py +154 -0
summarizer.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import nltk
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+
6
+ # Configure logger to print to console
7
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # Download necessary NLTK data (if not already downloaded)
11
+ nltk.download('punkt')
12
+
13
+ class Summarizer:
14
+ def __init__(self, model_path, tokenizer_path):
15
+ """
16
+ Initialize the summarizer with a fine-tuned model and tokenizer.
17
+ Both model and tokenizer are loaded from the same directory.
18
+ """
19
+ logger.info(f"Initializing Summarizer with model_path: {model_path} and tokenizer_path: {tokenizer_path}")
20
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
21
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
22
+ self.tokenizer_path = tokenizer_path
23
+
24
+ # Set device to GPU if available
25
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ self.model.to(self.device)
27
+ logger.info(f"Model loaded on device: {self.device}")
28
+
29
+ def model_summarize(self, text_chunk,
30
+ max_length=200,
31
+ min_length=30,
32
+ num_beams=4,
33
+ temperature=0.7,
34
+ top_k=50,
35
+ top_p=0.95):
36
+ """
37
+ Summarizes a text chunk using the fine-tuned model.
38
+ The prompt instructs the model to include explicit noun references.
39
+ """
40
+ logger.info(f"Summarizing text chunk of {len(text_chunk.split())} words.")
41
+
42
+ # Re-load tokenizer from the given path (as in original code)
43
+ self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
44
+ input_text = "summarize : " + text_chunk
45
+ inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
46
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
47
+
48
+ with torch.no_grad():
49
+ output_ids = self.model.generate(
50
+ **inputs,
51
+ max_length=max_length,
52
+ num_beams=num_beams,
53
+ temperature=temperature,
54
+ top_k=top_k,
55
+ top_p=top_p,
56
+ do_sample=True,
57
+ early_stopping=True
58
+ )
59
+
60
+ summary = self.tokenizer.decode(output_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
61
+ logger.info("Summary generated.")
62
+ return summary
63
+
64
+ def split_into_sentences(self, text):
65
+ """
66
+ Splits the text into sentences using NLTK.
67
+ """
68
+ sentences = nltk.sent_tokenize(text)
69
+ logger.info(f"Text split into {len(sentences)} sentences.")
70
+ return sentences
71
+
72
+ def chunk_sentences(self, sentences):
73
+ """
74
+ Groups sentences into chunks.
75
+ Each chunk contains as many sentences as possible while keeping its total word count below 300.
76
+ Only chunks with at least 50 words are kept; chunks with fewer words are discarded.
77
+ """
78
+ logger.info("Starting sentence chunking.")
79
+ chunks = []
80
+ current_chunk = []
81
+ current_word_count = 0
82
+
83
+ for sentence in sentences:
84
+ sentence_word_count = len(sentence.split())
85
+
86
+ # If adding this sentence keeps the chunk under 300 words, add it.
87
+ if current_word_count + sentence_word_count <= 300:
88
+ current_chunk.append(sentence)
89
+ current_word_count += sentence_word_count
90
+ else:
91
+ # If current chunk meets the minimum word requirement, add it to the chunks list.
92
+ if current_word_count >= 50:
93
+ chunks.append(" ".join(current_chunk))
94
+ logger.info(f"Created a chunk with {current_word_count} words.")
95
+ # Start a new chunk with the current sentence.
96
+ current_chunk = [sentence]
97
+ current_word_count = sentence_word_count
98
+
99
+ # After the loop, add the last chunk if it meets the minimum requirement.
100
+ if current_word_count >= 75:
101
+ chunks.append(" ".join(current_chunk))
102
+ logger.info(f"Final chunk created with {current_word_count} words.")
103
+
104
+ logger.info(f"Total chunks created: {len(chunks)}")
105
+ return chunks
106
+
107
+ def recursive_summarize(self, text, threshold=50):
108
+ """
109
+ Recursively summarizes the text until its word count is below the threshold.
110
+ If the combined summary consists of a single sentence (even if its length is above the threshold),
111
+ the recursion stops.
112
+ """
113
+ logger.info(f"Recursive summarization called on text with {len(text.split())} words.")
114
+ if len(text.split()) <= threshold:
115
+ logger.info("Text is below the threshold; returning original text.")
116
+ return text
117
+
118
+ sentences = self.split_into_sentences(text)
119
+ if not sentences:
120
+ logger.warning("No sentences found; returning original text.")
121
+ return text # Edge case if sentence splitting fails
122
+
123
+ chunks = self.chunk_sentences(sentences)
124
+ logger.info("Generating summaries for each chunk.")
125
+ summaries = [self.model_summarize(chunk) for chunk in chunks]
126
+ combined_summary = " ".join(summaries)
127
+ logger.info(f"Combined summary length: {len(combined_summary.split())} words.")
128
+
129
+ # Check if the combined summary is a single sentence; if so, stop recursion.
130
+ summary_sentences = self.split_into_sentences(combined_summary)
131
+ if len(summary_sentences) == 1:
132
+ logger.info("Combined summary consists of a single sentence; returning summary without further recursion.")
133
+ return combined_summary
134
+
135
+ if len(combined_summary.split()) > threshold:
136
+ logger.info("Combined summary exceeds threshold; recursing further.")
137
+ return self.recursive_summarize(combined_summary, threshold)
138
+ else:
139
+ logger.info("Combined summary meets threshold; summarization complete.")
140
+ return combined_summary
141
+
142
+ def iterative_summarization(self, text, threshold=50):
143
+ """
144
+ Alias for recursive_summarize to maintain compatibility with fetch_top_news.py.
145
+ """
146
+ logger.info("Starting iterative summarization.")
147
+ return self.recursive_summarize(text, threshold)
148
+
149
+ # if __name__ == "__main__":
150
+ # # Example test block to verify functionality.
151
+ # text = """Your test text here."""
152
+ # summarizer = Summarizer("beta./model", "beta./model")
153
+ # final_summary = summarizer.iterative_summarization(text, threshold=50)
154
+ # print(final_summary)