jiehou commited on
Commit
91fbffc
·
verified ·
1 Parent(s): 938fc1a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +302 -17
  2. image_annotator.py +221 -0
app.py CHANGED
@@ -44,6 +44,7 @@ st.set_page_config(
44
  initial_sidebar_state="expanded"
45
  )
46
 
 
47
 
48
 
49
  # Custom CSS - IMPROVED VERSION with larger fonts
@@ -497,27 +498,53 @@ def main():
497
 
498
  # Residue trimming controls - add early so they're available when needed
499
  st.sidebar.markdown("---")
500
- st.sidebar.markdown("**🔧 Terminal Residue Trimming**")
501
  col1, col2 = st.sidebar.columns(2)
502
  with col1:
503
- n_term_trim = st.number_input(
504
- "N-term trim",
505
  min_value=0,
506
  max_value=10,
507
  value=2,
508
  step=1,
509
  help="Number of residues to remove from 5' end",
510
- key="n_term_trim"
511
  )
512
  with col2:
513
- c_term_trim = st.number_input(
514
- "C-term trim",
515
  min_value=0,
516
  max_value=10,
517
  value=2,
518
  step=1,
519
  help="Number of residues to remove from 3' end",
520
- key="c_term_trim"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
  )
522
 
523
  # Load structure data
@@ -624,8 +651,8 @@ def main():
624
 
625
  if selection_method == "Select by range":
626
  current_selection = st.session_state['ref_selections'].get(selected_ref['name'], [])
627
- default_start = current_selection[0] + n_term_trim if current_selection else 3
628
- default_end = current_selection[-1] + 1 if current_selection else max(2, len(structure_info) - 2)
629
 
630
  c1, c2 = st.columns(2)
631
  with c1:
@@ -653,13 +680,13 @@ def main():
653
 
654
  elif selection_method == "Select specific residues":
655
  # Always use current trim values for default selection (updates when trim values change)
656
- default_names = [structure_info[i]['full_name'] for i in range(n_term_trim, len(structure_info)-c_term_trim)]
657
 
658
  selected_names = st.multiselect(
659
  "Select residues",
660
  options=[info['full_name'] for info in structure_info],
661
  default=default_names,
662
- key=f"specific_ref_{selected_ref['name']}_n{n_term_trim}_c{c_term_trim}"
663
  )
664
 
665
 
@@ -724,8 +751,8 @@ def main():
724
 
725
  if selection_method == "Select by range":
726
  current_selection = st.session_state['query_selections'].get(selected_query['name'], [])
727
- default_start = current_selection[0] + n_term_trim if current_selection else 3
728
- default_end = current_selection[-1] + 1 if current_selection else max(2, len(structure_info) - c_term_trim)
729
 
730
  c1, c2 = st.columns(2)
731
  with c1:
@@ -753,13 +780,13 @@ def main():
753
 
754
  elif selection_method == "Select specific residues":
755
  # Always use current trim values for default selection (updates when trim values change)
756
- default_names = [structure_info[i]['full_name'] for i in range(n_term_trim, len(structure_info)-c_term_trim)]
757
 
758
  selected_names = st.multiselect(
759
  "Select residues",
760
  options=[info['full_name'] for info in structure_info],
761
  default=default_names,
762
- key=f"specific_query_{selected_query['name']}_n{n_term_trim}_c{c_term_trim}"
763
  )
764
 
765
  name_to_idx = {info['full_name']: info['index'] for info in structure_info}
@@ -993,7 +1020,11 @@ def main():
993
  selected_row['Rotation_Matrix'],
994
  selected_row['Ref_COM'],
995
  selected_row['Query_COM'],
996
- selected_row['RMSD']
 
 
 
 
997
  )
998
  st.components.v1.html(viz_html, width=1400, height=750, scrolling=False)
999
  except Exception as e:
@@ -1001,6 +1032,210 @@ def main():
1001
  import traceback
1002
  st.code(traceback.format_exc())
1003
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1004
 
1005
  # Show transformation details
1006
  with st.expander("🔧 Transformation Details"):
@@ -1079,7 +1314,57 @@ def main():
1079
  help="Query structure aligned to reference"
1080
  )
1081
 
1082
- st.info("💡 **Tip:** Load reference and aligned query together in PyMOL/Chimera to examine the superposition")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1083
 
1084
  else:
1085
  st.warning("No comparisons below RMSD threshold to visualize")
 
