Spaces:
Running on Zero
Running on Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -92,14 +92,13 @@ def process_pdb_file(pdb_file, backbones, sequences, names):
|
|
| 92 |
names.append(_name)
|
| 93 |
return backbones, sequences, names
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
if not input_seq:
|
| 98 |
input_seq = ""
|
| 99 |
-
|
| 100 |
if not input_seq.strip() and not input_file:
|
| 101 |
-
return None, "Provide a file/s or
|
| 102 |
-
|
| 103 |
if input_file:
|
| 104 |
if len(input_file) == 1:
|
| 105 |
input_file = input_file[0]
|
|
@@ -109,13 +108,7 @@ def flex_seq(input_seq, input_file):
|
|
| 109 |
|
| 110 |
default_name = '{}'.format(datetime.now().strftime('%Y%m%d_%H%M%S'))
|
| 111 |
output_name = default_name
|
| 112 |
-
|
| 113 |
-
sequences = []
|
| 114 |
-
names = []
|
| 115 |
-
backbones = []
|
| 116 |
-
flucts_list = []
|
| 117 |
-
pdb_files = []
|
| 118 |
-
|
| 119 |
datapoint_for_eval = 'all'
|
| 120 |
|
| 121 |
if input_seq:
|
|
@@ -129,12 +122,10 @@ def flex_seq(input_seq, input_file):
|
|
| 129 |
sequence = proteins[record+1]
|
| 130 |
else:
|
| 131 |
raise ValueError("You must adhere to the .fasta format")
|
| 132 |
-
|
| 133 |
if datapoint_for_eval == 'all':
|
| 134 |
names.append(name)
|
| 135 |
sequences.append(sequence)
|
| 136 |
backbones.append(None)
|
| 137 |
-
|
| 138 |
elif suffix == ".fasta":
|
| 139 |
for record in SeqIO.parse(input_file, "fasta"):
|
| 140 |
name = record.name
|
|
@@ -142,59 +133,56 @@ def flex_seq(input_seq, input_file):
|
|
| 142 |
names.append(name)
|
| 143 |
sequences.append(str(record.seq))
|
| 144 |
backbones.append(None)
|
| 145 |
-
|
| 146 |
elif suffix == ".pdb":
|
| 147 |
backbones, sequences, names = process_pdb_file(input_file, backbones, sequences, names)
|
| 148 |
pdb_files.append(input_file)
|
| 149 |
-
|
| 150 |
elif suffix == ".pdb_list":
|
| 151 |
for i in input_file:
|
| 152 |
backbones, sequences, names = process_pdb_file(i, backbones, sequences, names)
|
| 153 |
pdb_files.append(i)
|
| 154 |
-
|
| 155 |
env_config = yaml.load(open('configs/env_config.yaml', 'r'), Loader=yaml.FullLoader)
|
| 156 |
-
# Set folder for huggingface cache
|
| 157 |
os.environ['HF_HOME'] = env_config['huggingface']['HF_HOME']
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
config = yaml.load(open('configs/train_config.yaml', 'r'), Loader=yaml.FullLoader)
|
| 162 |
-
class_config=ClassConfig(config)
|
| 163 |
class_config.adaptor_architecture = 'no-adaptor'
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)
|
| 166 |
-
model.to(
|
|
|
|
| 167 |
repo_id = "Honzus24/Flexpert_weights"
|
| 168 |
file_weights = config['inference_args']['seq_model_path']
|
| 169 |
-
|
| 170 |
-
# Get path (instant if cached)
|
| 171 |
weights_path = get_weights_path(repo_id, file_weights)
|
| 172 |
|
| 173 |
-
|
| 174 |
-
state_dict = torch.load(weights_path, map_location=config['inference_args']['device'])
|
| 175 |
model.load_state_dict(state_dict, strict=False)
|
| 176 |
model.eval()
|
| 177 |
-
|
| 178 |
data_to_collate = []
|
| 179 |
for idx, (backbone, sequence) in enumerate(zip(backbones, sequences)):
|
| 180 |
-
|
| 181 |
-
sequence = sequence.replace('-', 'X') #due to the tokenizer vocabulary
|
| 182 |
-
|
| 183 |
tokenizer_out = tokenizer(' '.join(sequence), add_special_tokens=True, return_tensors='pt')
|
| 184 |
-
tokenized_seq
|
| 185 |
-
|
| 186 |
data_to_collate.append({'input_ids': tokenized_seq[0,:], 'attention_mask': attention_mask[0,:]})
|
| 187 |
-
|
| 188 |
data_collator = DataCollatorForTokenRegression(tokenizer)
|
| 189 |
-
batch = data_collator(data_to_collate)
|
| 190 |
-
batch.to(
|
| 191 |
|
| 192 |
-
# Predict
|
| 193 |
with torch.no_grad():
|
| 194 |
output_logits = process_in_batches_and_combine(model, batch, config['inference_args']['batch_size'])
|
| 195 |
-
predictions = output_logits[:,:,0]
|
| 196 |
-
|
| 197 |
-
|
| 198 |
output_filename = Path(config['inference_args']['prediction_output_dir'].format(output_name, "seq"))
|
| 199 |
output_filename.parent.mkdir(parents=True, exist_ok=True)
|
| 200 |
output_files = []
|
|
@@ -205,11 +193,7 @@ def flex_seq(input_seq, input_file):
|
|
| 205 |
with open(output_filename_new.with_suffix('.txt'), 'w') as f:
|
| 206 |
f.write("Residue Number\tResidue ID\tFlexibility\n")
|
| 207 |
prediction = prediction[mask.bool()]
|
| 208 |
-
if len(prediction) != len(sequence)+1:
|
| 209 |
-
print("Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1))
|
| 210 |
-
|
| 211 |
assert len(prediction) == len(sequence)+1, "Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1)
|
| 212 |
-
|
| 213 |
p = prediction.tolist()[:-1]
|
| 214 |
for i in range(len(p)):
|
| 215 |
f.write(f"{i:<10}\t{sequence[i]:<20}\t{round(p[i], 4):<10}\n")
|
|
@@ -220,19 +204,32 @@ def flex_seq(input_seq, input_file):
|
|
| 220 |
_prediction = prediction[:-1].reshape(1,-1)
|
| 221 |
_outname = output_filename.with_name('{}_'.format(name.split("/")[-1]) + output_filename.stem + '.pdb')
|
| 222 |
print("Saving prediction to {}.".format(_outname))
|
| 223 |
-
modify_bfactor_biotite(pdb_file, None, _outname, _prediction)
|
| 224 |
output_files.append(str(_outname))
|
| 225 |
-
|
| 226 |
_outname = output_filename.with_name(name.split("/")[-1] + output_filename.stem + '.fasta')
|
| 227 |
with open(_outname, 'w') as f:
|
| 228 |
print("Saving fasta to {}.".format(_outname))
|
| 229 |
for name, sequence in zip(names, sequences):
|
| 230 |
f.write('>' + name + '\n')
|
| 231 |
f.write(sequence + '\n')
|
| 232 |
-
|
| 233 |
|
| 234 |
return output_files, output_message
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
@spaces.GPU
|
| 237 |
def flex_3d(input_file):
|
| 238 |
if not input_file:
|
|
|
|
| 92 |
names.append(_name)
|
| 93 |
return backbones, sequences, names
|
| 94 |
|
| 95 |
+
def core_flex_seq(input_seq, input_file, force_cpu=False):
|
| 96 |
+
"""Core logic decoupled from the GPU decorator."""
|
| 97 |
if not input_seq:
|
| 98 |
input_seq = ""
|
|
|
|
| 99 |
if not input_seq.strip() and not input_file:
|
| 100 |
+
return None, "Provide a file/s or an input sequence/s"
|
| 101 |
+
|
| 102 |
if input_file:
|
| 103 |
if len(input_file) == 1:
|
| 104 |
input_file = input_file[0]
|
|
|
|
| 108 |
|
| 109 |
default_name = '{}'.format(datetime.now().strftime('%Y%m%d_%H%M%S'))
|
| 110 |
output_name = default_name
|
| 111 |
+
sequences, names, backbones, flucts_list, pdb_files = [], [], [], [], []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
datapoint_for_eval = 'all'
|
| 113 |
|
| 114 |
if input_seq:
|
|
|
|
| 122 |
sequence = proteins[record+1]
|
| 123 |
else:
|
| 124 |
raise ValueError("You must adhere to the .fasta format")
|
|
|
|
| 125 |
if datapoint_for_eval == 'all':
|
| 126 |
names.append(name)
|
| 127 |
sequences.append(sequence)
|
| 128 |
backbones.append(None)
|
|
|
|
| 129 |
elif suffix == ".fasta":
|
| 130 |
for record in SeqIO.parse(input_file, "fasta"):
|
| 131 |
name = record.name
|
|
|
|
| 133 |
names.append(name)
|
| 134 |
sequences.append(str(record.seq))
|
| 135 |
backbones.append(None)
|
|
|
|
| 136 |
elif suffix == ".pdb":
|
| 137 |
backbones, sequences, names = process_pdb_file(input_file, backbones, sequences, names)
|
| 138 |
pdb_files.append(input_file)
|
|
|
|
| 139 |
elif suffix == ".pdb_list":
|
| 140 |
for i in input_file:
|
| 141 |
backbones, sequences, names = process_pdb_file(i, backbones, sequences, names)
|
| 142 |
pdb_files.append(i)
|
| 143 |
+
|
| 144 |
env_config = yaml.load(open('configs/env_config.yaml', 'r'), Loader=yaml.FullLoader)
|
|
|
|
| 145 |
os.environ['HF_HOME'] = env_config['huggingface']['HF_HOME']
|
| 146 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = env_config['gpus']['cuda_visible_device']
|
| 147 |
+
|
|
|
|
| 148 |
config = yaml.load(open('configs/train_config.yaml', 'r'), Loader=yaml.FullLoader)
|
| 149 |
+
class_config = ClassConfig(config)
|
| 150 |
class_config.adaptor_architecture = 'no-adaptor'
|
| 151 |
+
|
| 152 |
+
# --- DEVICE OVERRIDE LOGIC ---
|
| 153 |
+
if force_cpu:
|
| 154 |
+
target_device = 'cpu'
|
| 155 |
+
else:
|
| 156 |
+
target_device = config['inference_args']['device'] if torch.cuda.is_available() else 'cpu'
|
| 157 |
+
config['inference_args']['device'] = target_device
|
| 158 |
+
|
| 159 |
model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)
|
| 160 |
+
model.to(target_device)
|
| 161 |
+
|
| 162 |
repo_id = "Honzus24/Flexpert_weights"
|
| 163 |
file_weights = config['inference_args']['seq_model_path']
|
|
|
|
|
|
|
| 164 |
weights_path = get_weights_path(repo_id, file_weights)
|
| 165 |
|
| 166 |
+
state_dict = torch.load(weights_path, map_location=target_device)
|
|
|
|
| 167 |
model.load_state_dict(state_dict, strict=False)
|
| 168 |
model.eval()
|
| 169 |
+
|
| 170 |
data_to_collate = []
|
| 171 |
for idx, (backbone, sequence) in enumerate(zip(backbones, sequences)):
|
| 172 |
+
sequence = sequence.replace('-', 'X')
|
|
|
|
|
|
|
| 173 |
tokenizer_out = tokenizer(' '.join(sequence), add_special_tokens=True, return_tensors='pt')
|
| 174 |
+
tokenized_seq = tokenizer_out['input_ids'].to(target_device)
|
| 175 |
+
attention_mask = tokenizer_out['attention_mask'].to(target_device)
|
| 176 |
data_to_collate.append({'input_ids': tokenized_seq[0,:], 'attention_mask': attention_mask[0,:]})
|
| 177 |
+
|
| 178 |
data_collator = DataCollatorForTokenRegression(tokenizer)
|
| 179 |
+
batch = data_collator(data_to_collate)
|
| 180 |
+
batch.to(target_device)
|
| 181 |
|
|
|
|
| 182 |
with torch.no_grad():
|
| 183 |
output_logits = process_in_batches_and_combine(model, batch, config['inference_args']['batch_size'])
|
| 184 |
+
predictions = output_logits[:,:,0]
|
| 185 |
+
|
|
|
|
| 186 |
output_filename = Path(config['inference_args']['prediction_output_dir'].format(output_name, "seq"))
|
| 187 |
output_filename.parent.mkdir(parents=True, exist_ok=True)
|
| 188 |
output_files = []
|
|
|
|
| 193 |
with open(output_filename_new.with_suffix('.txt'), 'w') as f:
|
| 194 |
f.write("Residue Number\tResidue ID\tFlexibility\n")
|
| 195 |
prediction = prediction[mask.bool()]
|
|
|
|
|
|
|
|
|
|
| 196 |
assert len(prediction) == len(sequence)+1, "Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1)
|
|
|
|
| 197 |
p = prediction.tolist()[:-1]
|
| 198 |
for i in range(len(p)):
|
| 199 |
f.write(f"{i:<10}\t{sequence[i]:<20}\t{round(p[i], 4):<10}\n")
|
|
|
|
| 204 |
_prediction = prediction[:-1].reshape(1,-1)
|
| 205 |
_outname = output_filename.with_name('{}_'.format(name.split("/")[-1]) + output_filename.stem + '.pdb')
|
| 206 |
print("Saving prediction to {}.".format(_outname))
|
| 207 |
+
modify_bfactor_biotite(pdb_file, None, _outname, _prediction)
|
| 208 |
output_files.append(str(_outname))
|
| 209 |
+
|
| 210 |
_outname = output_filename.with_name(name.split("/")[-1] + output_filename.stem + '.fasta')
|
| 211 |
with open(_outname, 'w') as f:
|
| 212 |
print("Saving fasta to {}.".format(_outname))
|
| 213 |
for name, sequence in zip(names, sequences):
|
| 214 |
f.write('>' + name + '\n')
|
| 215 |
f.write(sequence + '\n')
|
| 216 |
+
output_files.append(str(_outname))
|
| 217 |
|
| 218 |
return output_files, output_message
|
| 219 |
|
| 220 |
+
@spaces.GPU
|
| 221 |
+
def flex_seq_gpu(input_seq, input_file):
|
| 222 |
+
return core_flex_seq(input_seq, input_file, force_cpu=False)
|
| 223 |
+
|
| 224 |
+
def flex_seq(input_seq, input_file):
|
| 225 |
+
try:
|
| 226 |
+
return flex_seq_gpu(input_seq, input_file)
|
| 227 |
+
except Exception as e:
|
| 228 |
+
# ZeroGPU exceptions (like SpaceTaskError or timeouts) are caught here
|
| 229 |
+
print(f"ZeroGPU failed or timed out. Reason: {e}")
|
| 230 |
+
print("Falling back to CPU execution. This may take a while...")
|
| 231 |
+
return core_flex_seq(input_seq, input_file, force_cpu=True)
|
| 232 |
+
|
| 233 |
@spaces.GPU
|
| 234 |
def flex_3d(input_file):
|
| 235 |
if not input_file:
|