Spaces:
Sleeping
Sleeping
Fix device handling - properly support both CPU and CUDA
Browse files- app.py +19 -12
- core/network.py +4 -2
app.py
CHANGED
|
@@ -102,10 +102,13 @@ def load_model():
|
|
| 102 |
model = Network(cfg, src_lang, tgt_lang)
|
| 103 |
|
| 104 |
# Load pretrained weights if available
|
|
|
|
|
|
|
|
|
|
| 105 |
if os.path.exists('./LM_MODEL.pth'):
|
| 106 |
try:
|
| 107 |
-
#
|
| 108 |
-
checkpoint = torch.load('./LM_MODEL.pth', map_location=
|
| 109 |
if 'state_dict' in checkpoint:
|
| 110 |
state_dict = checkpoint['state_dict']
|
| 111 |
else:
|
|
@@ -122,11 +125,15 @@ def load_model():
|
|
| 122 |
print(f"Warning: Could not load full model weights: {e}")
|
| 123 |
print("Continuing with randomly initialized weights")
|
| 124 |
|
|
|
|
| 125 |
model.eval()
|
| 126 |
return model, src_lang, tgt_lang, cfg
|
| 127 |
|
| 128 |
# Process image and text
|
| 129 |
def process_input(image, text_input, model, src_lang, tgt_lang, cfg):
|
|
|
|
|
|
|
|
|
|
| 130 |
# Transform image
|
| 131 |
diagram_transform = T_diagram.Compose([
|
| 132 |
T_diagram.Resize(cfg.diagram_size),
|
|
@@ -135,7 +142,7 @@ def process_input(image, text_input, model, src_lang, tgt_lang, cfg):
|
|
| 135 |
T_diagram.Normalize()
|
| 136 |
])
|
| 137 |
|
| 138 |
-
diagram = diagram_transform(image).unsqueeze(0)
|
| 139 |
|
| 140 |
# Process text input
|
| 141 |
# Create a simple text structure
|
|
@@ -159,28 +166,28 @@ def process_input(image, text_input, model, src_lang, tgt_lang, cfg):
|
|
| 159 |
# Get text indices
|
| 160 |
text_token, text_sect_tag, text_class_tag = get_text_index(combine_text, src_lang)
|
| 161 |
|
| 162 |
-
# Convert to tensors
|
| 163 |
text_dict = {
|
| 164 |
-
'token': torch.LongTensor([text_token]),
|
| 165 |
-
'sect_tag': torch.LongTensor([text_sect_tag]),
|
| 166 |
-
'class_tag': torch.LongTensor([text_class_tag]),
|
| 167 |
-
'len': torch.LongTensor([len(text_token)])
|
| 168 |
}
|
| 169 |
|
| 170 |
# Get variables and arguments
|
| 171 |
var_arg_positions, var_values, arg_values = get_var_arg(combine_text, cfg)
|
| 172 |
|
| 173 |
var_dict = {
|
| 174 |
-
'pos': torch.LongTensor([var_arg_positions]),
|
| 175 |
-
'len': torch.LongTensor([len(var_arg_positions)]),
|
| 176 |
'var_value': var_values,
|
| 177 |
'arg_value': arg_values
|
| 178 |
}
|
| 179 |
|
| 180 |
# Create dummy expression dict for inference
|
| 181 |
exp_dict = {
|
| 182 |
-
'exp': torch.LongTensor([[1]]), # SOS token
|
| 183 |
-
'len': torch.LongTensor([1]),
|
| 184 |
'answer': 0
|
| 185 |
}
|
| 186 |
|
|
|
|
| 102 |
model = Network(cfg, src_lang, tgt_lang)
|
| 103 |
|
| 104 |
# Load pretrained weights if available
|
| 105 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 106 |
+
print(f"Using device: {device}")
|
| 107 |
+
|
| 108 |
if os.path.exists('./LM_MODEL.pth'):
|
| 109 |
try:
|
| 110 |
+
# Load with proper device mapping
|
| 111 |
+
checkpoint = torch.load('./LM_MODEL.pth', map_location=device)
|
| 112 |
if 'state_dict' in checkpoint:
|
| 113 |
state_dict = checkpoint['state_dict']
|
| 114 |
else:
|
|
|
|
| 125 |
print(f"Warning: Could not load full model weights: {e}")
|
| 126 |
print("Continuing with randomly initialized weights")
|
| 127 |
|
| 128 |
+
model = model.to(device)
|
| 129 |
model.eval()
|
| 130 |
return model, src_lang, tgt_lang, cfg
|
| 131 |
|
| 132 |
# Process image and text
|
| 133 |
def process_input(image, text_input, model, src_lang, tgt_lang, cfg):
|
| 134 |
+
# Get device
|
| 135 |
+
device = next(model.parameters()).device
|
| 136 |
+
|
| 137 |
# Transform image
|
| 138 |
diagram_transform = T_diagram.Compose([
|
| 139 |
T_diagram.Resize(cfg.diagram_size),
|
|
|
|
| 142 |
T_diagram.Normalize()
|
| 143 |
])
|
| 144 |
|
| 145 |
+
diagram = diagram_transform(image).unsqueeze(0).to(device)
|
| 146 |
|
| 147 |
# Process text input
|
| 148 |
# Create a simple text structure
|
|
|
|
| 166 |
# Get text indices
|
| 167 |
text_token, text_sect_tag, text_class_tag = get_text_index(combine_text, src_lang)
|
| 168 |
|
| 169 |
+
# Convert to tensors and move to device
|
| 170 |
text_dict = {
|
| 171 |
+
'token': torch.LongTensor([text_token]).to(device),
|
| 172 |
+
'sect_tag': torch.LongTensor([text_sect_tag]).to(device),
|
| 173 |
+
'class_tag': torch.LongTensor([text_class_tag]).to(device),
|
| 174 |
+
'len': torch.LongTensor([len(text_token)]).to(device)
|
| 175 |
}
|
| 176 |
|
| 177 |
# Get variables and arguments
|
| 178 |
var_arg_positions, var_values, arg_values = get_var_arg(combine_text, cfg)
|
| 179 |
|
| 180 |
var_dict = {
|
| 181 |
+
'pos': torch.LongTensor([var_arg_positions]).to(device),
|
| 182 |
+
'len': torch.LongTensor([len(var_arg_positions)]).to(device),
|
| 183 |
'var_value': var_values,
|
| 184 |
'arg_value': arg_values
|
| 185 |
}
|
| 186 |
|
| 187 |
# Create dummy expression dict for inference
|
| 188 |
exp_dict = {
|
| 189 |
+
'exp': torch.LongTensor([[1]]).to(device), # SOS token
|
| 190 |
+
'len': torch.LongTensor([1]).to(device),
|
| 191 |
'answer': 0
|
| 192 |
}
|
| 193 |
|
core/network.py
CHANGED
|
@@ -43,8 +43,9 @@ class MLMTransformerPretrain(nn.Module):
|
|
| 43 |
return transformer_outputs
|
| 44 |
|
| 45 |
def load_model(self, model_path):
|
|
|
|
| 46 |
pretrain_dict = torch.load(
|
| 47 |
-
model_path, map_location=
|
| 48 |
)
|
| 49 |
pretrain_dict_model = pretrain_dict['state_dict'] \
|
| 50 |
if 'state_dict' in pretrain_dict else pretrain_dict
|
|
@@ -156,8 +157,9 @@ class Network(nn.Module):
|
|
| 156 |
p.requires_grad = False
|
| 157 |
|
| 158 |
def load_model(self, model_path):
|
|
|
|
| 159 |
pretrain_dict = torch.load(
|
| 160 |
-
model_path, map_location=
|
| 161 |
)
|
| 162 |
pretrain_dict_model = pretrain_dict['state_dict'] \
|
| 163 |
if 'state_dict' in pretrain_dict else pretrain_dict
|
|
|
|
| 43 |
return transformer_outputs
|
| 44 |
|
| 45 |
def load_model(self, model_path):
|
| 46 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 47 |
pretrain_dict = torch.load(
|
| 48 |
+
model_path, map_location=device
|
| 49 |
)
|
| 50 |
pretrain_dict_model = pretrain_dict['state_dict'] \
|
| 51 |
if 'state_dict' in pretrain_dict else pretrain_dict
|
|
|
|
| 157 |
p.requires_grad = False
|
| 158 |
|
| 159 |
def load_model(self, model_path):
|
| 160 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 161 |
pretrain_dict = torch.load(
|
| 162 |
+
model_path, map_location=device
|
| 163 |
)
|
| 164 |
pretrain_dict_model = pretrain_dict['state_dict'] \
|
| 165 |
if 'state_dict' in pretrain_dict else pretrain_dict
|