| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from llava.model.builder import load_pretrained_model | |
| from llava.mm_utils import tokenizer_image_token, get_model_name_from_path | |
| from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN | |
| from llava.conversation import conv_templates | |
| from PIL import Image | |
| import torch | |
| model_path = "wisdomik/Quilt-Llava-v1.5-7b" | |
| tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base=None, model_name=get_model_name_from_path(model_path)) | |
| def predict(image, prompt, history): | |
| if image is not None: | |
| image_token = DEFAULT_IMAGE_TOKEN | |
| prompt = image_token + '\n' + prompt | |
| else: | |
| prompt = prompt | |
| inp = f"{prompt}\nAssistant:" | |
| conv = conv_templates["llava_v1"].copy() | |
| conv.append_message(conv.roles[0], inp) | |
| conv.append_message(conv.roles[1], None) | |
| prompt = conv.get_prompt() | |
| input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() | |
| with torch.inference_mode(): | |
| output_ids = model.generate(input_ids, max_new_tokens=512, do_sample=False) | |
| response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True).strip() | |
| history.append((prompt, response)) | |
| return history, "" | |
| iface = gr.ChatInterface(predict, multimodal=True) | |
| iface.launch() | |