KangLiao commited on
Commit
748bbd2
·
1 Parent(s): 2d1f86e
app.py CHANGED
@@ -7,9 +7,6 @@ import math
7
  import re
8
  from einops import rearrange
9
  from mmengine.config import Config
10
- from xtuner.registry import BUILDER
11
- from xtuner.model.utils import guess_load_checkpoint
12
-
13
 
14
  import matplotlib
15
  matplotlib.use("Agg")
@@ -18,6 +15,10 @@ import matplotlib.pyplot as plt
18
  from scripts.camera.cam_dataset import Cam_Generator
19
  from scripts.camera.visualization.visualize_batch import make_perspective_figures
20
 
 
 
 
 
21
 
22
  NUM = r"[+-]?(?:\d+(?:\.\d+)?|\.\d+)(?:[eE][+-]?\d+)?"
23
  CAM_PATTERN = re.compile(r"(?:camera parameters.*?:|roll.*?:)\s*("+NUM+r")\s*,\s*("+NUM+r")\s*,\s*("+NUM+r")", re.IGNORECASE|re.DOTALL)
@@ -35,8 +36,8 @@ config = "configs/pipelines/stage_2_base.py"
35
  config = Config.fromfile(config)
36
  model = BUILDER.build(config.model).eval()
37
  checkpoint_path = "checkpoints/Puffin-Base.pth"
38
- state_dict = guess_load_checkpoint(checkpoint_path)
39
- model.load_state_dict(state_dict, strict=False)
40
 
41
  if torch.cuda.is_available():
42
  model = model.to(torch.bfloat16).cuda()
 
7
  import re
8
  from einops import rearrange
9
  from mmengine.config import Config
 
 
 
10
 
11
  import matplotlib
12
  matplotlib.use("Agg")
 
15
  from scripts.camera.cam_dataset import Cam_Generator
16
  from scripts.camera.visualization.visualize_batch import make_perspective_figures
17
 
18
+ from mmengine.registry import Registry
19
+ __all__ = ['BUILDER']
20
+ BUILDER = Registry('builder')
21
+
22
 
23
  NUM = r"[+-]?(?:\d+(?:\.\d+)?|\.\d+)(?:[eE][+-]?\d+)?"
24
  CAM_PATTERN = re.compile(r"(?:camera parameters.*?:|roll.*?:)\s*("+NUM+r")\s*,\s*("+NUM+r")\s*,\s*("+NUM+r")", re.IGNORECASE|re.DOTALL)
 
36
  config = Config.fromfile(config)
37
  model = BUILDER.build(config.model).eval()
38
  checkpoint_path = "checkpoints/Puffin-Base.pth"
39
+ checkpoint = torch.load(checkpoint_path)
40
+ info = model.load_state_dict(checkpoint, strict=False)
41
 
42
  if torch.cuda.is_available():
43
  model = model.to(torch.bfloat16).cuda()
configs/models/qwen2_5_1_5b_radio_sd3_dynamic_puffin.py CHANGED
@@ -29,13 +29,13 @@ model = dict(type=Qwen2p5RadioStableDiffusion3HFDynamic,
29
  hidden_size=1024,
30
  intermediate_size=4096,
31
  num_hidden_layers=6,
32
- _attn_implementation='flash_attention_2',
33
  num_attention_heads=16, ),
34
  connector_2=dict(
35
  hidden_size=1024,
36
  intermediate_size=4096,
37
  num_hidden_layers=6,
38
- _attn_implementation='flash_attention_2',
39
  num_attention_heads=16, ),
40
  transformer=dict(
41
  type=SD3Transformer2DModel.from_pretrained,
@@ -61,7 +61,7 @@ model = dict(type=Qwen2p5RadioStableDiffusion3HFDynamic,
61
  type=AutoModelForCausalLM.from_pretrained,
62
  pretrained_model_name_or_path=llm_name_or_path,
63
  torch_dtype=torch.bfloat16,
64
- attn_implementation='flash_attention_2',
65
  ),
66
  tokenizer=dict(
67
  type=AutoTokenizer.from_pretrained,
 
29
  hidden_size=1024,
30
  intermediate_size=4096,
31
  num_hidden_layers=6,
32
+ #_attn_implementation='flash_attention_2',
33
  num_attention_heads=16, ),
34
  connector_2=dict(
35
  hidden_size=1024,
36
  intermediate_size=4096,
37
  num_hidden_layers=6,
38
+ #_attn_implementation='flash_attention_2',
39
  num_attention_heads=16, ),
40
  transformer=dict(
41
  type=SD3Transformer2DModel.from_pretrained,
 
61
  type=AutoModelForCausalLM.from_pretrained,
62
  pretrained_model_name_or_path=llm_name_or_path,
63
  torch_dtype=torch.bfloat16,
64
+ #attn_implementation='flash_attention_2',
65
  ),
66
  tokenizer=dict(
67
  type=AutoTokenizer.from_pretrained,