44
  initial_sidebar_state="expanded"
45
  )
46
 
47
+ from image_annotator import annotate_alignment_image
48
 
49
 
50
  # Custom CSS - IMPROVED VERSION with larger fonts
 
498
 
499
  # Residue trimming controls - add early so they're available when needed
500
  st.sidebar.markdown("---")
501
+ st.sidebar.markdown("**🔧 Terminal Residue Trimming (Reference) **")
502
  col1, col2 = st.sidebar.columns(2)
503
  with col1:
504
+ n_term_trim_ref = st.number_input(
505
+ "N-term trim_ref",
506
  min_value=0,
507
  max_value=10,
508
  value=2,
509
  step=1,
510
  help="Number of residues to remove from 5' end",
511
+ key="n_term_trim_ref"
512
  )
513
  with col2:
514
+ c_term_trim_ref = st.number_input(
515
+ "C-term trim_ref",
516
  min_value=0,
517
  max_value=10,
518
  value=2,
519
  step=1,
520
  help="Number of residues to remove from 3' end",
521
+ key="c_term_trim_ref"
522
+ )
523
+
524
+
525
+ # Residue trimming controls - add early so they're available when needed
526
+ st.sidebar.markdown("---")
527
+ st.sidebar.markdown("**🔧 Terminal Residue Trimming (Query) **")
528
+ col1, col2 = st.sidebar.columns(2)
529
+ with col1:
530
+ n_term_trim_query = st.number_input(
531
+ "N-term trim_query",
532
+ min_value=0,
533
+ max_value=10,
534
+ value=2,
535
+ step=1,
536
+ help="Number of residues to remove from 5' end",
537
+ key="n_term_trim_query"
538
+ )
539
+ with col2:
540
+ c_term_trim_query = st.number_input(
541
+ "C-term trim_query",
542
+ min_value=0,
543
+ max_value=10,
544
+ value=2,
545
+ step=1,
546
+ help="Number of residues to remove from 3' end",
547
+ key="c_term_trim_query"
548
  )
549
 
550
  # Load structure data
 
651
 
652
  if selection_method == "Select by range":
653
  current_selection = st.session_state['ref_selections'].get(selected_ref['name'], [])
654
+ default_start = current_selection[0] + n_term_trim_ref if current_selection else n_term_trim_ref
655
+ default_end = current_selection[-1] + 1 if current_selection else max(n_term_trim_ref, len(structure_info) - c_term_trim_ref)
656
 
657
  c1, c2 = st.columns(2)
658
  with c1:
 
680
 
681
  elif selection_method == "Select specific residues":
682
  # Always use current trim values for default selection (updates when trim values change)
683
+ default_names = [structure_info[i]['full_name'] for i in range(n_term_trim_ref, len(structure_info)-c_term_trim_ref)]
684
 
685
  selected_names = st.multiselect(
686
  "Select residues",
687
  options=[info['full_name'] for info in structure_info],
688
  default=default_names,
689
+ key=f"specific_ref_{selected_ref['name']}_n{n_term_trim_ref}_c{c_term_trim_ref}"
690
  )
691
 
692
 
 
751
 
752
  if selection_method == "Select by range":
753
  current_selection = st.session_state['query_selections'].get(selected_query['name'], [])
754
+ default_start = current_selection[0] + n_term_trim_query if current_selection else 3
755
+ default_end = current_selection[-1] + 1 if current_selection else max(2, len(structure_info) - c_term_trim_query)
756
 
757
  c1, c2 = st.columns(2)
758
  with c1:
 
780
 
781
  elif selection_method == "Select specific residues":
782
  # Always use current trim values for default selection (updates when trim values change)
783
+ default_names = [structure_info[i]['full_name'] for i in range(n_term_trim_query, len(structure_info)-c_term_trim_query)]
784
 
785
  selected_names = st.multiselect(
786
  "Select residues",
787
  options=[info['full_name'] for info in structure_info],
788
  default=default_names,
789
+ key=f"specific_query_{selected_query['name']}_n{n_term_trim_query}_c{c_term_trim_query}"
790
  )
791
 
792
  name_to_idx = {info['full_name']: info['index'] for info in structure_info}
 
1020
  selected_row['Rotation_Matrix'],
1021
  selected_row['Ref_COM'],
1022
  selected_row['Query_COM'],
