Spaces:
Sleeping
Sleeping
Fix missing config attributes and imports
Browse files- app.py +4 -0
- model/backbone/__init__.py +6 -4
- model/decoder/__init__.py +2 -1
- model/encoder/__init__.py +2 -1
app.py
CHANGED
|
@@ -56,6 +56,7 @@ class Config:
|
|
| 56 |
# Visual backbone
|
| 57 |
self.visual_backbone = "ResNet10"
|
| 58 |
self.diagram_size = 128
|
|
|
|
| 59 |
|
| 60 |
# Encoder
|
| 61 |
self.encoder_type = "gru"
|
|
@@ -80,6 +81,9 @@ class Config:
|
|
| 80 |
|
| 81 |
# Dataset
|
| 82 |
self.without_stru = False
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
# Initialize model
|
| 85 |
def load_model():
|
|
|
|
| 56 |
# Visual backbone
|
| 57 |
self.visual_backbone = "ResNet10"
|
| 58 |
self.diagram_size = 128
|
| 59 |
+
self.pretrain_vis_path = '' # Added missing attribute
|
| 60 |
|
| 61 |
# Encoder
|
| 62 |
self.encoder_type = "gru"
|
|
|
|
| 81 |
|
| 82 |
# Dataset
|
| 83 |
self.without_stru = False
|
| 84 |
+
|
| 85 |
+
# Logger (dummy for compatibility)
|
| 86 |
+
self.logger = type('obj', (object,), {'info': lambda x: print(x)})
|
| 87 |
|
| 88 |
# Initialize model
|
| 89 |
def load_model():
|
model/backbone/__init__.py
CHANGED
|
@@ -1,16 +1,18 @@
|
|
| 1 |
from .resnet import *
|
| 2 |
from .mobilenet_v2 import *
|
| 3 |
-
from config import visual_backbone_list
|
| 4 |
|
|
|
|
| 5 |
|
| 6 |
def get_visual_backbone(args):
|
| 7 |
if args.visual_backbone in visual_backbone_list:
|
| 8 |
model = eval(args.visual_backbone)()
|
| 9 |
-
if args.pretrain_vis_path !="":
|
| 10 |
model.load_model(pretrain=args.pretrain_vis_path)
|
| 11 |
-
|
|
|
|
| 12 |
else:
|
| 13 |
-
|
|
|
|
| 14 |
return model
|
| 15 |
else:
|
| 16 |
raise NotImplementedError("Unsupported Backbone: {}".format(args.visual_backbone))
|
|
|
|
| 1 |
from .resnet import *
|
| 2 |
from .mobilenet_v2 import *
|
|
|
|
| 3 |
|
| 4 |
+
visual_backbone_list = ['ResNet10', 'mobilenet_v2']
|
| 5 |
|
| 6 |
def get_visual_backbone(args):
|
| 7 |
if args.visual_backbone in visual_backbone_list:
|
| 8 |
model = eval(args.visual_backbone)()
|
| 9 |
+
if hasattr(args, 'pretrain_vis_path') and args.pretrain_vis_path != "":
|
| 10 |
model.load_model(pretrain=args.pretrain_vis_path)
|
| 11 |
+
if hasattr(args, 'logger'):
|
| 12 |
+
args.logger.info("Visual backbone has been loaded...")
|
| 13 |
else:
|
| 14 |
+
if hasattr(args, 'logger'):
|
| 15 |
+
args.logger.info("Visual backbone choose to train from scratch")
|
| 16 |
return model
|
| 17 |
else:
|
| 18 |
raise NotImplementedError("Unsupported Backbone: {}".format(args.visual_backbone))
|
model/decoder/__init__.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
| 1 |
# from .transformer import TransformerModel
|
| 2 |
-
from config import decoder_list
|
| 3 |
from .rnn_decoder import DecoderRNN
|
| 4 |
from .tree_decoder import TreeDecoder
|
| 5 |
from .transformer import TransformerDecoder
|
| 6 |
|
|
|
|
|
|
|
| 7 |
def get_decoder(params, *args):
|
| 8 |
|
| 9 |
if not params.decoder_type in decoder_list:
|
|
|
|
| 1 |
# from .transformer import TransformerModel
|
|
|
|
| 2 |
from .rnn_decoder import DecoderRNN
|
| 3 |
from .tree_decoder import TreeDecoder
|
| 4 |
from .transformer import TransformerDecoder
|
| 5 |
|
| 6 |
+
decoder_list = ["rnn_decoder", "tree_decoder", "transformer"]
|
| 7 |
+
|
| 8 |
def get_decoder(params, *args):
|
| 9 |
|
| 10 |
if not params.decoder_type in decoder_list:
|
model/encoder/__init__.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
from .lstm import LSTM
|
| 2 |
from .gru import GRU
|
| 3 |
-
from config import encoder_list
|
| 4 |
from .transformer import TransformerEncoder
|
| 5 |
|
|
|
|
|
|
|
| 6 |
def get_encoder(params, *args):
|
| 7 |
|
| 8 |
if not params.encoder_type in encoder_list:
|
|
|
|
| 1 |
from .lstm import LSTM
|
| 2 |
from .gru import GRU
|
|
|
|
| 3 |
from .transformer import TransformerEncoder
|
| 4 |
|
| 5 |
+
encoder_list = ['lstm', 'gru', 'transformer']
|
| 6 |
+
|
| 7 |
def get_encoder(params, *args):
|
| 8 |
|
| 9 |
if not params.encoder_type in encoder_list:
|