Honzus24 commited on
Commit
6886821
·
verified ·
1 Parent(s): c707c1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -48
app.py CHANGED
@@ -92,14 +92,13 @@ def process_pdb_file(pdb_file, backbones, sequences, names):
92
  names.append(_name)
93
  return backbones, sequences, names
94
 
95
- @spaces.GPU
96
- def flex_seq(input_seq, input_file):
97
  if not input_seq:
98
  input_seq = ""
99
-
100
  if not input_seq.strip() and not input_file:
101
- return None, "Provide a file/s or a input sequence/s"
102
-
103
  if input_file:
104
  if len(input_file) == 1:
105
  input_file = input_file[0]
@@ -109,13 +108,7 @@ def flex_seq(input_seq, input_file):
109
 
110
  default_name = '{}'.format(datetime.now().strftime('%Y%m%d_%H%M%S'))
111
  output_name = default_name
112
-
113
- sequences = []
114
- names = []
115
- backbones = []
116
- flucts_list = []
117
- pdb_files = []
118
-
119
  datapoint_for_eval = 'all'
120
 
121
  if input_seq:
@@ -129,12 +122,10 @@ def flex_seq(input_seq, input_file):
129
  sequence = proteins[record+1]
130
  else:
131
  raise ValueError("You must adhere to the .fasta format")
132
-
133
  if datapoint_for_eval == 'all':
134
  names.append(name)
135
  sequences.append(sequence)
136
  backbones.append(None)
137
-
138
  elif suffix == ".fasta":
139
  for record in SeqIO.parse(input_file, "fasta"):
140
  name = record.name
@@ -142,59 +133,56 @@ def flex_seq(input_seq, input_file):
142
  names.append(name)
143
  sequences.append(str(record.seq))
144
  backbones.append(None)
145
-
146
  elif suffix == ".pdb":
147
  backbones, sequences, names = process_pdb_file(input_file, backbones, sequences, names)
148
  pdb_files.append(input_file)
149
-
150
  elif suffix == ".pdb_list":
151
  for i in input_file:
152
  backbones, sequences, names = process_pdb_file(i, backbones, sequences, names)
153
  pdb_files.append(i)
154
-
155
  env_config = yaml.load(open('configs/env_config.yaml', 'r'), Loader=yaml.FullLoader)
156
- # Set folder for huggingface cache
157
  os.environ['HF_HOME'] = env_config['huggingface']['HF_HOME']
158
- # Set gpu device
159
- os.environ["CUDA_VISIBLE_DEVICES"]= env_config['gpus']['cuda_visible_device']
160
-
161
  config = yaml.load(open('configs/train_config.yaml', 'r'), Loader=yaml.FullLoader)
162
- class_config=ClassConfig(config)
163
  class_config.adaptor_architecture = 'no-adaptor'
164
- config['inference_args']['device'] = config['inference_args']['device'] if torch.cuda.is_available() else 'cpu'
 
 
 
 
 
 
 
165
  model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)
166
- model.to(config['inference_args']['device'])
 
167
  repo_id = "Honzus24/Flexpert_weights"
168
  file_weights = config['inference_args']['seq_model_path']
169
-
170
- # Get path (instant if cached)
171
  weights_path = get_weights_path(repo_id, file_weights)
172
 
173
- # Load weights
174
- state_dict = torch.load(weights_path, map_location=config['inference_args']['device'])
175
  model.load_state_dict(state_dict, strict=False)
176
  model.eval()
177
-
178
  data_to_collate = []
179
  for idx, (backbone, sequence) in enumerate(zip(backbones, sequences)):
180
- #Ensure that the missing residues in the sequence are not represented as '-' but as 'X'
181
- sequence = sequence.replace('-', 'X') #due to the tokenizer vocabulary
182
-
183
  tokenizer_out = tokenizer(' '.join(sequence), add_special_tokens=True, return_tensors='pt')
184
- tokenized_seq, attention_mask = tokenizer_out['input_ids'].to(config['inference_args']['device']), tokenizer_out['attention_mask'].to(config['inference_args']['device'])
185
-
186
  data_to_collate.append({'input_ids': tokenized_seq[0,:], 'attention_mask': attention_mask[0,:]})
187
-
188
  data_collator = DataCollatorForTokenRegression(tokenizer)
