| | |
| | |
| |
|
| | import requests, json |
| | from collections import namedtuple |
| | from functools import lru_cache |
| | from typing import List |
| | from dataclasses import dataclass, field |
| | from datetime import datetime as dt |
| | import streamlit as st |
| |
|
| | from codetiming import Timer |
| | from transformers import AutoTokenizer |
| |
|
| | from source import Source, Summary |
| | from scrape_sources import stub as stb |
| |
|
| |
|
| |
|
| | @dataclass |
| | class Digestor: |
| | timer: Timer |
| | cache: bool = True |
| | text: str = field(default="no_digest") |
| | stubs: List = field(default_factory=list) |
| | |
| | |
| | user_choices: List =field(default_factory=list) |
| | |
| | summaries: List = field(default_factory=list) |
| | |
| | |
| |
|
| | digest_meta:namedtuple( |
| | "digestMeta", |
| | [ |
| | 'digest_time', |
| | 'number_articles', |
| | 'digest_length', |
| | 'articles_per_cluster' |
| | ]) = None |
| |
|
| | |
| | token_limit: int = 1024 |
| | word_limit: int = 400 |
| | SUMMARIZATION_PARAMETERS = { |
| | "do_sample": False, |
| | "use_cache": cache, |
| | } |
| |
|
| | |
| | API_URL = "https://api-inference.huggingface.co/models/sshleifer/distilbart-cnn-12-6" |
| | headers = {"Authorization": f"""Bearer {st.secrets['ato']}"""} |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def relevance(self, summary): |
| | return len(set(self.user_choices) & set(summary.cluster_list)) |
| |
|
| | def digest(self): |
| | """Retrieves all data for user-chosen articles, builds summary object list""" |
| | |
| | self.timer.timers.clear() |
| | |
| | with Timer(name=f"digest_time", text="Total digest time: {seconds:.4f} seconds"): |
| | |
| | |
| | for stub in self.stubs: |
| | |
| | |
| | if not isinstance(stub, stb): |
| | self.summaries.append(stub) |
| | print(f"""type(stub): {type(stub)}""") |
| | else: |
| | |
| | summary_data: List |
| | |
| | text, summary_data = stub.source.retrieve_article(stub) |
| | |
| | |
| | if text != None and summary_data != None: |
| | |
| | with Timer(name=f"{stub.hed}_chunk_time", logger=None): |
| | chunk_list = self.chunk_piece(text, self.word_limit, stub.source.source_summarization_checkpoint) |
| | |
| | with Timer(name=f"{stub.hed}_summary_time", text="Whole article summarization time: {:.4f} seconds"): |
| | summary = self.perform_summarization( |
| | stub.hed, |
| | chunk_list, |
| | self.API_URL, |
| | self.headers, |
| | cache = self.cache, |
| | ) |
| | |
| | |
| | |
| |
|
| | self.summaries.append( |
| | Summary( |
| | source=summary_data[0], |
| | cluster_list=summary_data[1], |
| | link_ext=summary_data[2], |
| | hed=summary_data[3], |
| | dek=summary_data[4], |
| | date=summary_data[5], |
| | authors=summary_data[6], |
| | original_length = summary_data[7], |
| | summary_text=summary, |
| | summary_length=len(' '.join(summary).split(' ')), |
| | chunk_time=self.timer.timers[f'{stub.hed}_chunk_time'], |
| | query_time=self.timer.timers[f"{stub.hed}_query_time"], |
| | mean_query_time=self.timer.timers.mean(f'{stub.hed}_query_time'), |
| | summary_time=self.timer.timers[f'{stub.hed}_summary_time'], |
| | |
| | ) |
| | ) |
| | else: |
| | print("Null article") |
| |
|
| |
|
| | |
| | self.summaries.sort(key=self.relevance, reverse=True) |
| |
|
| | |
| | def query(self, payload, API_URL, headers): |
| | """Performs summarization inference API call.""" |
| | data = json.dumps(payload) |
| | response = requests.request("POST", API_URL, headers=headers, data=data) |
| | return json.loads(response.content.decode("utf-8")) |
| |
|
| |
|
| | def chunk_piece(self, piece, limit, tokenizer_checkpoint, include_tail=False): |
| | """Breaks articles into chunks that will fit the desired token length limit""" |
| | |
| | words = len(piece.split(' ')) |
| | |
| | |
| | base_range = [i*limit for i in range(words//limit+1)] |
| | |
| | |
| | |
| | if include_tail or base_range == [0]: |
| | base_range.append(base_range[-1]+words%limit) |
| | |
| | range_list = [i for i in zip(base_range,base_range[1:])] |
| | |
| |
|
| | |
| | fractured = piece.split(' ') |
| | tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint) |
| | chunk_list = [] |
| | |
| | |
| | for i, j in range_list: |
| | if (tokenized_len := len(tokenizer(chunk := ' '.join(fractured[i:j])))) <= self.token_limit: |
| | chunk_list.append(chunk) |
| | else: |
| | chunk_list.append(' '.join(chunk.split(' ')[: self.token_limit - tokenized_len ])) |
| | |
| | return chunk_list |
| |
|
| |
|
| |
|
| | |
| | def perform_summarization(self, stubhead, chunklist : List[str], API_URL: str, headers: None, cache=True) -> List[str]: |
| | """For each in chunk_list, appends result of query(chunk) to list collection_bin.""" |
| | collection_bin = [] |
| | repeat = 0 |
| | |
| | |
| | for chunk in chunklist: |
| | print(f"""Chunk:\n\t{chunk}""") |
| | safe = False |
| | summarized_chunk = None |
| | with Timer(name=f"{stubhead}_query_time", logger=None): |
| | while not safe and repeat < 4: |
| | try: |
| | summarized_chunk = self.query( |
| | { |
| | "inputs": str(chunk), |
| | "parameters": self.SUMMARIZATION_PARAMETERS |
| | }, |
| | API_URL, |
| | headers, |
| | )[0]['summary_text'] |
| | safe = True |
| | except Exception as e: |
| | print("Summarization error, repeating...") |
| | print(e) |
| | repeat+=1 |
| | print(summarized_chunk) |
| | if summarized_chunk is not None: |
| | collection_bin.append(summarized_chunk) |
| | return collection_bin |
| | |
| |
|
| |
|
| | |
| | def build_digest(self) -> str: |
| | """Called to show the digest. Also creates data dict for digest and summaries.""" |
| | |
| | |
| | |
| | |
| | |
| | digest = [] |
| | for each in self.summaries: |
| | digest.append(' '.join(each.summary_text)) |
| | |
| | self.text = '\n\n'.join(digest) |
| |
|
| | |
| | out_data = {} |
| | t = dt.now() |
| | datetime_str = f"""{t.hour:.2f}:{t.minute:.2f}:{t.second:.2f}""" |
| | choices_str = ', '.join(self.user_choices) |
| | digest_str = '\n\t'.join(digest) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | summaries = { |
| | |
| | c: { |
| | |
| | k._fields[i]:p if k._fields[i]!='source' |
| | else |
| | { |
| | 'name': k.source.source_name, |
| | 'source_url': k.source.source_url, |
| | 'Summarization" Checkpoint': k.source.source_summarization_checkpoint, |
| | 'NER Checkpoint': k.source.source_ner_checkpoint, |
| | } for i,p in enumerate(k) |
| | } for c,k in enumerate(self.summaries)} |
| |
|
| | out_data['timestamp'] = datetime_str |
| | out_data['article_count'] = len(self.summaries) |
| | out_data['digest_length'] = len(digest_str.split(" ")) |
| | out_data['sum_params'] = { |
| | 'token_limit':self.token_limit, |
| | 'word_limit':self.word_limit, |
| | 'params':self.SUMMARIZATION_PARAMETERS, |
| | } |
| | out_data['summaries'] = summaries |
| |
|
| | return out_data |