Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import transformers | |
| from transformers import BitsAndBytesConfig | |
| from llavajp.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX | |
| from llavajp.conversation import conv_templates | |
| from llavajp.model.llava_llama import LlavaLlamaForCausalLM | |
| from llavajp.train.dataset import tokenizer_image_token | |
| import spaces | |
| # import subprocess | |
| # import sys | |
| # def install_package(): | |
| # subprocess.check_call([sys.executable, "-m", "pip", "install", "https://github.com/hibikaze-git/LLaVA-JP@feature/tanuki-moe"]) | |
| model_path = "weblab-GENIAC/Tanuki-8B-vision" | |
| # load model | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32 | |
| bnb_model_from_pretrained_args = {} | |
| bnb_model_from_pretrained_args.update(dict( | |
| device_map="auto", | |
| quantization_config=BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| llm_int8_skip_modules=["mm_projector", "vision_tower"], | |
| llm_int8_threshold=6.0, | |
| llm_int8_has_fp16_weight=False, | |
| ) | |
| )) | |
| model = LlavaLlamaForCausalLM.from_pretrained( | |
| model_path, | |
| low_cpu_mem_usage=True, | |
| use_safetensors=True, | |
| **bnb_model_from_pretrained_args | |
| ) | |
| tokenizer = transformers.AutoTokenizer.from_pretrained( | |
| model_path, | |
| model_max_length=8192, | |
| padding_side="right", | |
| use_fast=False, | |
| ) | |
| model.eval() | |
| conv_mode = "v1" | |
| def inference_fn( | |
| image, | |
| prompt, | |
| max_len, | |
| temperature, | |
| top_p, | |
| no_repeat_ngram_size | |
| ): | |
| # prepare inputs | |
| # image pre-process | |
| image_size = model.get_model().vision_tower.image_processor.size["height"] | |
| if model.get_model().vision_tower.scales is not None: | |
| image_size = model.get_model().vision_tower.image_processor.size[ | |
| "height" | |
| ] * len(model.get_model().vision_tower.scales) | |
| if device == "cuda": | |
| image_tensor = ( | |
| model.get_model() | |
| .vision_tower.image_processor( | |
| image, | |
| return_tensors="pt", | |
| size={"height": image_size, "width": image_size}, | |
| )["pixel_values"] | |
| .half() | |
| .cuda() | |
| .to(torch_dtype) | |
| ) | |
| else: | |
| image_tensor = ( | |
| model.get_model() | |
| .vision_tower.image_processor( | |
| image, | |
| return_tensors="pt", | |
| size={"height": image_size, "width": image_size}, | |
| )["pixel_values"] | |
| .to(torch_dtype) | |
| ) | |
| # create prompt | |
| inp = DEFAULT_IMAGE_TOKEN + "\n" + prompt | |
| conv = conv_templates[conv_mode].copy() | |
| conv.append_message(conv.roles[0], inp) | |
| conv.append_message(conv.roles[1], None) | |
| prompt = conv.get_prompt() | |
| input_ids = tokenizer_image_token( | |
| prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" | |
| ).unsqueeze(0) | |
| if device == "cuda": | |
| input_ids = input_ids.to(device) | |
| input_ids = input_ids[:, :-1] # </sep>がinputの最後に入るので削除する | |
| # generate | |
| output_ids = model.generate( | |
| inputs=input_ids, | |
| images=image_tensor, | |
| do_sample=temperature != 0.0, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_new_tokens=max_len, | |
| repetition_penalty=1.0, | |
| use_cache=False, | |
| no_repeat_ngram_size=no_repeat_ngram_size | |
| ) | |
| output_ids = [ | |
| token_id for token_id in output_ids.tolist()[0] if token_id != IMAGE_TOKEN_INDEX | |
| ] | |
| print(output_ids) | |
| output = tokenizer.decode(output_ids, skip_special_tokens=True) | |
| print(output) | |
| target = "システム: " | |
| idx = output.find(target) | |
| output = output[idx + len(target) :] | |
| return output | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Tanuki-8B-vision Demo") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # input_instruction = gr.TextArea(label="instruction", value=DEFAULT_INSTRUCTION) | |
| input_image = gr.Image(type="pil", label="image") | |
| prompt = gr.Textbox(label="prompt (optional)", value="") | |
| with gr.Accordion(label="Configs", open=False): | |
| max_len = gr.Slider( | |
| minimum=10, | |
| maximum=256, | |
| value=200, | |
| step=5, | |
| interactive=True, | |
| label="Max New Tokens", | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.0, | |
| step=0.1, | |
| interactive=True, | |
| label="Temperature", | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.5, | |
| maximum=1.0, | |
| value=1.0, | |
| step=0.1, | |
| interactive=True, | |
| label="Top p", | |
| ) | |
| no_repeat_ngram_size = gr.Slider( | |
| minimum=0, | |
| maximum=4, | |
| value=3.0, | |
| step=1, | |
| interactive=True, | |
| label="No Repeat Ngram Size(1, 2にすると出力が狂います)", | |
| ) | |
| # button | |
| input_button = gr.Button(value="Submit") | |
| with gr.Column(): | |
| output = gr.Textbox(label="Output") | |
| inputs = [input_image, prompt, max_len, temperature, top_p, no_repeat_ngram_size] | |
| input_button.click(inference_fn, inputs=inputs, outputs=[output]) | |
| prompt.submit(inference_fn, inputs=inputs, outputs=[output]) | |
| img2txt_examples = gr.Examples( | |
| examples=[ | |
| [ | |
| "https://raw.githubusercontent.com/hibikaze-git/LLaVA-JP/feature/package/imgs/sample1.jpg", | |
| "猫の隣には何がありますか?", | |
| 128, | |
| 0.0, | |
| 1.0, | |
| 3.0 | |
| ], | |
| ], | |
| inputs=inputs, | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(server_name="0.0.0.0") |