189
- batch = data_collator(data_to_collate) # Wrap in list since collator expects batch
190
- batch.to(model.device)
191
 
192
- # Predict
193
  with torch.no_grad():
194
  output_logits = process_in_batches_and_combine(model, batch, config['inference_args']['batch_size'])
195
- predictions = output_logits[:,:,0] #includes the prediction for the added token
196
- # subselect the predictions using the attention mask
197
-
198
  output_filename = Path(config['inference_args']['prediction_output_dir'].format(output_name, "seq"))
199
  output_filename.parent.mkdir(parents=True, exist_ok=True)
200
  output_files = []
@@ -205,11 +193,7 @@ def flex_seq(input_seq, input_file):
205
  with open(output_filename_new.with_suffix('.txt'), 'w') as f:
206
  f.write("Residue Number\tResidue ID\tFlexibility\n")
207
  prediction = prediction[mask.bool()]
208
- if len(prediction) != len(sequence)+1:
209
- print("Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1))
210
-
211
  assert len(prediction) == len(sequence)+1, "Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1)
212
-
213
  p = prediction.tolist()[:-1]
214
  for i in range(len(p)):
215
  f.write(f"{i:<10}\t{sequence[i]:<20}\t{round(p[i], 4):<10}\n")
@@ -220,19 +204,32 @@ def flex_seq(input_seq, input_file):
220
  _prediction = prediction[:-1].reshape(1,-1)
221
  _outname = output_filename.with_name('{}_'.format(name.split("/")[-1]) + output_filename.stem + '.pdb')
222
  print("Saving prediction to {}.".format(_outname))
223
- modify_bfactor_biotite(pdb_file, None, _outname, _prediction) #writing the prediction without the last token
224
  output_files.append(str(_outname))
225
-
226
  _outname = output_filename.with_name(name.split("/")[-1] + output_filename.stem + '.fasta')
227
  with open(_outname, 'w') as f:
228
  print("Saving fasta to {}.".format(_outname))
229
  for name, sequence in zip(names, sequences):
230
  f.write('>' + name + '\n')
231
  f.write(sequence + '\n')
232
- output_files.append(str(_outname))
233
 
234
  return output_files, output_message
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  @spaces.GPU
237
  def flex_3d(input_file):
238
  if not input_file:
 
92
  names.append(_name)
93
  return backbones, sequences, names
94
 
95
+ def core_flex_seq(input_seq, input_file, force_cpu=False):
96
+ """Core logic decoupled from the GPU decorator."""
97
  if not input_seq:
98
  input_seq = ""
 
99
  if not input_seq.strip() and not input_file:
100
+ return None, "Provide a file/s or an input sequence/s"
101
+
102
  if input_file:
103
  if len(input_file) == 1:
104
  input_file = input_file[0]
 
108
 
109
  default_name = '{}'.format(datetime.now().strftime('%Y%m%d_%H%M%S'))
110
  output_name = default_name
111
+ sequences, names, backbones, flucts_list, pdb_files = [], [], [], [], []
 
 
 
 
 
 
112
  datapoint_for_eval = 'all'
113
 
114
  if input_seq:
 
122
  sequence = proteins[record+1]
123
  else:
124
  raise ValueError("You must adhere to the .fasta format")
 
125
  if datapoint_for_eval == 'all':
126
  names.append(name)
127
  sequences.append(sequence)
128
  backbones.append(None)
 
129
  elif suffix == ".fasta":
130
  for record in SeqIO.parse(input_file, "fasta"):
131
  name = record.name
 
133
  names.append(name)
134
  sequences.append(str(record.seq))
135
  backbones.append(None)
 
136
  elif suffix == ".pdb":
137
  backbones, sequences, names = process_pdb_file(input_file, backbones, sequences, names)
138
  pdb_files.append(input_file)
 
139
  elif suffix == ".pdb_list":
140
  for i in input_file:
141
  backbones, sequences, names = process_pdb_file(i, backbones, sequences, names)
142
  pdb_files.append(i)
143
+
144
  env_config = yaml.load(open('configs/env_config.yaml', 'r'), Loader=yaml.FullLoader)
 
145
  os.environ['HF_HOME'] = env_config['huggingface']['HF_HOME']
146
+ os.environ["CUDA_VISIBLE_DEVICES"] = env_config['gpus']['cuda_visible_device']
147
+
 
148
  config = yaml.load(open('configs/train_config.yaml', 'r'), Loader=yaml.FullLoader)
149
+ class_config = ClassConfig(config)
150
  class_config.adaptor_architecture = 'no-adaptor'
151
+
152
+ # --- DEVICE OVERRIDE LOGIC ---
153
+ if force_cpu:
154
+ target_device = 'cpu'
155
+ else:
156
+ target_device = config['inference_args']['device'] if torch.cuda.is_available() else 'cpu'
157
+ config['inference_args']['device'] = target_device
158
+
159
  model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)
