| import os |
| from functools import lru_cache |
| from time import time |
|
|
| import streamlit as st |
| from grouped_sampling import GroupedSamplingPipeLine |
|
|
| from download_repo import download_pytorch_model |
|
|
|
|
| def is_downloaded(model_name: str) -> bool: |
| """ |
| Checks if the model is downloaded. |
| :param model_name: The name of the model to check. |
| :return: True if the model is downloaded, False otherwise. |
| """ |
| models_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub") |
| model_dir = os.path.join(models_dir, f"models--{model_name.replace('/', '--')}") |
| return os.path.isdir(model_dir) |
|
|
|
|
| @lru_cache(maxsize=10) |
| def create_pipeline(model_name: str) -> GroupedSamplingPipeLine: |
| """ |
| Creates a pipeline with the given model name and group size. |
| :param model_name: The name of the model to use. |
| :return: A pipeline with the given model name and group size. |
| """ |
| if not is_downloaded(model_name): |
| download_repository_start_time = time() |
| st.write(f"Starts downloading model: {model_name} from the internet.") |
| download_pytorch_model(model_name) |
| download_repository_end_time = time() |
| download_time = download_repository_end_time - download_repository_start_time |
| st.write(f"Finished downloading model: {model_name} from the internet in {download_time:,.2f} seconds.") |
| st.write(f"Starts creating pipeline with model: {model_name}") |
| pipeline_start_time = time() |
| pipeline = GroupedSamplingPipeLine( |
| model_name=model_name, |
| group_size=512, |
| end_of_sentence_stop=False, |
| top_k=50, |
| load_in_8bit=False, |
| ) |
| pipeline_end_time = time() |
| pipeline_time = pipeline_end_time - pipeline_start_time |
| st.write(f"Finished creating pipeline with model: {model_name} in {pipeline_time:,.2f} seconds.") |
| return pipeline |
|
|
|
|
| def generate_text( |
| pipeline: GroupedSamplingPipeLine, |
| prompt: str, |
| output_length: int, |
| ) -> str: |
| """ |
| Generates text using the given pipeline. |
| :param pipeline: The pipeline to use. GroupedSamplingPipeLine. |
| :param prompt: The prompt to use. str. |
| :param output_length: The size of the text to generate in tokens. int > 0. |
| :return: The generated text. str. |
| """ |
| return pipeline( |
| prompt_s=prompt, |
| max_new_tokens=output_length, |
| return_text=True, |
| return_full_text=False, |
| )["generated_text"] |
|
|
|
|
| def on_form_submit( |
| model_name: str, |
| output_length: int, |
| prompt: str, |
| ) -> str: |
| """ |
| Called when the user submits the form. |
| :param model_name: The name of the model to use. |
| :param output_length: The size of the groups to use. |
| :param prompt: The prompt to use. |
| :return: The output of the model. |
| :raises ValueError: If the model name is not supported, the output length is <= 0, |
| the prompt is empty or longer than |
| 16384 characters, or the output length is not an integer. |
| TypeError: If the output length is not an integer or the prompt is not a string. |
| RuntimeError: If the model is not found. |
| """ |
| if len(prompt) == 0: |
| raise ValueError("The prompt must not be empty.") |
| st.write(f"Loading model: {model_name}...") |
| loading_start_time = time() |
| pipeline = create_pipeline( |
| model_name=model_name, |
| ) |
| loading_end_time = time() |
| loading_time = loading_end_time - loading_start_time |
| st.write(f"Finished loading model: {model_name} in {loading_time:,.2f} seconds.") |
| st.write("Generating text...") |
| generation_start_time = time() |
| generated_text = generate_text( |
| pipeline=pipeline, |
| prompt=prompt, |
| output_length=output_length, |
| ) |
| generation_end_time = time() |
| generation_time = generation_end_time - generation_start_time |
| st.write(f"Finished generating text in {generation_time:,.2f} seconds.") |
| if not isinstance(generated_text, str): |
| raise RuntimeError(f"The model {model_name} did not generate any text.") |
| if len(generated_text) == 0: |
| raise RuntimeError(f"The model {model_name} did not generate any text.") |
| return generated_text |
|
|