File size: 1,983 Bytes
dfe8075
57c3c27
2d1dca7
229fe65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b4c813
229fe65
 
 
 
 
 
 
 
 
 
 
 
dfe8075
9b4c813
2d1dca7
dfe8075
9b4c813
2d1dca7
9b4c813
 
2d1dca7
9b4c813
2d1dca7
9b4c813
 
5d8219d
9b4c813
 
2d1dca7
 
9b4c813
 
2d1dca7
7b7a1c6
 
 
9b4c813
7b7a1c6
 
 
 
9b4c813
 
7b7a1c6
9b4c813
7b7a1c6
 
e33a302
6fd7103
9b4c813
e33a302
6fd7103
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
from PIL import Image 
import io
import argparse
import os
import random

import numpy as np
import torch
import torch.backends.cudnn as cudnn

from omegaconf import OmegaConf
from cheetah.common.config import Config
from cheetah.common.registry import registry
from cheetah.conversation.conversation_llama2 import Chat, CONV_VISION

from cheetah.models import *
from cheetah.processors import *

# Initialize Cheetah
print('Initializing Chat')

config = OmegaConf.load('eval_configs/cheetah_eval_llama2.yaml')
cfg = Config.build_model_config(config)
model_cls = registry.get_model_class(cfg.model.arch)
model = model_cls.from_config(cfg.model).to('cuda:{}'.format(0))

vis_processor_cfg = cfg.preprocess.vis_processor.eval
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(0))
print('Initialization Finished')


# Respond to user
def respond(imgs, question):

    # List of image paths
    raw_img_list = []

    # Go through each image
    for i, img in enumerate(imgs):
        # Open image
        temp_img = Image.open(io.BytesIO(img))

        # Convert to RGB
        temp_img = temp_img.convert(mode='RGB')

        # Save image
        save_path = "./img" + str(i) + ".jpg"
        temp_img.save(save_path)

        # Add to image path list
        raw_img_list.append(save_path)
    
    context = question

    # Format: <Img><HereForImage></Img> <Img><HereForImage></Img> Question
    for i in raw_img_list:
        context = "<Img><HereForImage></Img> " + context

    print("Question: ", context)

    # Get response from Cheetah
    llm_message = chat.answer(raw_img_list, context)
    
    print("Answer: ", llm_message)

    return raw_img_list,llm_message


iface = gr.Interface(fn=respond, title="Cheetah", inputs=[gr.File(file_count='multiple',file_types=['image'],type='binary'), "text"], outputs=[gr.Gallery(),"text"])
iface.launch()