focal / app /summarizer.py
michaelkri
HDBSCAN bug fix
76cc805
import nltk
from nltk import tokenize
from sklearn.cluster import HDBSCAN
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import logging
class Summarizer:
ARTICLE_CATEGORIES = [
'world',
'politics',
'technology',
'sports',
'entertainment',
'economy',
'health'
]
def __init__(self, embedding_model, summarization_model, categorization_model):
nltk.download('punkt_tab') # for tokenizing into sentences
self.embedding_model = embedding_model
self.summarization_model = summarization_model
self.categorization_model = categorization_model
def cluster_sentences(self, sentences, embeddings, min_cluster_size=2):
# HDBSCAN requires more than one sentence
if not sentences:
return []
if len(sentences) == 1:
return [[(sentences[0], embeddings[0])]]
hdb = HDBSCAN(min_cluster_size=min_cluster_size).fit(embeddings)
clusters = {}
for i, sentence in enumerate(sentences):
cluster_id = hdb.labels_[i]
if cluster_id == -1: # discard "noise"
continue
if cluster_id not in clusters:
clusters[cluster_id] = []
clusters[cluster_id].append((sentence, embeddings[i]))
return list(clusters.values())
def create_embeddings(self, sentences):
return self.embedding_model.encode(sentences)
def summarize(self, content, min_length=30, max_length=200):
# truncate content if exceeds model capabilities
max_model_length = self.summarization_model.model.config.max_position_embeddings
if len(content) > max_model_length:
content = content[:max_model_length]
max_length = max(min_length, min(len(content), max_length))
return self.summarization_model(
content,
min_length=min_length,
max_length=max_length,
do_sample=False)[0]['summary_text']
@staticmethod
def rank_cluster_sentences(cluster):
# separate sentences and embeddings
sentences = [entry[0] for entry in cluster]
embeddings = [entry[1] for entry in cluster]
# find center of cluster
center = np.mean(embeddings, axis=0)
# score sentences by similarity to center
scores = cosine_similarity([center], embeddings)[0]
# rank sentences by score
ranked_sentences = [sentence for _, sentence in sorted(zip(scores, sentences), reverse=True)]
return ranked_sentences
def summarize_clusters(self, clusters, top_cluster_count=10, top_k_sentences=10):
# sort clusters by their length (descending) to find the most important topics
clusters = sorted(clusters, key=len, reverse=True)
# take only the top_cluster_count clusters
clusters = clusters[:top_cluster_count]
# combine key sentences from each cluster
key_sentences = []
for i, cluster in enumerate(clusters):
logging.debug(f'Extracting from cluster {i + 1}...')
top_sentences = Summarizer.rank_cluster_sentences(cluster)
content = '\n'.join(top_sentences[:top_k_sentences])
key_sentences.append(content)
combined = ' '.join(key_sentences)
# summarize all key sentences
logging.debug('Creating response...')
summary = self.summarize(
combined,
min_length=60,
max_length=400
)
return summary
def multisource_summary(self, articles, min_cluster_size=2):
'''
Create a single summary from multiple articles
'''
if not articles:
return None
logging.debug('Tokenizing into sentences...')
# create a list of all sentences from all articles
sentences = []
for article in articles:
sentences.extend(tokenize.sent_tokenize(str.strip(article)))
# remove duplicate sentences
sentences = sorted(list(set(sentences)), key=sentences.index)
logging.debug(f'Found {len(sentences)} unique sentences')
if not sentences:
return None
logging.debug('Creating sentence embeddings...')
# create embeddings
embeddings = self.create_embeddings(sentences)
logging.debug('Grouping sentences into clusters...')
# group (embeddings of) sentences by similarity
clusters = self.cluster_sentences(sentences, embeddings, min_cluster_size=min_cluster_size)
logging.debug(f'Created {len(clusters)} clusters')
# summarize all clusters into a single summary
summary = self.summarize_clusters(clusters)
return summary
def categorize_article(self, headline):
result = self.categorization_model(
headline,
candidate_labels=Summarizer.ARTICLE_CATEGORIES
)
# return top result
return result['labels'][0]