asdfasdfdsafdsa commited on
Commit
9658342
·
verified ·
1 Parent(s): 51e6305

Fix device handling - properly support both CPU and CUDA

Browse files
Files changed (2) hide show
  1. app.py +19 -12
  2. core/network.py +4 -2
app.py CHANGED
@@ -102,10 +102,13 @@ def load_model():
102
  model = Network(cfg, src_lang, tgt_lang)
103
 
104
  # Load pretrained weights if available
 
 
 
105
  if os.path.exists('./LM_MODEL.pth'):
106
  try:
107
- # Try loading with map_location for CPU compatibility
108
- checkpoint = torch.load('./LM_MODEL.pth', map_location='cpu')
109
  if 'state_dict' in checkpoint:
110
  state_dict = checkpoint['state_dict']
111
  else:
@@ -122,11 +125,15 @@ def load_model():
122
  print(f"Warning: Could not load full model weights: {e}")
123
  print("Continuing with randomly initialized weights")
124
 
 
125
  model.eval()
126
  return model, src_lang, tgt_lang, cfg
127
 
128
  # Process image and text
129
  def process_input(image, text_input, model, src_lang, tgt_lang, cfg):
 
 
 
130
  # Transform image
131
  diagram_transform = T_diagram.Compose([
132
  T_diagram.Resize(cfg.diagram_size),
@@ -135,7 +142,7 @@ def process_input(image, text_input, model, src_lang, tgt_lang, cfg):
135
  T_diagram.Normalize()
136
  ])
137
 
138
- diagram = diagram_transform(image).unsqueeze(0)
139
 
140
  # Process text input
141
  # Create a simple text structure
@@ -159,28 +166,28 @@ def process_input(image, text_input, model, src_lang, tgt_lang, cfg):
159
  # Get text indices
160
  text_token, text_sect_tag, text_class_tag = get_text_index(combine_text, src_lang)
161
 
162
- # Convert to tensors
163
  text_dict = {
164
- 'token': torch.LongTensor([text_token]),
165
- 'sect_tag': torch.LongTensor([text_sect_tag]),
166
- 'class_tag': torch.LongTensor([text_class_tag]),
167
- 'len': torch.LongTensor([len(text_token)])
168
  }
169
 
170
  # Get variables and arguments
171
  var_arg_positions, var_values, arg_values = get_var_arg(combine_text, cfg)
172
 
173
  var_dict = {
174
- 'pos': torch.LongTensor([var_arg_positions]),
175
- 'len': torch.LongTensor([len(var_arg_positions)]),
176
  'var_value': var_values,
177
  'arg_value': arg_values
178
  }
179
 
180
  # Create dummy expression dict for inference
181
  exp_dict = {
182
- 'exp': torch.LongTensor([[1]]), # SOS token
183
- 'len': torch.LongTensor([1]),
184
  'answer': 0
185
  }
186
 
 
102
  model = Network(cfg, src_lang, tgt_lang)
103
 
104
  # Load pretrained weights if available
105
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
106
+ print(f"Using device: {device}")
107
+
108
  if os.path.exists('./LM_MODEL.pth'):
109
  try:
110
+ # Load with proper device mapping
111
+ checkpoint = torch.load('./LM_MODEL.pth', map_location=device)
112
  if 'state_dict' in checkpoint:
113
  state_dict = checkpoint['state_dict']
114
  else:
 
125
  print(f"Warning: Could not load full model weights: {e}")
126
  print("Continuing with randomly initialized weights")
127
 
128
+ model = model.to(device)
129
  model.eval()
130
  return model, src_lang, tgt_lang, cfg
131
 
132
  # Process image and text
133
  def process_input(image, text_input, model, src_lang, tgt_lang, cfg):
134
+ # Get device
135
+ device = next(model.parameters()).device
136
+
137
  # Transform image
138
  diagram_transform = T_diagram.Compose([
139
  T_diagram.Resize(cfg.diagram_size),
 
142
  T_diagram.Normalize()
143
  ])
144
 
145
+ diagram = diagram_transform(image).unsqueeze(0).to(device)
146
 
147
  # Process text input
148
  # Create a simple text structure
 
166
  # Get text indices
167
  text_token, text_sect_tag, text_class_tag = get_text_index(combine_text, src_lang)
168
 
169
+ # Convert to tensors and move to device
170
  text_dict = {
171
+ 'token': torch.LongTensor([text_token]).to(device),
172
+ 'sect_tag': torch.LongTensor([text_sect_tag]).to(device),
173
+ 'class_tag': torch.LongTensor([text_class_tag]).to(device),
174
+ 'len': torch.LongTensor([len(text_token)]).to(device)
175
  }
176
 
177
  # Get variables and arguments
178
  var_arg_positions, var_values, arg_values = get_var_arg(combine_text, cfg)
179
 
180
  var_dict = {
181
+ 'pos': torch.LongTensor([var_arg_positions]).to(device),
182
+ 'len': torch.LongTensor([len(var_arg_positions)]).to(device),
183
  'var_value': var_values,
184
  'arg_value': arg_values
185
  }
186
 
187
  # Create dummy expression dict for inference
188
  exp_dict = {
189
+ 'exp': torch.LongTensor([[1]]).to(device), # SOS token
190
+ 'len': torch.LongTensor([1]).to(device),
191
  'answer': 0
192
  }
193
 
core/network.py CHANGED
@@ -43,8 +43,9 @@ class MLMTransformerPretrain(nn.Module):
43
  return transformer_outputs
44
 
45
  def load_model(self, model_path):
 
46
  pretrain_dict = torch.load(
47
- model_path, map_location="cpu"
48
  )
49
  pretrain_dict_model = pretrain_dict['state_dict'] \
50
  if 'state_dict' in pretrain_dict else pretrain_dict
@@ -156,8 +157,9 @@ class Network(nn.Module):
156
  p.requires_grad = False
157
 
158
  def load_model(self, model_path):
 
159
  pretrain_dict = torch.load(
160
- model_path, map_location="cpu"
161
  )
162
  pretrain_dict_model = pretrain_dict['state_dict'] \
163
  if 'state_dict' in pretrain_dict else pretrain_dict
 
43
  return transformer_outputs
44
 
45
  def load_model(self, model_path):
46
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
  pretrain_dict = torch.load(
48
+ model_path, map_location=device
49
  )
50
  pretrain_dict_model = pretrain_dict['state_dict'] \
51
  if 'state_dict' in pretrain_dict else pretrain_dict
 
157
  p.requires_grad = False
158
 
159
  def load_model(self, model_path):
160
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
161
  pretrain_dict = torch.load(
162
+ model_path, map_location=device
163
  )
164
  pretrain_dict_model = pretrain_dict['state_dict'] \
165
  if 'state_dict' in pretrain_dict else pretrain_dict