1023
+ selected_row['RMSD'],
1024
+ ref_name=selected_row['Reference'],
1025
+ query_name=selected_row['Query'],
1026
+ ref_sequence=selected_row['Ref_Sequence'],
1027
+ query_sequence=selected_row['Query_Sequence']
1028
  )
1029
  st.components.v1.html(viz_html, width=1400, height=750, scrolling=False)
1030
  except Exception as e:
 
1032
  import traceback
1033
  st.code(traceback.format_exc())
1034
 
1035
+ # Direct Download Annotated Image Button
1036
+ st.markdown("---")
1037
+ st.markdown("### 📸 Download Annotated Structure Image")
1038
+
1039
+ col1, col2 = st.columns([3, 1])
1040
+
1041
+ with col1:
1042
+ st.info("💡 Generate a structure visualization with RMSD and sequence information embedded")
1043
+
1044
+ with col2:
1045
+ if st.button("🖼️ Generate Annotated Image", use_container_width=True, type="primary"):
1046
+ with st.spinner("Generating annotated image..."):
1047
+ try:
1048
+ from visualization import extract_window_pdb, transform_pdb_string
1049
+
1050
+ # Create a simple structure visualization using matplotlib
1051
+ # Since we can't capture py3Dmol directly, we'll create a matplotlib-based view
1052
+
1053
+ # Extract structures
1054
+ ref_pdb = extract_window_pdb(
1055
+ selected_row['Ref_Path'],
1056
+ selected_row['Ref_Window']
1057
+ )
1058
+
1059
+ query_pdb = extract_window_pdb(
1060
+ selected_row['Query_Path'],
1061
+ selected_row['Query_Window']
1062
+ )
1063
+
1064
+ query_aligned_pdb = transform_pdb_string(
1065
+ query_pdb,
1066
+ selected_row['Rotation_Matrix'],
1067
+ selected_row['Query_COM'],
1068
+ selected_row['Ref_COM']
1069
+ )
1070
+
1071
+ # Parse coordinates for visualization
1072
+ from rmsd_utils import parse_residue_atoms
1073
+
1074
+ # Create a matplotlib-based 3D visualization
1075
+ import matplotlib.pyplot as plt
1076
+ from mpl_toolkits.mplot3d import Axes3D
1077
+
1078
+ fig = plt.figure(figsize=(12, 9), dpi=150)
1079
+ ax = fig.add_subplot(111, projection='3d')
1080
+
1081
+ # Function to extract coordinates from PDB string
1082
+ def get_coords_from_pdb_string(pdb_string):
1083
+ coords = []
1084
+ for line in pdb_string.split('\n'):
1085
+ if line.startswith(('ATOM', 'HETATM')):
1086
+ try:
1087
+ x = float(line[30:38].strip())
1088
+ y = float(line[38:46].strip())
1089
+ z = float(line[46:54].strip())
1090
+ atom_name = line[12:16].strip()
1091
+ coords.append((x, y, z, atom_name))
1092
+ except:
1093
+ continue
1094
+ return coords
1095
+
1096
+ # Get coordinates
1097
+ ref_coords = get_coords_from_pdb_string(ref_pdb)
1098
+ query_coords = get_coords_from_pdb_string(query_aligned_pdb)
1099
+
1100
+ # Plot reference structure (blue)
1101
+ if ref_coords:
1102
+ ref_x = [c[0] for c in ref_coords]
1103
+ ref_y = [c[1] for c in ref_coords]
1104
+ ref_z = [c[2] for c in ref_coords]
1105
+ ax.scatter(ref_x, ref_y, ref_z, c='#4A90E2', s=40, alpha=0.8, label='Reference')
1106
+
1107
+ # Connect backbone atoms
1108
+ backbone_atoms = ['P', "C4'", "C3'", "O3'"]
1109
+ ref_backbone = [(c[0], c[1], c[2]) for c in ref_coords if c[3] in backbone_atoms]
1110
+ if len(ref_backbone) > 1:
1111
+ bb_x = [c[0] for c in ref_backbone]
1112
+ bb_y = [c[1] for c in ref_backbone]
1113
+ bb_z = [c[2] for c in ref_backbone]
1114
+ ax.plot(bb_x, bb_y, bb_z, c='#4A90E2', linewidth=2, alpha=0.6)
1115
+
1116
+ # Plot query structure (red)
1117
+ if query_coords:
1118
+ query_x = [c[0] for c in query_coords]
1119
+ query_y = [c[1] for c in query_coords]
1120
+ query_z = [c[2] for c in query_coords]
1121
+ ax.scatter(query_x, query_y, query_z, c='#E94B3C', s=40, alpha=0.8, label='Query (Aligned)')
1122
+
1123
+ # Connect backbone atoms
1124
+ query_backbone = [(c[0], c[1], c[2]) for c in query_coords if c[3] in backbone_atoms]
1125
+ if len(query_backbone) > 1:
1126
+ bb_x = [c[0] for c in query_backbone]
1127
+ bb_y = [c[1] for c in query_backbone]
1128
+ bb_z = [c[2] for c in query_backbone]
1129
+ ax.plot(bb_x, bb_y, bb_z, c='#E94B3C', linewidth=2, alpha=0.6)
1130
+
1131
+ # Set labels and title
1132
+ ax.set_xlabel('X (Å)', fontsize=10)
1133
+ ax.set_ylabel('Y (Å)', fontsize=10)
1134
+ ax.set_zlabel('Z (Å)', fontsize=10)
1135
+ ax.legend(fontsize=10, loc='upper right')
1136
+
1137
+ # Set viewing angle
1138
+ ax.view_init(elev=20, azim=45)
1139
+
1140
+ # Equal aspect ratio
1141
+ if ref_coords or query_coords:
1142
+ all_coords = ref_coords + query_coords
1143
+ all_x = [c[0] for c in all_coords]
1144
+ all_y = [c[1] for c in all_coords]
1145
+ all_z = [c[2] for c in all_coords]
1146
+
1147
+ max_range = max(
1148
+ max(all_x) - min(all_x),
1149
+ max(all_y) - min(all_y),
1150
+ max(all_z) - min(all_z)
1151
+ ) / 2.0
1152
+
1153
+ mid_x = (max(all_x) + min(all_x)) / 2
1154
+ mid_y = (max(all_y) + min(all_y)) / 2
1155
+ mid_z = (max(all_z) + min(all_z)) / 2
1156
+
1157
+ ax.set_xlim(mid_x - max_range, mid_x + max_range)
1158
+ ax.set_ylim(mid_y - max_range, mid_y + max_range)
1159
+ ax.set_zlim(mid_z - max_range, mid_z + max_range)
1160
+
1161
+ plt.tight_layout()
1162
+
1163
+ # Save to temporary buffer
1164
+ buf = io.BytesIO()
1165
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
1166
+ plt.close()
1167
+ buf.seek(0)
1168
+
1169
+ # Now annotate this image
1170
+ annotated_img = annotate_alignment_image(
1171
+ image_data=buf.read(),
1172
+ rmsd=selected_row['RMSD'],
1173
+ ref_name=selected_row['Reference'],
1174
+ query_name=selected_row['Query'],
1175
+ ref_sequence=selected_row['Ref_Sequence'],
1176
+ query_sequence=selected_row['Query_Sequence'],
1177
+ output_format='JPEG'
1178
+ )
1179
+
1180
+ # Generate filename
1181
+ ref_clean = selected_row['Reference'].replace('.pdb', '')
1182
+ query_clean = selected_row['Query'].replace('.pdb', '')
1183
+ filename = f"annotated_{ref_clean}_{query_clean}_RMSD_{selected_row['RMSD']:.3f}.jpg"
1184
+
1185
+ # Show preview and download button
1186
+ st.success("✅ Annotated image generated!")
1187
+ st.image(annotated_img, caption="Annotated Structure Alignment", use_column_width=True)
1188
+
1189
+ st.download_button(
1190
+ label="📥 Download Annotated JPEG",
1191
+ data=annotated_img.getvalue(),
1192
+ file_name=filename,
1193
+ mime="image/jpeg",
1194
+ use_container_width=True,
1195
+ help="Download JPEG with RMSD and sequence information"
1196
+ )
1197
+
1198
+ except Exception as e:
1199
+ st.error(f"Error generating annotated image: {str(e)}")
1200
+ import traceback
1201
+ st.code(traceback.format_exc())
1202
+
1203
+ # Fallback: offer the upload option
1204
+ st.info("💡 Alternatively, you can download a screenshot from the 3D viewer above using the '📷 Download PNG' button, then upload it below:")
1205
+
1206
+ uploaded_screenshot = st.file_uploader(
1207
+ "Upload screenshot (PNG/JPG)",
1208
+ type=['png', 'jpg', 'jpeg'],
1209
+ key=f"screenshot_upload_fallback_{selected_viz_idx}"
1210
+ )
1211
+
1212
+ if uploaded_screenshot is not None:
1213
+ try:
1214
+ annotated_img = annotate_alignment_image(
1215
+ image_data=uploaded_screenshot.read(),
1216
+ rmsd=selected_row['RMSD'],
1217
+ ref_name=selected_row['Reference'],
1218
+ query_name=selected_row['Query'],
1219
+ ref_sequence=selected_row['Ref_Sequence'],
1220
+ query_sequence=selected_row['Query_Sequence'],
1221
+ output_format='JPEG'
1222
+ )
1223
+
1224
+ st.image(annotated_img, use_column_width=True)
1225
+
1226
+ ref_clean = selected_row['Reference'].replace('.pdb', '')
1227
+ query_clean = selected_row['Query'].replace('.pdb', '')
1228
+ filename = f"annotated_{ref_clean}_{query_clean}_RMSD_{selected_row['RMSD']:.3f}.jpg"
1229
+
1230
+ st.download_button(
1231
+ label="📥 Download Annotated JPEG",
1232
+ data=annotated_img.getvalue(),
1233
+ file_name=filename,
1234
+ mime="image/jpeg",
1235
+ use_container_width=True
1236
+ )
1237
+ except Exception as e2:
1238
+ st.error(f"Error annotating uploaded image: {str(e2)}")
1239
 
