tonyassi commited on
Commit
9b4c813
·
1 Parent(s): 9feffc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -28
app.py CHANGED
@@ -17,33 +17,8 @@ from cheetah.conversation.conversation_llama2 import Chat, CONV_VISION
17
  from cheetah.models import *
18
  from cheetah.processors import *
19
 
20
- '''
21
- def parse_args():
22
- parser = argparse.ArgumentParser(description="Demo")
23
- parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
24
- parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
25
- parser.add_argument(
26
- "--options",
27
- nargs="+",
28
- help="override some settings in the used config, the key-value pair "
29
- "in xxx=yyy format will be merged into config file (deprecate), "
30
- "change to --cfg-options instead.",
31
- )
32
- args = parser.parse_args()
33
- return args
34
-
35
- def setup_seeds(seed = 50):
36
- random.seed(seed)
37
- np.random.seed(seed)
38
- torch.manual_seed(seed)
39
- torch.cuda.manual_seed_all(seed)
40
- cudnn.benchmark = False
41
- cudnn.deterministic = True
42
- '''
43
-
44
  print('Initializing Chat')
45
- #args = parse_args()
46
- #args = Namespace(cfg_path='eval_configs/cheetah_eval_vicuna.yaml', gpu_id=0, options=None)
47
 
48
  config = OmegaConf.load('eval_configs/cheetah_eval_llama2.yaml')
49
  cfg = Config.build_model_config(config)
@@ -56,28 +31,42 @@ chat = Chat(model, vis_processor, device='cuda:{}'.format(0))
56
  print('Initialization Finished')
57
 
58
 
 
59
  def respond(imgs, question):
60
 
 
61
  raw_img_list = []
62
-
 
63
  for i, img in enumerate(imgs):
 
64
  temp_img = Image.open(io.BytesIO(img))
 
 
65
  temp_img = temp_img.convert(mode='RGB')
 
 
66
  save_path = "./img" + str(i) + ".jpg"
67
  temp_img.save(save_path)
 
 
68
  raw_img_list.append(save_path)
69
 
70
  context = question
71
 
 
72
  for i in raw_img_list:
73
  context = "<Img><HereForImage></Img> " + context
74
 
75
  print("Question: ", context)
 
 
76
  llm_message = chat.answer(raw_img_list, context)
 
77
  print("Answer: ", llm_message)
78
 
79
  return raw_img_list,llm_message
80
 
81
- #iface = gr.Interface(fn=respond, title="Cheetah", inputs=[gr.Image(type="pil"), "text"], outputs="text")
82
  iface = gr.Interface(fn=respond, title="Cheetah", inputs=[gr.File(file_count='multiple',file_types=['image'],type='binary'), "text"], outputs=[gr.Gallery(),"text"])
83
  iface.launch()
 
17
  from cheetah.models import *
18
  from cheetah.processors import *
19
 
20
+ # Initialize Cheetah
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  print('Initializing Chat')
 
 
22
 
23
  config = OmegaConf.load('eval_configs/cheetah_eval_llama2.yaml')
24
  cfg = Config.build_model_config(config)
 
31
  print('Initialization Finished')
32
 
33
 
34
+ # Respond to user
35
  def respond(imgs, question):
36
 
37
+ # List of image paths
38
  raw_img_list = []
39
+
40
+ # Go through each image
41
  for i, img in enumerate(imgs):
42
+ # Open image
43
  temp_img = Image.open(io.BytesIO(img))
44
+
45
+ # Convert to RGB
46
  temp_img = temp_img.convert(mode='RGB')
47
+
48
+ # Save image
49
  save_path = "./img" + str(i) + ".jpg"
50
  temp_img.save(save_path)
51
+
52
+ # Add to image path list
53
  raw_img_list.append(save_path)
54
 
55
  context = question
56
 
57
+ # Format: <Img><HereForImage></Img> <Img><HereForImage></Img> Question
58
  for i in raw_img_list:
59
  context = "<Img><HereForImage></Img> " + context
60
 
61
  print("Question: ", context)
62
+
63
+ # Get response from Cheetah
64
  llm_message = chat.answer(raw_img_list, context)
65
+
66
  print("Answer: ", llm_message)
67
 
68
  return raw_img_list,llm_message
69
 
70
+
71
  iface = gr.Interface(fn=respond, title="Cheetah", inputs=[gr.File(file_count='multiple',file_types=['image'],type='binary'), "text"], outputs=[gr.Gallery(),"text"])
72
  iface.launch()