| |
| from pathlib import Path |
| import json |
| import argparse |
| import os |
|
|
| from pysbd import Segmenter |
| from tiktoken import Encoding |
|
|
| from .knowledge_graph import PROMPT_FILE_PATH |
| from .openai_api import (RESPONSES_DIRECTORY_PATH, |
| get_max_chapter_segment_token_count, |
| get_openai_model_encoding, save_openai_api_response) |
| from .utils import (execute_function_in_parallel, set_up_logging, |
| strip_and_remove_empty_strings) |
|
|
| logger = set_up_logging('openai-api-scripts.log') |
|
|
|
|
| def get_paragraphs(text): |
| """Split a text into paragraphs.""" |
| paragraphs = strip_and_remove_empty_strings(text.split('\n\n')) |
| |
| paragraphs = [' '.join(paragraph.split()) for paragraph in paragraphs] |
| return paragraphs |
|
|
|
|
| def combine_text_subunits_into_segments(subunits, join_string, |
| encoding: Encoding, |
| max_token_count): |
| """ |
| Combine subunits of text into segments that do not exceed a maximum number |
| of tokens. |
| """ |
| |
| |
| subunit_token_counts = [len(tokens) for tokens |
| in encoding.encode_ordinary_batch(subunits)] |
| join_string_token_count = len(encoding.encode_ordinary(join_string)) |
| total_token_count = (sum(subunit_token_counts) + join_string_token_count |
| * (len(subunits) - 1)) |
| if total_token_count <= max_token_count: |
| return [join_string.join(subunits)] |
| |
| |
| |
| approximate_segment_count = total_token_count // max_token_count + 1 |
| approximate_segment_token_count = round(total_token_count |
| / approximate_segment_count) |
| segments = [] |
| current_segment_subunits = [] |
| current_segment_token_count = 0 |
| for i, (subunit, subunit_token_count) in enumerate( |
| zip(subunits, subunit_token_counts)): |
| |
| |
| extended_segment_token_count = (current_segment_token_count |
| + join_string_token_count |
| + subunit_token_count) |
| |
| |
| |
| if (extended_segment_token_count <= max_token_count |
| and abs(extended_segment_token_count |
| - approximate_segment_token_count) |
| <= abs(current_segment_token_count |
| - approximate_segment_token_count)): |
| current_segment_subunits.append(subunit) |
| current_segment_token_count = extended_segment_token_count |
| else: |
| segment = join_string.join(current_segment_subunits) |
| segments.append(segment) |
| |
| |
| |
| if (sum(subunit_token_counts[i:]) + join_string_token_count |
| * (len(subunits) - i - 1) <= max_token_count |
| or i == len(subunits) - 1): |
| segment = join_string.join(subunits[i:]) |
| segments.append(segment) |
| break |
| current_segment_subunits = [subunit] |
| current_segment_token_count = subunit_token_count |
| return segments |
|
|
|
|
| def split_long_sentences(sentences, encoding: Encoding, |
| max_token_count): |
| """ |
| Given a list of sentences, split sentences that exceed a maximum number of |
| tokens into multiple segments. |
| """ |
| token_counts = [len(tokens) for tokens |
| in encoding.encode_ordinary_batch(sentences)] |
| split_sentences = [] |
| for sentence, token_count in zip(sentences, token_counts): |
| if token_count > max_token_count: |
| words = sentence.split() |
| segments = combine_text_subunits_into_segments( |
| words, ' ', encoding, max_token_count) |
| split_sentences.extend(segments) |
| else: |
| split_sentences.append(sentence) |
| return split_sentences |
|
|
|
|
| def split_long_paragraphs(paragraphs, encoding: Encoding, |
| max_token_count): |
| """ |
| Given a list of paragraphs, split paragraphs that exceed a maximum number |
| of tokens into multiple segments. |
| """ |
| token_counts = [len(tokens) for tokens |
| in encoding.encode_ordinary_batch(paragraphs)] |
| split_paragraphs = [] |
| for paragraph, token_count in zip(paragraphs, token_counts): |
| if token_count > max_token_count: |
| sentences = Segmenter().segment(paragraph) |
| sentences = split_long_sentences(sentences, encoding, |
| max_token_count) |
| segments = combine_text_subunits_into_segments( |
| sentences, ' ', encoding, max_token_count) |
| split_paragraphs.extend(segments) |
| else: |
| split_paragraphs.append(paragraph) |
| return split_paragraphs |
|
|
|
|
| def get_chapter_segments(chapter_text, encoding: Encoding, |
| max_token_count): |
| """ |
| Split a chapter text into segments that do not exceed a maximum number of |
| tokens. |
| """ |
| paragraphs = get_paragraphs(chapter_text) |
| paragraphs = split_long_paragraphs(paragraphs, encoding, max_token_count) |
| chapter_segments = combine_text_subunits_into_segments( |
| paragraphs, '\n', encoding, max_token_count) |
| return chapter_segments |
|
|
|
|
| def get_response_save_path(idx, save_path, project_gutenberg_id, |
| chapter_index = None, |
| chapter_segment_index = None, |
| chapter_segment_count = None): |
| """ |
| Get the path to the JSON file(s) containing response data from the OpenAI |
| API. |
| """ |
| save_path = Path(save_path) |
| os.makedirs(save_path, exist_ok=True) |
| |
| if chapter_index is not None: |
| save_path /= str(chapter_index) |
| if chapter_segment_index is not None: |
| save_path /= (f'{chapter_segment_index + 1}-of-' |
| f'{chapter_segment_count}.json') |
| return save_path |
|
|
|
|
|
|
|
|
| def save_openai_api_responses_for_script(script, prompt, encoding, max_chapter_segment_token_count, idx, api_key, model_id): |
| """ |
| Call the OpenAI API for each chapter segment in a script and save the |
| responses to a list. |
| """ |
| project_gutenberg_id = script['id'] |
| chapter_count = len(script['chapters']) |
| logger.info(f'Starting to call OpenAI API and process responses for script ' |
| f'{project_gutenberg_id} ({chapter_count} chapters).') |
|
|
| prompt_message_lists = [] |
| response_list = [] |
|
|
| for chapter in script['chapters']: |
| chapter_index = chapter['index'] |
| chapter_segments = chapter['text'] |
| chapter_segment_count = len(chapter_segments) |
|
|
| for chapter_segment_index, chapter_segment in enumerate(chapter_segments): |
| prompt_with_story = prompt.replace('{STORY}', chapter_segment) |
| prompt_message_lists.append([{ |
| 'role': 'user', |
| 'content': prompt_with_story, |
| 'api_key': api_key, |
| 'model_id': model_id |
| }]) |
|
|
| responses = execute_function_in_parallel(save_openai_api_response, prompt_message_lists) |
|
|
| for response in responses: |
| response_list.append(response) |
|
|
| logger.info(f'Finished processing responses for script {project_gutenberg_id}.') |
| return response_list |
|
|
|
|
| def save_triples_for_scripts(input_data, idx, api_key, model_id): |
| """ |
| Call the OpenAI API to generate knowledge graph nodes and edges, and store |
| the responses in a list. |
| """ |
| |
| script = input_data |
|
|
| |
| prompt = PROMPT_FILE_PATH.read_text() |
| max_chapter_segment_token_count = get_max_chapter_segment_token_count(prompt, model_id) |
| encoding = get_openai_model_encoding(model_id) |
| responses = save_openai_api_responses_for_script( |
| script, prompt, encoding, max_chapter_segment_token_count, idx, api_key, model_id |
| ) |
|
|
| return responses |