Update app.py
Browse files
app.py
CHANGED
|
@@ -176,40 +176,112 @@ def parse_json_input(json_data: List[Dict]) -> Dict:
|
|
| 176 |
})
|
| 177 |
return components
|
| 178 |
|
|
|
|
| 179 |
def create_protenix_json(input_data: Dict) -> List[Dict]:
|
| 180 |
-
"""Convert UI inputs to Protenix JSON format"""
|
| 181 |
sequences = []
|
| 182 |
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
"
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
"
|
| 204 |
-
|
| 205 |
-
|
|
|
|
|
|
|
| 206 |
|
| 207 |
return [{
|
| 208 |
"sequences": sequences,
|
| 209 |
-
"name": input_data
|
| 210 |
}]
|
| 211 |
|
| 212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
#@torch.inference_mode()
|
| 214 |
@spaces.GPU(duration=120) # Specify a duration to avoid timeout
|
| 215 |
def predict_structure(input_collector: dict):
|
|
@@ -225,7 +297,7 @@ def predict_structure(input_collector: dict):
|
|
| 225 |
print(input_collector)
|
| 226 |
|
| 227 |
# Handle JSON input
|
| 228 |
-
if
|
| 229 |
# Handle different input types
|
| 230 |
if isinstance(input_collector["json"], str): # Example JSON case (file path)
|
| 231 |
input_data = json.load(open(input_collector["json"]))
|
|
@@ -406,31 +478,44 @@ with gr.Blocks(title="FoldMark", css=custom_css) as demo:
|
|
| 406 |
headers=["Sequence", "Count"],
|
| 407 |
datatype=["str", "number"],
|
| 408 |
row_count=1,
|
| 409 |
-
col_count=(2, "fixed")
|
|
|
|
| 410 |
)
|
| 411 |
|
| 412 |
# Repeat for other groups
|
| 413 |
-
with gr.Accordion(label="DNA Sequences", open=True):
|
| 414 |
dna_sequences = gr.Dataframe(
|
| 415 |
headers=["Sequence", "Count"],
|
| 416 |
datatype=["str", "number"],
|
| 417 |
-
row_count=1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
)
|
| 419 |
|
| 420 |
with gr.Accordion(label="Ligands", open=True):
|
| 421 |
ligands = gr.Dataframe(
|
| 422 |
headers=["Ligand Type", "Count"],
|
| 423 |
datatype=["str", "number"],
|
| 424 |
-
row_count=1
|
|
|
|
| 425 |
)
|
| 426 |
|
| 427 |
manual_output = gr.JSON(label="Generated JSON")
|
| 428 |
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
|
|
|
|
|
|
| 434 |
|
| 435 |
# Shared prediction components
|
| 436 |
with gr.Row():
|
|
@@ -450,8 +535,8 @@ with gr.Blocks(title="FoldMark", css=custom_css) as demo:
|
|
| 450 |
|
| 451 |
# Map inputs to a dictionary
|
| 452 |
submit_btn.click(
|
| 453 |
-
fn=lambda c, p, d, l, w: {"data": {"complex_name": c, "protein_chains": p, "dna_sequences": d, "ligands": l}, "watermark": w},
|
| 454 |
-
inputs=[complex_name, protein_chains, dna_sequences, ligands, add_watermark1],
|
| 455 |
outputs=input_collector
|
| 456 |
).then(
|
| 457 |
fn=predict_structure,
|
|
|
|
| 176 |
})
|
| 177 |
return components
|
| 178 |
|
| 179 |
+
|
| 180 |
def create_protenix_json(input_data: Dict) -> List[Dict]:
|
|
|
|
| 181 |
sequences = []
|
| 182 |
|
| 183 |
+
# Process protein chains
|
| 184 |
+
for pc in input_data.get("protein_chains", []):
|
| 185 |
+
# Check that the row has both columns and the sequence is nonempty.
|
| 186 |
+
if len(pc) >= 2 and pc[0].strip():
|
| 187 |
+
sequences.append({
|
| 188 |
+
"proteinChain": {
|
| 189 |
+
"sequence": pc[0].strip(),
|
| 190 |
+
"count": int(pc[1]) if pc[1] else 1
|
| 191 |
+
}
|
| 192 |
+
})
|
| 193 |
|
| 194 |
+
# Process DNA sequences
|
| 195 |
+
for dna in input_data.get("dna_sequences", []):
|
| 196 |
+
if len(dna) >= 2 and dna[0].strip():
|
| 197 |
+
sequences.append({
|
| 198 |
+
"dnaSequence": {
|
| 199 |
+
"sequence": dna[0].strip(),
|
| 200 |
+
"count": int(dna[1]) if dna[1] else 1
|
| 201 |
+
}
|
| 202 |
+
})
|
| 203 |
+
|
| 204 |
+
# Process RNA sequences
|
| 205 |
+
for rna in input_data.get("rna_sequences", []):
|
| 206 |
+
if len(rna) >= 2 and rna[0].strip():
|
| 207 |
+
sequences.append({
|
| 208 |
+
"rnaSequence": {
|
| 209 |
+
"sequence": rna[0].strip(),
|
| 210 |
+
"count": int(rna[1]) if rna[1] else 1
|
| 211 |
+
}
|
| 212 |
+
})
|
| 213 |
|
| 214 |
+
# Process ligands
|
| 215 |
+
for lig in input_data.get("ligands", []):
|
| 216 |
+
if len(lig) >= 2 and lig[0].strip():
|
| 217 |
+
sequences.append({
|
| 218 |
+
"ligand": {
|
| 219 |
+
"ligand": lig[0].strip(),
|
| 220 |
+
"count": int(lig[1]) if lig[1] else 1
|
| 221 |
+
}
|
| 222 |
+
})
|
| 223 |
|
| 224 |
return [{
|
| 225 |
"sequences": sequences,
|
| 226 |
+
"name": input_data.get("complex_name")+f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:3]}"
|
| 227 |
}]
|
| 228 |
|
| 229 |
|
| 230 |
+
def update_json(complex_name, protein_chains, dna_sequences, rna_sequences, ligands):
|
| 231 |
+
sequences_list = []
|
| 232 |
+
|
| 233 |
+
# Process protein chains (DataFrame with headers: ["Sequence", "Count"])
|
| 234 |
+
if protein_chains:
|
| 235 |
+
for row in protein_chains:
|
| 236 |
+
# Check if the row is valid and non-empty
|
| 237 |
+
if row and len(row) >= 2 and row[0]:
|
| 238 |
+
sequences_list.append({
|
| 239 |
+
"proteinChain": {
|
| 240 |
+
"sequence": row[0],
|
| 241 |
+
"count": row[1]
|
| 242 |
+
}
|
| 243 |
+
})
|
| 244 |
+
|
| 245 |
+
# Process DNA sequences
|
| 246 |
+
if dna_sequences:
|
| 247 |
+
for row in dna_sequences:
|
| 248 |
+
if row and len(row) >= 2 and row[0]:
|
| 249 |
+
sequences_list.append({
|
| 250 |
+
"dnaSequence": {
|
| 251 |
+
"sequence": row[0],
|
| 252 |
+
"count": row[1]
|
| 253 |
+
}
|
| 254 |
+
})
|
| 255 |
+
|
| 256 |
+
# Process RNA sequences
|
| 257 |
+
if rna_sequences:
|
| 258 |
+
for row in rna_sequences:
|
| 259 |
+
if row and len(row) >= 2 and row[0]:
|
| 260 |
+
sequences_list.append({
|
| 261 |
+
"rnaSequence": {
|
| 262 |
+
"sequence": row[0],
|
| 263 |
+
"count": row[1]
|
| 264 |
+
}
|
| 265 |
+
})
|
| 266 |
+
|
| 267 |
+
# Process ligands (DataFrame with headers: ["Ligand Type", "Count"])
|
| 268 |
+
if ligands:
|
| 269 |
+
for row in ligands:
|
| 270 |
+
if row and len(row) >= 2 and row[0]:
|
| 271 |
+
sequences_list.append({
|
| 272 |
+
"ligand": {
|
| 273 |
+
"ligand": row[0],
|
| 274 |
+
"count": row[1]
|
| 275 |
+
}
|
| 276 |
+
})
|
| 277 |
+
|
| 278 |
+
return {
|
| 279 |
+
"sequences": sequences_list,
|
| 280 |
+
"name": complex_name
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
|
| 285 |
#@torch.inference_mode()
|
| 286 |
@spaces.GPU(duration=120) # Specify a duration to avoid timeout
|
| 287 |
def predict_structure(input_collector: dict):
|
|
|
|
| 297 |
print(input_collector)
|
| 298 |
|
| 299 |
# Handle JSON input
|
| 300 |
+
if "json" in input_collector:
|
| 301 |
# Handle different input types
|
| 302 |
if isinstance(input_collector["json"], str): # Example JSON case (file path)
|
| 303 |
input_data = json.load(open(input_collector["json"]))
|
|
|
|
| 478 |
headers=["Sequence", "Count"],
|
| 479 |
datatype=["str", "number"],
|
| 480 |
row_count=1,
|
| 481 |
+
col_count=(2, "fixed"),
|
| 482 |
+
type="array"
|
| 483 |
)
|
| 484 |
|
| 485 |
# Repeat for other groups
|
| 486 |
+
with gr.Accordion(label="DNA Sequences (A T G C)", open=True):
|
| 487 |
dna_sequences = gr.Dataframe(
|
| 488 |
headers=["Sequence", "Count"],
|
| 489 |
datatype=["str", "number"],
|
| 490 |
+
row_count=1,
|
| 491 |
+
type="array"
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
with gr.Accordion(label="RNA Sequences (A U G C)", open=True):
|
| 495 |
+
rna_sequences = gr.Dataframe(
|
| 496 |
+
headers=["Sequence", "Count"],
|
| 497 |
+
datatype=["str", "number"],
|
| 498 |
+
row_count=1,
|
| 499 |
+
type="array"
|
| 500 |
)
|
| 501 |
|
| 502 |
with gr.Accordion(label="Ligands", open=True):
|
| 503 |
ligands = gr.Dataframe(
|
| 504 |
headers=["Ligand Type", "Count"],
|
| 505 |
datatype=["str", "number"],
|
| 506 |
+
row_count=1,
|
| 507 |
+
type="array"
|
| 508 |
)
|
| 509 |
|
| 510 |
manual_output = gr.JSON(label="Generated JSON")
|
| 511 |
|
| 512 |
+
# Attach a change event to all widgets so that any change updates the JSON output.
|
| 513 |
+
for widget in [complex_name, protein_chains, dna_sequences, rna_sequences, ligands]:
|
| 514 |
+
widget.change(
|
| 515 |
+
fn=update_json,
|
| 516 |
+
inputs=[complex_name, protein_chains, dna_sequences, rna_sequences, ligands],
|
| 517 |
+
outputs=manual_output
|
| 518 |
+
)
|
| 519 |
|
| 520 |
# Shared prediction components
|
| 521 |
with gr.Row():
|
|
|
|
| 535 |
|
| 536 |
# Map inputs to a dictionary
|
| 537 |
submit_btn.click(
|
| 538 |
+
fn=lambda c, p, d, r, l, w: {"data": {"complex_name": c, "protein_chains": p, "dna_sequences": d, "rna_sequences": r, "ligands": l}, "watermark": w},
|
| 539 |
+
inputs=[complex_name, protein_chains, dna_sequences, rna_sequences, ligands, add_watermark1],
|
| 540 |
outputs=input_collector
|
| 541 |
).then(
|
| 542 |
fn=predict_structure,
|