160
+ model.to(target_device)
161
+
162
  repo_id = "Honzus24/Flexpert_weights"
163
  file_weights = config['inference_args']['seq_model_path']
 
 
164
  weights_path = get_weights_path(repo_id, file_weights)
165
 
166
+ state_dict = torch.load(weights_path, map_location=target_device)
 
167
  model.load_state_dict(state_dict, strict=False)
168
  model.eval()
169
+
170
  data_to_collate = []
171
  for idx, (backbone, sequence) in enumerate(zip(backbones, sequences)):
172
+ sequence = sequence.replace('-', 'X')
 
 
173
  tokenizer_out = tokenizer(' '.join(sequence), add_special_tokens=True, return_tensors='pt')
174
+ tokenized_seq = tokenizer_out['input_ids'].to(target_device)
175
+ attention_mask = tokenizer_out['attention_mask'].to(target_device)
176
  data_to_collate.append({'input_ids': tokenized_seq[0,:], 'attention_mask': attention_mask[0,:]})
177
+
178
  data_collator = DataCollatorForTokenRegression(tokenizer)
179
+ batch = data_collator(data_to_collate)
180
+ batch.to(target_device)
181
 
 
182
  with torch.no_grad():
183
  output_logits = process_in_batches_and_combine(model, batch, config['inference_args']['batch_size'])
184
+ predictions = output_logits[:,:,0]
185
+
 
186
  output_filename = Path(config['inference_args']['prediction_output_dir'].format(output_name, "seq"))
187
  output_filename.parent.mkdir(parents=True, exist_ok=True)
188
  output_files = []
 
193
  with open(output_filename_new.with_suffix('.txt'), 'w') as f:
194
  f.write("Residue Number\tResidue ID\tFlexibility\n")
195
  prediction = prediction[mask.bool()]
 
 
 
196
  assert len(prediction) == len(sequence)+1, "Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1)
 
197
  p = prediction.tolist()[:-1]
198
  for i in range(len(p)):
199
  f.write(f"{i:<10}\t{sequence[i]:<20}\t{round(p[i], 4):<10}\n")
 
204
  _prediction = prediction[:-1].reshape(1,-1)
205
  _outname = output_filename.with_name('{}_'.format(name.split("/")[-1]) + output_filename.stem + '.pdb')
206
  print("Saving prediction to {}.".format(_outname))
207
+ modify_bfactor_biotite(pdb_file, None, _outname, _prediction)
208
  output_files.append(str(_outname))
209
+
210
  _outname = output_filename.with_name(name.split("/")[-1] + output_filename.stem + '.fasta')
211
  with open(_outname, 'w') as f:
212
  print("Saving fasta to {}.".format(_outname))
213
  for name, sequence in zip(names, sequences):
214
  f.write('>' + name + '\n')
215
  f.write(sequence + '\n')
216
+ output_files.append(str(_outname))
217
 
218
  return output_files, output_message
219
 
220
+ @spaces.GPU
221
+ def flex_seq_gpu(input_seq, input_file):
222
+ return core_flex_seq(input_seq, input_file, force_cpu=False)
223
+
224
+ def flex_seq(input_seq, input_file):
225
+ try:
226
+ return flex_seq_gpu(input_seq, input_file)
227
+ except Exception as e:
228
+ # ZeroGPU exceptions (like SpaceTaskError or timeouts) are caught here
229
+ print(f"ZeroGPU failed or timed out. Reason: {e}")
230
+ print("Falling back to CPU execution. This may take a while...")
231
+ return core_flex_seq(input_seq, input_file, force_cpu=True)
232
+
233
  @spaces.GPU
234
  def flex_3d(input_file):
235
  if not input_file: