| import base64 |
| import re |
| import time |
| from functools import partial |
| from io import BytesIO |
|
|
| import gradio as gr |
| import torch |
|
|
| from extensions.multimodal.multimodal_embedder import MultimodalEmbedder |
| from modules import shared |
| from modules.logging_colors import logger |
|
|
| params = { |
| "add_all_images_to_prompt": False, |
| |
| "vision_device": None, |
| |
| "vision_bits": 32, |
| |
| "projector_device": None, |
| |
| "projector_bits": 32 |
| } |
|
|
|
|
| |
| input_hijack = { |
| 'state': False, |
| 'value': ["", ""] |
| } |
|
|
|
|
| |
| multimodal_embedder: MultimodalEmbedder = None |
|
|
|
|
| def chat_input_modifier(text, visible_text, state): |
| global input_hijack |
| if input_hijack['state']: |
| input_hijack['state'] = False |
| return input_hijack['value'](text, visible_text) |
| else: |
| return text, visible_text |
|
|
|
|
| def add_chat_picture(picture, text, visible_text): |
| |
| max_hw, min_hw = max(picture.size), min(picture.size) |
| aspect_ratio = max_hw / min_hw |
| shortest_edge = int(max(300 / aspect_ratio, 224)) |
| longest_edge = int(shortest_edge * aspect_ratio) |
| w = shortest_edge if picture.width < picture.height else longest_edge |
| h = shortest_edge if picture.width >= picture.height else longest_edge |
| picture = picture.resize((w, h)) |
|
|
| buffer = BytesIO() |
| picture.save(buffer, format="JPEG") |
| img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') |
| image = f'<img src="data:image/jpeg;base64,{img_str}">' |
|
|
| if '<image>' in text: |
| text = text.replace('<image>', image) |
| else: |
| text = text + '\n' + image |
|
|
| if visible_text == '' or visible_text is None: |
| visible_text = text |
| elif '<image>' in visible_text: |
| visible_text = visible_text.replace('<image>', image) |
| else: |
| visible_text = visible_text + '\n' + image |
|
|
| return text, visible_text |
|
|
|
|
| def custom_tokenized_length(prompt): |
| return multimodal_embedder.len_in_tokens(prompt) |
|
|
|
|
| def tokenizer_modifier(state, prompt, input_ids, input_embeds): |
| global params |
| start_ts = time.time() |
| image_match = re.search(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', prompt) |
|
|
| if image_match is None: |
| return prompt, input_ids, input_embeds |
|
|
| prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params) |
| logger.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s') |
| return (prompt, |
| input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64), |
| input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype)) |
|
|
|
|
| def ui(): |
| global multimodal_embedder |
| multimodal_embedder = MultimodalEmbedder(params) |
| with gr.Column(): |
| picture_select = gr.Image(label='Send a picture', type='pil') |
| |
| single_image_checkbox = gr.Checkbox(False, label='Embed all images, not only the last one') |
| |
| picture_select.upload( |
| lambda picture: input_hijack.update({"state": True, "value": partial(add_chat_picture, picture)}), |
| [picture_select], |
| None |
| ) |
| picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["", ""]}), None, None) |
| single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None) |
| shared.gradio['Generate'].click(lambda: None, None, picture_select) |
| shared.gradio['textbox'].submit(lambda: None, None, picture_select) |
|
|