import gradio as gr
from PIL import Image
import io
import argparse
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from omegaconf import OmegaConf
from cheetah.common.config import Config
from cheetah.common.registry import registry
from cheetah.conversation.conversation_llama2 import Chat, CONV_VISION
from cheetah.models import *
from cheetah.processors import *
# Initialize Cheetah
print('Initializing Chat')
config = OmegaConf.load('eval_configs/cheetah_eval_llama2.yaml')
cfg = Config.build_model_config(config)
model_cls = registry.get_model_class(cfg.model.arch)
model = model_cls.from_config(cfg.model).to('cuda:{}'.format(0))
vis_processor_cfg = cfg.preprocess.vis_processor.eval
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(0))
print('Initialization Finished')
# Respond to user
def respond(imgs, question):
# List of image paths
raw_img_list = []
# Go through each image
for i, img in enumerate(imgs):
# Open image
temp_img = Image.open(io.BytesIO(img))
# Convert to RGB
temp_img = temp_img.convert(mode='RGB')
# Save image
save_path = "./img" + str(i) + ".jpg"
temp_img.save(save_path)
# Add to image path list
raw_img_list.append(save_path)
context = question
# Format:
Question
for i in raw_img_list:
context = "
" + context
print("Question: ", context)
# Get response from Cheetah
llm_message = chat.answer(raw_img_list, context)
print("Answer: ", llm_message)
return raw_img_list,llm_message
iface = gr.Interface(fn=respond, title="Cheetah", inputs=[gr.File(file_count='multiple',file_types=['image'],type='binary'), "text"], outputs=[gr.Gallery(),"text"])
iface.launch()