1240
  # Show transformation details
1241
  with st.expander("🔧 Transformation Details"):
 
1314
  help="Query structure aligned to reference"
1315
  )
1316
 
1317
+ # Combined aligned structure
1318
+ st.markdown("---")
1319
+ st.markdown("**Combined Aligned Structure (Reference + Query)**")
1320
+
1321
+ # Create combined PDB with both structures
1322
+ combined_pdb_lines = []
1323
+
1324
+ # Add header information as REMARK records
1325
+ combined_pdb_lines.append(f"REMARK Reference: {selected_row['Reference']}")
1326
+ combined_pdb_lines.append(f"REMARK Reference Residues: {','.join(map(str, [i+1 for i in selected_row['Ref_Window']]))}")
1327
+ combined_pdb_lines.append(f"REMARK Reference Sequence: {selected_row['Ref_Sequence']}")
1328
+ combined_pdb_lines.append(f"REMARK Query: {selected_row['Query']}")
1329
+ combined_pdb_lines.append(f"REMARK Query Residues: {','.join(map(str, [i+1 for i in selected_row['Query_Window']]))}")
1330
+ combined_pdb_lines.append(f"REMARK Query Sequence: {selected_row['Query_Sequence']}")
1331
+ combined_pdb_lines.append(f"REMARK RMSD: {selected_row['RMSD']:.3f} Angstroms")
1332
+ combined_pdb_lines.append("MODEL 1")
1333
+
1334
+ # Add reference atoms with chain A
1335
+ for line in ref_pdb.split('\n'):
1336
+ if line.startswith(('ATOM', 'HETATM')):
1337
+ # Set chain to A for reference
1338
+ modified_line = line[:21] + 'A' + line[22:]
1339
+ combined_pdb_lines.append(modified_line)
1340
+
1341
+ combined_pdb_lines.append("ENDMDL")
1342
+ combined_pdb_lines.append("MODEL 2")
1343
+
1344
+ # Add aligned query atoms with chain B
1345
+ for line in query_aligned_pdb.split('\n'):
1346
+ if line.startswith(('ATOM', 'HETATM')):
1347
+ # Set chain to B for query
1348
+ modified_line = line[:21] + 'B' + line[22:]
1349
+ combined_pdb_lines.append(modified_line)
1350
+
1351
+ combined_pdb_lines.append("ENDMDL")
1352
+ combined_pdb_lines.append("END")
1353
+
1354
+ combined_pdb = '\n'.join(combined_pdb_lines)
1355
+
1356
+ combined_filename = f"aligned_{selected_row['Reference'].replace('.pdb', '')}_{selected_row['Query'].replace('.pdb', '')}_rmsd_{selected_row['RMSD']:.3f}.pdb"
1357
+
1358
+ st.download_button(
1359
+ label="📥 Download Combined Aligned Structure",
1360
+ data=combined_pdb,
1361
+ file_name=combined_filename,
1362
+ mime="chemical/x-pdb",
1363
+ help="Reference (chain A) and aligned query (chain B) in one file",
1364
+ use_container_width=True
1365
+ )
1366
+
1367
+ st.info("💡 **Tip:** The combined PDB contains reference (chain A) and aligned query (chain B) - ready for PyMOL/Chimera")
1368
 
