| | import base64 |
| | import re |
| | import time |
| | from functools import partial |
| | from io import BytesIO |
| |
|
| | import gradio as gr |
| | import torch |
| |
|
| | from extensions.multimodal.multimodal_embedder import MultimodalEmbedder |
| | from modules import shared |
| | from modules.logging_colors import logger |
| |
|
| | params = { |
| | "add_all_images_to_prompt": False, |
| | |
| | "vision_device": None, |
| | |
| | "vision_bits": 32, |
| | |
| | "projector_device": None, |
| | |
| | "projector_bits": 32 |
| | } |
| |
|
| |
|
| | |
| | input_hijack = { |
| | 'state': False, |
| | 'value': ["", ""] |
| | } |
| |
|
| |
|
| | |
| | multimodal_embedder: MultimodalEmbedder = None |
| |
|
| |
|
| | def chat_input_modifier(text, visible_text, state): |
| | global input_hijack |
| | if input_hijack['state']: |
| | input_hijack['state'] = False |
| | return input_hijack['value'](text, visible_text) |
| | else: |
| | return text, visible_text |
| |
|
| |
|
| | def add_chat_picture(picture, text, visible_text): |
| | |
| | |
| | max_hw, min_hw = max(picture.size), min(picture.size) |
| | aspect_ratio = max_hw / min_hw |
| | shortest_edge = int(max(336 / aspect_ratio, 336)) |
| | longest_edge = int(shortest_edge * aspect_ratio) |
| | w = shortest_edge if picture.width < picture.height else longest_edge |
| | h = shortest_edge if picture.width >= picture.height else longest_edge |
| | picture = picture.resize((w, h)) |
| |
|
| | buffer = BytesIO() |
| | picture.save(buffer, format="PNG") |
| | img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') |
| | image = f'<img src="data:image/jpeg;base64,{img_str}">' |
| |
|
| | if '<image>' in text: |
| | text = text.replace('<image>', image) |
| | else: |
| | text = image + '\n' + text |
| |
|
| | if visible_text == '' or visible_text is None: |
| | visible_text = text |
| | elif '<image>' in visible_text: |
| | visible_text = visible_text.replace('<image>', image) |
| | else: |
| | visible_text = visible_text + '\n' + image |
| |
|
| | return text, visible_text |
| |
|
| |
|
| | def custom_tokenized_length(prompt): |
| | return multimodal_embedder.len_in_tokens(prompt) |
| |
|
| |
|
| | def tokenizer_modifier(state, prompt, input_ids, input_embeds): |
| | global params |
| | start_ts = time.time() |
| | image_match = re.search(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', prompt) |
| |
|
| | if image_match is None: |
| | return prompt, input_ids, input_embeds |
| |
|
| | prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params) |
| | logger.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s') |
| | return (prompt, |
| | input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64), |
| | input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype)) |
| |
|
| |
|
| | def ui(): |
| | global multimodal_embedder |
| | multimodal_embedder = MultimodalEmbedder(params) |
| | with gr.Column(): |
| | picture_select = gr.Image(label='Send a picture', type='pil') |
| | |
| | single_image_checkbox = gr.Checkbox(False, label='Embed all images, not only the last one') |
| | |
| | picture_select.upload( |
| | lambda picture: input_hijack.update({"state": True, "value": partial(add_chat_picture, picture)}), |
| | [picture_select], |
| | None |
| | ) |
| | picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["", ""]}), None, None) |
| | single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None) |
| | shared.gradio['Generate'].click(lambda: None, None, picture_select) |
| | shared.gradio['textbox'].submit(lambda: None, None, picture_select) |
| |
|