Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from diffusers import DiffusionPipeline | |
| from transformers import AutoTokenizer, CLIPTokenizerFast, T5TokenizerFast | |
| import pandas as pd | |
| def load_tokenizers(model_id: str) -> list[CLIPTokenizerFast | T5TokenizerFast | None]: | |
| config = DiffusionPipeline.load_config(model_id) | |
| num_tokenizers = sum("tokenizer" in key for key in config.keys()) | |
| if not 1 <= num_tokenizers <= 3: | |
| raise gr.Error(f"Invalid number of tokenizers: {num_tokenizers}") | |
| tokenizers = [ | |
| AutoTokenizer.from_pretrained( | |
| model_id, subfolder=f'tokenizer{"" if i == 0 else f"_{i + 1}"}' | |
| ) | |
| for i in range(num_tokenizers) | |
| ] | |
| # Pad the list with None if there are fewer than 3 tokenizers | |
| tokenizers.extend([None] * (3 - num_tokenizers)) | |
| return tokenizers | |
| def inference(model_id: str, text: str): | |
| tokenizers = load_tokenizers(model_id) | |
| text_pairs_components = [] | |
| special_tokens_components = [] | |
| tokenizer_details_components = [] | |
| for i, tokenizer in enumerate(tokenizers): | |
| if tokenizer: | |
| label_text = f"Tokenizer {i + 1}: {tokenizer.__class__.__name__}" | |
| # テキストとトークンIDのペアを作成 | |
| input_ids = tokenizer( | |
| text=text, | |
| truncation=False, | |
| return_length=False, | |
| return_overflowing_tokens=False, | |
| ).input_ids | |
| decoded_tokens = [tokenizer.decode(id_) for id_ in input_ids] | |
| token_pairs = [ | |
| (str(token), str(id_)) for token, id_ in zip(decoded_tokens, input_ids) | |
| ] | |
| output_text_pair_component = gr.HighlightedText( | |
| label=label_text, | |
| value=token_pairs, | |
| visible=True, | |
| ) | |
| # スペシャルトークンを追加 | |
| special_tokens = [] | |
| for k, v in tokenizer.special_tokens_map.items(): | |
| if k == "additional_special_tokens": | |
| continue | |
| special_token_map = (str(k), str(v)) | |
| special_tokens.append(special_token_map) | |
| output_special_tokens_component = gr.HighlightedText( | |
| label=label_text, | |
| value=special_tokens, | |
| visible=True, | |
| ) | |
| # トークナイザーの詳細情報を追加 | |
| tokenizer_details = pd.DataFrame([ | |
| ("Type", tokenizer.__class__.__name__), | |
| ("Vocab Size", tokenizer.vocab_size), | |
| ("Model Max Length", tokenizer.model_max_length), | |
| ("Padding Side", tokenizer.padding_side), | |
| ("Truncation Side", tokenizer.truncation_side), | |
| ], columns=["Attribute", "Value"]) | |
| output_tokenizer_details = gr.Dataframe( | |
| headers=["Attribute", "Value"], | |
| value=tokenizer_details, | |
| label=label_text, | |
| visible=True, | |
| ) | |
| else: | |
| output_text_pair_component = gr.HighlightedText(visible=False) | |
| output_special_tokens_component = gr.HighlightedText(visible=False) | |
| output_tokenizer_details = gr.Dataframe(visible=False) | |
| text_pairs_components.append(output_text_pair_component) | |
| special_tokens_components.append(output_special_tokens_component) | |
| tokenizer_details_components.append(output_tokenizer_details) | |
| return text_pairs_components + special_tokens_components + tokenizer_details_components | |
| if __name__ == "__main__": | |
| theme = gr.themes.Soft( | |
| primary_hue=gr.themes.colors.emerald, | |
| secondary_hue=gr.themes.colors.emerald, | |
| ) | |
| with gr.Blocks(theme=theme) as demo: | |
| with gr.Column(): | |
| input_model_id = gr.Dropdown( | |
| label="Model ID", | |
| choices=[ | |
| "black-forest-labs/FLUX.1-dev", | |
| "black-forest-labs/FLUX.1-schnell", | |
| "stabilityai/stable-diffusion-3-medium-diffusers", | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| "stable-diffusion-v1-5/stable-diffusion-v1-5", | |
| "stabilityai/japanese-stable-diffusion-xl", | |
| "rinna/japanese-stable-diffusion", | |
| ], | |
| value="black-forest-labs/FLUX.1-dev", | |
| ) | |
| input_text = gr.Textbox( | |
| label="Input Text", | |
| placeholder="Enter text here", | |
| ) | |
| with gr.Tab(label="Tokenization Outputs"): | |
| with gr.Column(): | |
| output_highlighted_text_1 = gr.HighlightedText() | |
| output_highlighted_text_2 = gr.HighlightedText() | |
| output_highlighted_text_3 = gr.HighlightedText() | |
| with gr.Tab(label="Special Tokens"): | |
| with gr.Column(): | |
| output_special_tokens_1 = gr.HighlightedText() | |
| output_special_tokens_2 = gr.HighlightedText() | |
| output_special_tokens_3 = gr.HighlightedText() | |
| with gr.Tab(label="Tokenizer Details"): | |
| with gr.Column(): | |
| output_tokenizer_details_1 = gr.Dataframe(headers=["Attribute", "Value"]) | |
| output_tokenizer_details_2 = gr.Dataframe(headers=["Attribute", "Value"]) | |
| output_tokenizer_details_3 = gr.Dataframe(headers=["Attribute", "Value"]) | |
| with gr.Row(): | |
| clear_button = gr.ClearButton(components=[input_text]) | |
| submit_button = gr.Button("Run", variant="primary") | |
| all_inputs = [input_model_id, input_text] | |
| all_output = [ | |
| output_highlighted_text_1, | |
| output_highlighted_text_2, | |
| output_highlighted_text_3, | |
| output_special_tokens_1, | |
| output_special_tokens_2, | |
| output_special_tokens_3, | |
| output_tokenizer_details_1, | |
| output_tokenizer_details_2, | |
| output_tokenizer_details_3, | |
| ] | |
| submit_button.click(fn=inference, inputs=all_inputs, outputs=all_output) | |
| examples = gr.Examples( | |
| fn=inference, | |
| inputs=all_inputs, | |
| outputs=all_output, | |
| examples=[ | |
| ["black-forest-labs/FLUX.1-dev", "a photo of cat"], | |
| [ | |
| "stabilityai/stable-diffusion-3-medium-diffusers", | |
| 'cat holding sign saying "I am a cat"', | |
| ], | |
| ["rinna/japanese-stable-diffusion", "空を飛んでいるネコの写真 油絵"], | |
| ], | |
| cache_examples=True, | |
| ) | |
| demo.queue().launch() |