1369
  else:
1370
  st.warning("No comparisons below RMSD threshold to visualize")
image_annotator.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image Annotation Utility for RNA Structure Alignments
3
+ Adds RMSD, reference/query names, and sequences directly to PNG/JPEG images
4
+ """
5
+
6
+ from PIL import Image, ImageDraw, ImageFont
7
+ import io
8
+
9
+
10
+ def annotate_alignment_image(image_data, rmsd, ref_name, query_name,
11
+ ref_sequence=None, query_sequence=None,
12
+ output_format='PNG'):
13
+ """
14
+ Add text annotations to an alignment image.
15
+
16
+ Args:
17
+ image_data: Either a file path (str), PIL Image object, or bytes
18
+ rmsd: RMSD value (float)
19
+ ref_name: Reference structure name (str)
20
+ query_name: Query structure name (str)
21
+ ref_sequence: Reference sequence (str, optional)
22
+ query_sequence: Query sequence (str, optional)
23
+ output_format: 'PNG' or 'JPEG'
24
+
25
+ Returns:
26
+ BytesIO object containing the annotated image
27
+ """
28
+
29
+ # Load the image
30
+ if isinstance(image_data, str):
31
+ img = Image.open(image_data)
32
+ elif isinstance(image_data, bytes):
33
+ img = Image.open(io.BytesIO(image_data))
34
+ elif isinstance(image_data, Image.Image):
35
+ img = image_data
36
+ else:
37
+ raise ValueError("image_data must be a file path, bytes, or PIL Image")
38
+
39
+ # Convert to RGB if necessary (for JPEG compatibility)
40
+ if img.mode != 'RGB' and output_format == 'JPEG':
41
+ img = img.convert('RGB')
42
+
43
+ # Create drawing context
44
+ draw = ImageDraw.Draw(img)
45
+
46
+ # Try to use a better font, fall back to default if not available
47
+ try:
48
+ # Try to load a TrueType font
49
+ font_large = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20)
50
+ font_medium = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16)
51
+ font_small = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf", 14)
52
+ except:
53
+ # Fallback to default font
54
+ font_large = ImageFont.load_default()
55
+ font_medium = ImageFont.load_default()
56
+ font_small = ImageFont.load_default()
57
+
58
+ # Get image dimensions
59
+ width, height = img.size
60
+
61
+ # Define annotation box parameters
62
+ margin = 15
63
+ padding = 12
64
+ line_spacing = 8
65
+
66
+ # Prepare text lines
67
+ lines = []
68
+ lines.append(("RMSD:", f"{rmsd:.3f} Å", font_large, True)) # Bold RMSD
69
+ lines.append(("", "", font_medium, False)) # Spacer
70
+ lines.append(("Reference:", ref_name, font_medium, False))
71
+
72
+ if ref_sequence:
73
+ lines.append((" Seq:", ref_sequence, font_small, False))
74
+
75
+ lines.append(("", "", font_medium, False)) # Spacer
76
+ lines.append(("Query:", query_name, font_medium, False))
77
+
78
+ if query_sequence:
79
+ lines.append((" Seq:", query_sequence, font_small, False))
80
+
81
+ # Calculate box dimensions
82
+ max_width = 0
83
+ total_height = padding * 2
84
+
85
+ for label, value, font, is_bold in lines:
86
+ if label or value:
87
+ text = f"{label} {value}".strip()
88
+ bbox = draw.textbbox((0, 0), text, font=font)
89
+ text_width = bbox[2] - bbox[0]
90
+ text_height = bbox[3] - bbox[1]
91
+ max_width = max(max_width, text_width)
92
+ total_height += text_height + line_spacing
93
+ else:
94
+ total_height += line_spacing // 2
95
+
96
+ box_width = max_width + padding * 2
97
+ box_height = total_height
98
+
99
+ # Position box in bottom-left corner
100
+ box_x = margin
101
+ box_y = height - box_height - margin
102
+
103
+ # Draw semi-transparent background box
104
+ overlay = Image.new('RGBA', img.size, (255, 255, 255, 0))
105
+ overlay_draw = ImageDraw.Draw(overlay)
106
+
107
+ # White background with 95% opacity
108
+ overlay_draw.rounded_rectangle(
109
+ [(box_x, box_y), (box_x + box_width, box_y + box_height)],
110
+ radius=8,
111
+ fill=(255, 255, 255, 242)
112
+ )
113
+
114
+ # Add subtle border
115
+ overlay_draw.rounded_rectangle(
116
+ [(box_x, box_y), (box_x + box_width, box_y + box_height)],
117
+ radius=8,
118
+ outline=(200, 200, 200, 242),
119
+ width=1
120
+ )
121
+
122
+ # Composite the overlay
123
+ if img.mode == 'RGB':
124
+ img = img.convert('RGBA')
125
+ img = Image.alpha_composite(img, overlay)
126
+
127
+ # Convert back to RGB if needed
128
+ if output_format == 'JPEG':
129
+ img = img.convert('RGB')
130
+
131
+ # Recreate draw context after compositing
132
+ draw = ImageDraw.Draw(img)
133
+
134
+ # Draw text
135
+ current_y = box_y + padding
136
+
137
+ for label, value, font, is_bold in lines:
138
+ if not label and not value:
139
+ # Spacer
140
+ current_y += line_spacing // 2
141
+ continue
142
+
143
+ text = f"{label} {value}".strip() if value else label
144
+
145
+ # Choose color based on content
146
+ if "RMSD" in label:
147
+ color = (233, 75, 60) # Red color for RMSD
148
+ elif label.startswith(" "):
149
+ color = (100, 100, 100) # Gray for sequences
150
+ else:
151
+ color = (51, 51, 51) # Dark gray for labels
152
+
153
+ # Draw the text
154
+ draw.text((box_x + padding, current_y), text, fill=color, font=font)
155
+
156
+ # Get text height for next line
157
+ bbox = draw.textbbox((0, 0), text, font=font)
158
+ text_height = bbox[3] - bbox[1]
159
+ current_y += text_height + line_spacing
160
+
161
+ # Save to BytesIO
162
+ output = io.BytesIO()
163
+ img.save(output, format=output_format, quality=95 if output_format == 'JPEG' else None)
164
+ output.seek(0)
165
+
166
+ return output
167
+
168
+
169
+ def annotate_alignment_image_file(input_path, output_path, rmsd, ref_name, query_name,
170
+ ref_sequence=None, query_sequence=None):
171
+ """
172
+ Annotate an image file and save to a new file.
173
+
174
+ Args:
175
+ input_path: Path to input image
176
+ output_path: Path to save annotated image
177
+ rmsd: RMSD value (float)
178
+ ref_name: Reference structure name (str)
179
+ query_name: Query structure name (str)
180
+ ref_sequence: Reference sequence (str, optional)
181
+ query_sequence: Query sequence (str, optional)
182
+ """
183
+ output_format = 'JPEG' if output_path.lower().endswith('.jpg') or output_path.lower().endswith('.jpeg') else 'PNG'
184
+
185
+ annotated = annotate_alignment_image(
186
+ input_path, rmsd, ref_name, query_name,
187
+ ref_sequence, query_sequence, output_format
188
+ )
189
+
190
+ with open(output_path, 'wb') as f:
191
+ f.write(annotated.getvalue())
192
+
193
+
194
+ # Example usage and testing
195
+ if __name__ == "__main__":
196
+ # Example: Create a test image with annotations
197
+ from PIL import Image, ImageDraw
198
+
199
+ # Create a simple test image
200
+ test_img = Image.new('RGB', (800, 600), color='white')
201
+ test_draw = ImageDraw.Draw(test_img)
202
+
203
+ # Draw some simple shapes to simulate a structure
204
+ test_draw.ellipse([300, 200, 500, 400], fill='lightblue', outline='blue', width=3)
205
+ test_draw.ellipse([320, 220, 480, 380], fill='lightcoral', outline='red', width=3)
206
+
207
+ # Annotate it
208
+ annotated = annotate_alignment_image(
209
+ test_img,
210
+ rmsd=1.234,
211
+ ref_name="6TNA_reference",
212
+ query_name="1EHZ_query",
213
+ ref_sequence="GCGGAU",
214
+ query_sequence="GCGGAU"
215
+ )
216
+
217
+ # Save test image
218
+ with open('/tmp/test_annotated.png', 'wb') as f:
219
+ f.write(annotated.getvalue())
220
+
221
+ print("Test image created: /tmp/test_annotated.png")