Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +196 -0
src/streamlit_app.py
CHANGED
|
@@ -616,6 +616,202 @@ SAMPLE_STRUCTURES = {
|
|
| 616 |
"hBN Monolayer (4x4)": "hBN_monolayer_4x4_supercell.extxyz",
|
| 617 |
}
|
| 618 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 619 |
def get_structure_viz2(atoms_obj, style='stick', show_unit_cell=True, width=400, height=400):
|
| 620 |
xyz_str = ""
|
| 621 |
xyz_str += f"{len(atoms_obj)}\n"
|
|
|
|
| 616 |
"hBN Monolayer (4x4)": "hBN_monolayer_4x4_supercell.extxyz",
|
| 617 |
}
|
| 618 |
|
| 619 |
+
def get_trajectory_viz(trajectory, style='stick', show_unit_cell=True, width=400, height=400,
|
| 620 |
+
show_path=True, path_color='red', path_radius=0.02):
|
| 621 |
+
"""
|
| 622 |
+
Visualize optimization trajectory with multiple frames
|
| 623 |
+
|
| 624 |
+
Args:
|
| 625 |
+
trajectory: List of ASE atoms objects representing the optimization steps
|
| 626 |
+
style: Visualization style ('stick', 'ball', 'ball-stick')
|
| 627 |
+
show_unit_cell: Whether to show unit cell
|
| 628 |
+
show_path: Whether to show trajectory paths for each atom
|
| 629 |
+
path_color: Color of trajectory paths
|
| 630 |
+
path_radius: Radius of trajectory path cylinders
|
| 631 |
+
"""
|
| 632 |
+
if not trajectory:
|
| 633 |
+
return None
|
| 634 |
+
|
| 635 |
+
view = py3Dmol.view(width=width, height=height)
|
| 636 |
+
|
| 637 |
+
# Add all frames to the viewer
|
| 638 |
+
for frame_idx, atoms_obj in enumerate(trajectory):
|
| 639 |
+
xyz_str = ""
|
| 640 |
+
xyz_str += f"{len(atoms_obj)}\n"
|
| 641 |
+
xyz_str += f"Frame {frame_idx}\n"
|
| 642 |
+
for atom in atoms_obj:
|
| 643 |
+
xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
|
| 644 |
+
|
| 645 |
+
view.addModel(xyz_str, "xyz")
|
| 646 |
+
|
| 647 |
+
# Set style for all models
|
| 648 |
+
if style.lower() == 'ball-stick':
|
| 649 |
+
view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.3}})
|
| 650 |
+
elif style.lower() == 'stick':
|
| 651 |
+
view.setStyle({'stick': {}})
|
| 652 |
+
elif style.lower() == 'ball':
|
| 653 |
+
view.setStyle({'sphere': {'scale': 0.4}})
|
| 654 |
+
else:
|
| 655 |
+
view.setStyle({'stick': {'radius': 0.15}})
|
| 656 |
+
|
| 657 |
+
# Add trajectory paths
|
| 658 |
+
if show_path and len(trajectory) > 1:
|
| 659 |
+
for atom_idx in range(len(trajectory[0])):
|
| 660 |
+
for frame_idx in range(len(trajectory) - 1):
|
| 661 |
+
start_pos = trajectory[frame_idx][atom_idx].position
|
| 662 |
+
end_pos = trajectory[frame_idx + 1][atom_idx].position
|
| 663 |
+
|
| 664 |
+
view.addCylinder({
|
| 665 |
+
'start': {'x': start_pos[0], 'y': start_pos[1], 'z': start_pos[2]},
|
| 666 |
+
'end': {'x': end_pos[0], 'y': end_pos[1], 'z': end_pos[2]},
|
| 667 |
+
'radius': path_radius,
|
| 668 |
+
'color': path_color,
|
| 669 |
+
'alpha': 0.5
|
| 670 |
+
})
|
| 671 |
+
|
| 672 |
+
# Add unit cell for the last frame
|
| 673 |
+
if show_unit_cell and trajectory[-1].pbc.any():
|
| 674 |
+
cell = trajectory[-1].get_cell()
|
| 675 |
+
origin = np.array([0.0, 0.0, 0.0])
|
| 676 |
+
if cell is not None and cell.any():
|
| 677 |
+
edges = [
|
| 678 |
+
(origin, cell[0]), (origin, cell[1]), (cell[0], cell[0] + cell[1]), (cell[1], cell[0] + cell[1]),
|
| 679 |
+
(cell[2], cell[2] + cell[0]), (cell[2], cell[2] + cell[1]),
|
| 680 |
+
(cell[2] + cell[0], cell[2] + cell[0] + cell[1]), (cell[2] + cell[1], cell[2] + cell[0] + cell[1]),
|
| 681 |
+
(origin, cell[2]), (cell[0], cell[0] + cell[2]), (cell[1], cell[1] + cell[2]),
|
| 682 |
+
(cell[0] + cell[1], cell[0] + cell[1] + cell[2])
|
| 683 |
+
]
|
| 684 |
+
for start, end in edges:
|
| 685 |
+
view.addCylinder({
|
| 686 |
+
'start': {'x': start[0], 'y': start[1], 'z': start[2]},
|
| 687 |
+
'end': {'x': end[0], 'y': end[1], 'z': end[2]},
|
| 688 |
+
'radius': 0.05, 'color': 'black', 'alpha': 0.7
|
| 689 |
+
})
|
| 690 |
+
|
| 691 |
+
view.zoomTo()
|
| 692 |
+
view.setBackgroundColor('white')
|
| 693 |
+
return view
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
def get_animated_trajectory_viz(trajectory, style='stick', show_unit_cell=True, width=400, height=400):
|
| 697 |
+
"""
|
| 698 |
+
Create an animated trajectory visualization
|
| 699 |
+
"""
|
| 700 |
+
if not trajectory:
|
| 701 |
+
return None
|
| 702 |
+
|
| 703 |
+
view = py3Dmol.view(width=width, height=height)
|
| 704 |
+
|
| 705 |
+
# Add all frames
|
| 706 |
+
for frame_idx, atoms_obj in enumerate(trajectory):
|
| 707 |
+
xyz_str = ""
|
| 708 |
+
xyz_str += f"{len(atoms_obj)}\n"
|
| 709 |
+
xyz_str += f"Frame {frame_idx}\n"
|
| 710 |
+
for atom in atoms_obj:
|
| 711 |
+
xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
|
| 712 |
+
|
| 713 |
+
view.addModel(xyz_str, "xyz")
|
| 714 |
+
|
| 715 |
+
# Set style
|
| 716 |
+
if style.lower() == 'ball-stick':
|
| 717 |
+
view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.3}})
|
| 718 |
+
elif style.lower() == 'stick':
|
| 719 |
+
view.setStyle({'stick': {}})
|
| 720 |
+
elif style.lower() == 'ball':
|
| 721 |
+
view.setStyle({'sphere': {'scale': 0.4}})
|
| 722 |
+
else:
|
| 723 |
+
view.setStyle({'stick': {'radius': 0.15}})
|
| 724 |
+
|
| 725 |
+
# Add unit cell for last frame
|
| 726 |
+
if show_unit_cell and trajectory[-1].pbc.any():
|
| 727 |
+
cell = trajectory[-1].get_cell()
|
| 728 |
+
origin = np.array([0.0, 0.0, 0.0])
|
| 729 |
+
if cell is not None and cell.any():
|
| 730 |
+
edges = [
|
| 731 |
+
(origin, cell[0]), (origin, cell[1]), (cell[0], cell[0] + cell[1]), (cell[1], cell[0] + cell[1]),
|
| 732 |
+
(origin, cell[2]), (cell[0], cell[0] + cell[2]), (cell[1], cell[1] + cell[2]),
|
| 733 |
+
(cell[0] + cell[1], cell[0] + cell[1] + cell[2]),
|
| 734 |
+
(cell[2], cell[2] + cell[0]), (cell[2], cell[2] + cell[1]),
|
| 735 |
+
(cell[2] + cell[0], cell[2] + cell[0] + cell[1]), (cell[2] + cell[1], cell[2] + cell[0] + cell[1])
|
| 736 |
+
]
|
| 737 |
+
for start, end in edges:
|
| 738 |
+
view.addCylinder({
|
| 739 |
+
'start': {'x': start[0], 'y': start[1], 'z': start[2]},
|
| 740 |
+
'end': {'x': end[0], 'y': end[1], 'z': end[2]},
|
| 741 |
+
'radius': 0.05, 'color': 'black', 'alpha': 0.7
|
| 742 |
+
})
|
| 743 |
+
|
| 744 |
+
view.zoomTo()
|
| 745 |
+
view.setBackgroundColor('white')
|
| 746 |
+
|
| 747 |
+
# Enable animation
|
| 748 |
+
view.animate({'loop': 'forward', 'reps': 0, 'interval': 500})
|
| 749 |
+
|
| 750 |
+
return view
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
# Streamlit implementation example
|
| 754 |
+
def display_optimization_trajectory(trajectory, viz_style='stick'):
|
| 755 |
+
"""
|
| 756 |
+
Display optimization trajectory in Streamlit with controls
|
| 757 |
+
"""
|
| 758 |
+
if not trajectory:
|
| 759 |
+
st.error("No trajectory data available")
|
| 760 |
+
return
|
| 761 |
+
|
| 762 |
+
st.subheader(f"Optimization Trajectory ({len(trajectory)} steps)")
|
| 763 |
+
|
| 764 |
+
# Trajectory options
|
| 765 |
+
col1, col2 = st.columns(2)
|
| 766 |
+
|
| 767 |
+
with col1:
|
| 768 |
+
viz_mode = st.selectbox(
|
| 769 |
+
"Visualization Mode",
|
| 770 |
+
["Static with paths", "Animation", "Step-by-step"],
|
| 771 |
+
key="viz_mode"
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
with col2:
|
| 775 |
+
if viz_mode == "Static with paths":
|
| 776 |
+
show_paths = st.checkbox("Show trajectory paths", value=True)
|
| 777 |
+
path_color = st.selectbox("Path color", ["red", "blue", "green", "orange"], index=0)
|
| 778 |
+
elif viz_mode == "Step-by-step":
|
| 779 |
+
frame_idx = st.slider("Frame", 0, len(trajectory)-1, 0, key="frame_slider")
|
| 780 |
+
|
| 781 |
+
# Display visualization based on mode
|
| 782 |
+
if viz_mode == "Static with paths":
|
| 783 |
+
opt_view = get_trajectory_viz(
|
| 784 |
+
trajectory,
|
| 785 |
+
style=viz_style,
|
| 786 |
+
show_unit_cell=True,
|
| 787 |
+
width=400,
|
| 788 |
+
height=400,
|
| 789 |
+
show_path=show_paths,
|
| 790 |
+
path_color=path_color
|
| 791 |
+
)
|
| 792 |
+
st.components.v1.html(opt_view._make_html(), width=400, height=400)
|
| 793 |
+
|
| 794 |
+
elif viz_mode == "Animation":
|
| 795 |
+
opt_view = get_animated_trajectory_viz(
|
| 796 |
+
trajectory,
|
| 797 |
+
style=viz_style,
|
| 798 |
+
show_unit_cell=True,
|
| 799 |
+
width=400,
|
| 800 |
+
height=400
|
| 801 |
+
)
|
| 802 |
+
st.components.v1.html(opt_view._make_html(), width=400, height=400)
|
| 803 |
+
|
| 804 |
+
elif viz_mode == "Step-by-step":
|
| 805 |
+
opt_view = get_structure_viz2(
|
| 806 |
+
trajectory[frame_idx],
|
| 807 |
+
style=viz_style,
|
| 808 |
+
show_unit_cell=True,
|
| 809 |
+
width=400,
|
| 810 |
+
height=400
|
| 811 |
+
)
|
| 812 |
+
st.components.v1.html(opt_view._make_html(), width=400, height=400)
|
| 813 |
+
st.write(f"Step {frame_idx + 1} of {len(trajectory)}")
|
| 814 |
+
|
| 815 |
def get_structure_viz2(atoms_obj, style='stick', show_unit_cell=True, width=400, height=400):
|
| 816 |
xyz_str = ""
|
| 817 |
xyz_str += f"{len(atoms_obj)}\n"
|