| import json |
| from functools import lru_cache |
| from youtube_transcript_api import ( |
| YouTubeTranscriptApi, |
| TooManyRequests, |
| YouTubeRequestFailed, |
| CouldNotRetrieveTranscript |
| ) |
| import json |
| import re |
| import requests |
| from transformers import ( |
| AutoModelForSequenceClassification, |
| AutoTokenizer, |
| TextClassificationPipeline, |
| ) |
| from typing import Any, Dict, List |
| import os |
| import numpy as np |
|
|
| CATEGORIES = [None, 'SPONSOR', 'SELFPROMO', 'INTERACTION'] |
|
|
| PROFANITY_RAW = '[ __ ]' |
| PROFANITY_CONVERTED = '*****' |
|
|
| NUM_DECIMALS = 3 |
|
|
| |
| |
| LANGUAGE_PREFERENCE_LIST = ['en-GB', 'en-US', 'en-CA', 'en-AU', 'en-NZ', 'en-ZA', |
| 'en-IE', 'en-IN', 'en-JM', 'en-BZ', 'en-TT', 'en-PH', 'en-ZW', |
| 'en'] |
|
|
|
|
| def parse_transcript_json(json_data, granularity): |
| assert json_data['wireMagic'] == 'pb3' |
|
|
| assert granularity in ('word', 'chunk') |
|
|
| |
| |
| |
| |
|
|
| |
| |
|
|
| parsed_transcript = [] |
|
|
| events = json_data['events'] |
|
|
| for event_index, event in enumerate(events): |
| segments = event.get('segs') |
| if not segments: |
| continue |
|
|
| |
| start_ms = event['tStartMs'] |
| total_characters = 0 |
|
|
| new_segments = [] |
| for seg in segments: |
| |
| text = ' '.join(seg['utf8'].split()) |
|
|
| |
| text = text.replace('\u200b', '').replace('\u200c', '').replace( |
| '\u200d', '').replace('\ufeff', '').strip() |
|
|
| |
| |
|
|
| |
| text = text.replace(PROFANITY_RAW, PROFANITY_CONVERTED) |
|
|
| if not text: |
| continue |
|
|
| offset_ms = seg.get('tOffsetMs', 0) |
|
|
| new_segments.append({ |
| 'text': text, |
| 'start': round((start_ms + offset_ms)/1000, NUM_DECIMALS) |
| }) |
|
|
| total_characters += len(text) |
|
|
| if not new_segments: |
| continue |
|
|
| if event_index < len(events) - 1: |
| next_start_ms = events[event_index + 1]['tStartMs'] |
| total_event_duration_ms = min( |
| event.get('dDurationMs', float('inf')), next_start_ms - start_ms) |
| else: |
| total_event_duration_ms = event.get('dDurationMs', 0) |
|
|
| |
| total_event_duration_ms = max(total_event_duration_ms, 0) |
|
|
| avg_seconds_per_character = ( |
| total_event_duration_ms/total_characters)/1000 |
|
|
| num_char_count = 0 |
| for seg_index, seg in enumerate(new_segments): |
| num_char_count += len(seg['text']) |
|
|
| |
| seg_end = seg['start'] + \ |
| (num_char_count * avg_seconds_per_character) |
|
|
| if seg_index < len(new_segments) - 1: |
| |
| seg_end = min(seg_end, new_segments[seg_index+1]['start']) |
|
|
| seg['end'] = round(seg_end, NUM_DECIMALS) |
| parsed_transcript.append(seg) |
|
|
| final_parsed_transcript = [] |
| for i in range(len(parsed_transcript)): |
|
|
| word_level = granularity == 'word' |
| if word_level: |
| split_text = parsed_transcript[i]['text'].split() |
| elif granularity == 'chunk': |
| |
| split_text = re.split( |
| r'(?<=[.!?,-;])\s+', parsed_transcript[i]['text']) |
| if len(split_text) == 1: |
| split_on_whitespace = parsed_transcript[i]['text'].split() |
|
|
| if len(split_on_whitespace) >= 8: |
| |
| split_text = split_on_whitespace |
| else: |
| word_level = True |
| else: |
| raise ValueError('Unknown granularity') |
|
|
| segment_end = parsed_transcript[i]['end'] |
| if i < len(parsed_transcript) - 1: |
| segment_end = min(segment_end, parsed_transcript[i+1]['start']) |
|
|
| segment_duration = segment_end - parsed_transcript[i]['start'] |
|
|
| num_chars_in_text = sum(map(len, split_text)) |
|
|
| num_char_count = 0 |
| current_offset = 0 |
| for s in split_text: |
| num_char_count += len(s) |
|
|
| next_offset = (num_char_count/num_chars_in_text) * segment_duration |
|
|
| word_start = round( |
| parsed_transcript[i]['start'] + current_offset, NUM_DECIMALS) |
| word_end = round( |
| parsed_transcript[i]['start'] + next_offset, NUM_DECIMALS) |
|
|
| |
| final_parsed_transcript.append({ |
| 'text': s, |
| 'start': word_start, |
| 'end': min(word_end, word_start + 1.5) if word_level else word_end |
| }) |
| current_offset = next_offset |
|
|
| return final_parsed_transcript |
|
|
|
|
| def list_transcripts(video_id): |
| try: |
| return YouTubeTranscriptApi.list_transcripts(video_id) |
| except json.decoder.JSONDecodeError: |
| return None |
|
|
|
|
| WORDS_TO_REMOVE = [ |
| '[Music]' |
| '[Applause]' |
| '[Laughter]' |
| ] |
|
|
|
|
| @lru_cache(maxsize=16) |
| def get_words(video_id, transcript_type='auto', fallback='manual', filter_words_to_remove=True, granularity='word'): |
| """Get parsed video transcript with caching system |
| returns None if not processed yet and process is False |
| """ |
|
|
| raw_transcript_json = None |
| try: |
| transcript_list = list_transcripts(video_id) |
|
|
| if transcript_list is not None: |
| if transcript_type == 'manual': |
| ts = transcript_list.find_manually_created_transcript( |
| LANGUAGE_PREFERENCE_LIST) |
| else: |
| ts = transcript_list.find_generated_transcript( |
| LANGUAGE_PREFERENCE_LIST) |
| raw_transcript = ts._http_client.get( |
| f'{ts._url}&fmt=json3').content |
| if raw_transcript: |
| raw_transcript_json = json.loads(raw_transcript) |
| except (TooManyRequests, YouTubeRequestFailed): |
| raise |
|
|
| except requests.exceptions.RequestException: |
| return get_words(video_id, transcript_type, fallback, granularity) |
|
|
| except CouldNotRetrieveTranscript: |
| pass |
|
|
| except json.decoder.JSONDecodeError: |
| return get_words(video_id, transcript_type, fallback, granularity) |
|
|
| if not raw_transcript_json and fallback is not None: |
| return get_words(video_id, transcript_type=fallback, fallback=None, granularity=granularity) |
|
|
| if raw_transcript_json: |
| processed_transcript = parse_transcript_json( |
| raw_transcript_json, granularity) |
| if filter_words_to_remove: |
| processed_transcript = list( |
| filter(lambda x: x['text'] not in WORDS_TO_REMOVE, processed_transcript)) |
| else: |
| processed_transcript = raw_transcript_json |
|
|
| return processed_transcript |
|
|
|
|
| def word_start(word): |
| return word['start'] |
|
|
|
|
| def word_end(word): |
| return word.get('end', word['start']) |
|
|
|
|
| def extract_segment(words, start, end, map_function=None): |
| """Extracts all words with time in [start, end]""" |
|
|
| a = max(binary_search_below(words, 0, len(words), start), 0) |
| b = min(binary_search_above(words, -1, len(words) - 1, end) + 1, len(words)) |
|
|
| to_transform = map_function is not None and callable(map_function) |
|
|
| return [ |
| map_function(words[i]) if to_transform else words[i] for i in range(a, b) |
| ] |
|
|
|
|
| def avg(*items): |
| return sum(items)/len(items) |
|
|
|
|
| def binary_search_below(transcript, start_index, end_index, time): |
| if start_index >= end_index: |
| return end_index |
|
|
| middle_index = (start_index + end_index) // 2 |
| middle = transcript[middle_index] |
| middle_time = avg(word_start(middle), word_end(middle)) |
|
|
| if time <= middle_time: |
| return binary_search_below(transcript, start_index, middle_index, time) |
| else: |
| return binary_search_below(transcript, middle_index + 1, end_index, time) |
|
|
|
|
| def binary_search_above(transcript, start_index, end_index, time): |
| if start_index >= end_index: |
| return end_index |
|
|
| middle_index = (start_index + end_index + 1) // 2 |
| middle = transcript[middle_index] |
| middle_time = avg(word_start(middle), word_end(middle)) |
|
|
| if time >= middle_time: |
| return binary_search_above(transcript, middle_index, end_index, time) |
| else: |
| return binary_search_above(transcript, start_index, middle_index - 1, time) |
|
|
|
|
| class PreTrainedPipeline(): |
| def __init__(self, path: str): |
| self.model2 = AutoModelForSequenceClassification.from_pretrained(path) |
| self.tokenizer2 = AutoTokenizer.from_pretrained(path) |
| self.pipeline2 = SponsorBlockClassificationPipeline( |
| model=self.model2, tokenizer=self.tokenizer2) |
|
|
| def __call__(self, inputs: str) -> List[Dict[str, Any]]: |
|
|
| |
| if ' ' not in inputs and inputs.count(',') >= 2: |
| split_info = inputs.split(',', 1) |
| times = np.reshape(np.array(split_info[1].split(',')), (-1, 2)) |
| data = [] |
| for start, end in times: |
| data.append({ |
| 'video_id': split_info[0], |
| 'start': float(start), |
| 'end': float(end) |
| }) |
| else: |
| data = inputs |
|
|
| return self.pipeline2(data) |
|
|
|
|
| class SponsorBlockClassificationPipeline(TextClassificationPipeline): |
| def __init__(self, model, tokenizer): |
| super().__init__(model=model, tokenizer=tokenizer, return_all_scores=True) |
|
|
| def preprocess(self, data, **tokenizer_kwargs): |
| if isinstance(data, str): |
| text = data |
| else: |
| words = get_words(data['video_id']) |
| segment_words = extract_segment(words, data['start'], data['end']) |
| text = ' '.join(x['text'] for x in segment_words) |
|
|
| return self.tokenizer( |
| text, return_tensors=self.framework, **tokenizer_kwargs) |
|
|