|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
import io |
|
|
import argparse |
|
|
import os |
|
|
import random |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.backends.cudnn as cudnn |
|
|
|
|
|
from omegaconf import OmegaConf |
|
|
from cheetah.common.config import Config |
|
|
from cheetah.common.registry import registry |
|
|
from cheetah.conversation.conversation_llama2 import Chat, CONV_VISION |
|
|
|
|
|
from cheetah.models import * |
|
|
from cheetah.processors import * |
|
|
|
|
|
|
|
|
print('Initializing Chat') |
|
|
|
|
|
config = OmegaConf.load('eval_configs/cheetah_eval_llama2.yaml') |
|
|
cfg = Config.build_model_config(config) |
|
|
model_cls = registry.get_model_class(cfg.model.arch) |
|
|
model = model_cls.from_config(cfg.model).to('cuda:{}'.format(0)) |
|
|
|
|
|
vis_processor_cfg = cfg.preprocess.vis_processor.eval |
|
|
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) |
|
|
chat = Chat(model, vis_processor, device='cuda:{}'.format(0)) |
|
|
print('Initialization Finished') |
|
|
|
|
|
|
|
|
|
|
|
def respond(imgs, question): |
|
|
|
|
|
|
|
|
raw_img_list = [] |
|
|
|
|
|
|
|
|
for i, img in enumerate(imgs): |
|
|
|
|
|
temp_img = Image.open(io.BytesIO(img)) |
|
|
|
|
|
|
|
|
temp_img = temp_img.convert(mode='RGB') |
|
|
|
|
|
|
|
|
save_path = "./img" + str(i) + ".jpg" |
|
|
temp_img.save(save_path) |
|
|
|
|
|
|
|
|
raw_img_list.append(save_path) |
|
|
|
|
|
context = question |
|
|
|
|
|
|
|
|
for i in raw_img_list: |
|
|
context = "<Img><HereForImage></Img> " + context |
|
|
|
|
|
print("Question: ", context) |
|
|
|
|
|
|
|
|
llm_message = chat.answer(raw_img_list, context) |
|
|
|
|
|
print("Answer: ", llm_message) |
|
|
|
|
|
return raw_img_list,llm_message |
|
|
|
|
|
|
|
|
iface = gr.Interface(fn=respond, title="Cheetah", inputs=[gr.File(file_count='multiple',file_types=['image'],type='binary'), "text"], outputs=[gr.Gallery(),"text"]) |
|
|
iface.launch() |