asdfasdfdsafdsa commited on
Commit
6f74e93
·
verified ·
1 Parent(s): 93ac856

Fix missing config attributes and imports

Browse files
app.py CHANGED
@@ -56,6 +56,7 @@ class Config:
56
  # Visual backbone
57
  self.visual_backbone = "ResNet10"
58
  self.diagram_size = 128
 
59
 
60
  # Encoder
61
  self.encoder_type = "gru"
@@ -80,6 +81,9 @@ class Config:
80
 
81
  # Dataset
82
  self.without_stru = False
 
 
 
83
 
84
  # Initialize model
85
  def load_model():
 
56
  # Visual backbone
57
  self.visual_backbone = "ResNet10"
58
  self.diagram_size = 128
59
+ self.pretrain_vis_path = '' # Added missing attribute
60
 
61
  # Encoder
62
  self.encoder_type = "gru"
 
81
 
82
  # Dataset
83
  self.without_stru = False
84
+
85
+ # Logger (dummy for compatibility)
86
+ self.logger = type('obj', (object,), {'info': lambda x: print(x)})
87
 
88
  # Initialize model
89
  def load_model():
model/backbone/__init__.py CHANGED
@@ -1,16 +1,18 @@
1
  from .resnet import *
2
  from .mobilenet_v2 import *
3
- from config import visual_backbone_list
4
 
 
5
 
6
  def get_visual_backbone(args):
7
  if args.visual_backbone in visual_backbone_list:
8
  model = eval(args.visual_backbone)()
9
- if args.pretrain_vis_path !="":
10
  model.load_model(pretrain=args.pretrain_vis_path)
11
- args.logger.info("Visual backbone has been loaded...")
 
12
  else:
13
- args.logger.info("Visual backbone choose to train from scratch")
 
14
  return model
15
  else:
16
  raise NotImplementedError("Unsupported Backbone: {}".format(args.visual_backbone))
 
1
  from .resnet import *
2
  from .mobilenet_v2 import *
 
3
 
4
+ visual_backbone_list = ['ResNet10', 'mobilenet_v2']
5
 
6
  def get_visual_backbone(args):
7
  if args.visual_backbone in visual_backbone_list:
8
  model = eval(args.visual_backbone)()
9
+ if hasattr(args, 'pretrain_vis_path') and args.pretrain_vis_path != "":
10
  model.load_model(pretrain=args.pretrain_vis_path)
11
+ if hasattr(args, 'logger'):
12
+ args.logger.info("Visual backbone has been loaded...")
13
  else:
14
+ if hasattr(args, 'logger'):
15
+ args.logger.info("Visual backbone choose to train from scratch")
16
  return model
17
  else:
18
  raise NotImplementedError("Unsupported Backbone: {}".format(args.visual_backbone))
model/decoder/__init__.py CHANGED
@@ -1,9 +1,10 @@
1
  # from .transformer import TransformerModel
2
- from config import decoder_list
3
  from .rnn_decoder import DecoderRNN
4
  from .tree_decoder import TreeDecoder
5
  from .transformer import TransformerDecoder
6
 
 
 
7
  def get_decoder(params, *args):
8
 
9
  if not params.decoder_type in decoder_list:
 
1
  # from .transformer import TransformerModel
 
2
  from .rnn_decoder import DecoderRNN
3
  from .tree_decoder import TreeDecoder
4
  from .transformer import TransformerDecoder
5
 
6
+ decoder_list = ["rnn_decoder", "tree_decoder", "transformer"]
7
+
8
  def get_decoder(params, *args):
9
 
10
  if not params.decoder_type in decoder_list:
model/encoder/__init__.py CHANGED
@@ -1,8 +1,9 @@
1
  from .lstm import LSTM
2
  from .gru import GRU
3
- from config import encoder_list
4
  from .transformer import TransformerEncoder
5
 
 
 
6
  def get_encoder(params, *args):
7
 
8
  if not params.encoder_type in encoder_list:
 
1
  from .lstm import LSTM
2
  from .gru import GRU
 
3
  from .transformer import TransformerEncoder
4
 
5
+ encoder_list = ['lstm', 'gru', 'transformer']
6
+
7
  def get_encoder(params, *args):
8
 
9
  if not params.encoder_type in encoder_list: