tonyassi commited on
Commit
229fe65
·
1 Parent(s): f47b968

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -0
app.py CHANGED
@@ -1,5 +1,57 @@
1
  import gradio as gr
2
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  def chat(img, question):
5
  img.save("img.jpg")
 
1
  import gradio as gr
2
  from PIL import Image
3
+ import argparse
4
+ import os
5
+ import random
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.backends.cudnn as cudnn
10
+
11
+ from omegaconf import OmegaConf
12
+ from cheetah.common.config import Config
13
+ from cheetah.common.registry import registry
14
+ from cheetah.conversation.conversation_llama2 import Chat, CONV_VISION
15
+
16
+ from cheetah.models import *
17
+ from cheetah.processors import *
18
+
19
+ def parse_args():
20
+ parser = argparse.ArgumentParser(description="Demo")
21
+ parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
22
+ parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
23
+ parser.add_argument(
24
+ "--options",
25
+ nargs="+",
26
+ help="override some settings in the used config, the key-value pair "
27
+ "in xxx=yyy format will be merged into config file (deprecate), "
28
+ "change to --cfg-options instead.",
29
+ )
30
+ args = parser.parse_args()
31
+ return args
32
+
33
+ def setup_seeds(seed = 50):
34
+ random.seed(seed)
35
+ np.random.seed(seed)
36
+ torch.manual_seed(seed)
37
+ torch.cuda.manual_seed_all(seed)
38
+ cudnn.benchmark = False
39
+ cudnn.deterministic = True
40
+
41
+ print('Initializing Chat')
42
+ #args = parse_args()
43
+ #args = Namespace(cfg_path='eval_configs/cheetah_eval_vicuna.yaml', gpu_id=0, options=None)
44
+
45
+ config = OmegaConf.load('eval_configs/cheetah_eval_llama2.yaml')
46
+ cfg = Config.build_model_config(config)
47
+ model_cls = registry.get_model_class(cfg.model.arch)
48
+ model = model_cls.from_config(cfg.model).to('cuda:{}'.format(0))
49
+
50
+ vis_processor_cfg = cfg.preprocess.vis_processor.eval
51
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
52
+ chat = Chat(model, vis_processor, device='cuda:{}'.format(0))
53
+ print('Initialization Finished')
54
+
55
 
56
  def chat(img, question):
57
  img.save("img.jpg")