import gradio as gr from PIL import Image import io import argparse import os import random import numpy as np import torch import torch.backends.cudnn as cudnn from omegaconf import OmegaConf from cheetah.common.config import Config from cheetah.common.registry import registry from cheetah.conversation.conversation_llama2 import Chat, CONV_VISION from cheetah.models import * from cheetah.processors import * # Initialize Cheetah print('Initializing Chat') config = OmegaConf.load('eval_configs/cheetah_eval_llama2.yaml') cfg = Config.build_model_config(config) model_cls = registry.get_model_class(cfg.model.arch) model = model_cls.from_config(cfg.model).to('cuda:{}'.format(0)) vis_processor_cfg = cfg.preprocess.vis_processor.eval vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) chat = Chat(model, vis_processor, device='cuda:{}'.format(0)) print('Initialization Finished') # Respond to user def respond(imgs, question): # List of image paths raw_img_list = [] # Go through each image for i, img in enumerate(imgs): # Open image temp_img = Image.open(io.BytesIO(img)) # Convert to RGB temp_img = temp_img.convert(mode='RGB') # Save image save_path = "./img" + str(i) + ".jpg" temp_img.save(save_path) # Add to image path list raw_img_list.append(save_path) context = question # Format: Question for i in raw_img_list: context = " " + context print("Question: ", context) # Get response from Cheetah llm_message = chat.answer(raw_img_list, context) print("Answer: ", llm_message) return raw_img_list,llm_message iface = gr.Interface(fn=respond, title="Cheetah", inputs=[gr.File(file_count='multiple',file_types=['image'],type='binary'), "text"], outputs=[gr.Gallery(),"text"]) iface.launch()