Spaces:
Running
Running
| import gradio as gr | |
| from slices.core import SLICES | |
| from pymatgen.core.structure import Structure | |
| from pymatgen.io.cif import CifWriter | |
| from pymatgen.io.ase import AseAtomsAdaptor | |
| from ase.io import write as ase_write | |
| import tempfile | |
| from pymatgen.symmetry.analyzer import SpacegroupAnalyzer | |
| import os | |
| # Initialize SLICES backend | |
| backend = SLICES(relax_model="chgnet", fmax=0.4, steps=25) | |
| def wrap_structure(structure): | |
| """Wrap all atoms back into the unit cell.""" | |
| for i, site in enumerate(structure): | |
| frac_coords = site.frac_coords % 1.0 | |
| structure.replace(i, species=site.species, coords=frac_coords, coords_are_cartesian=False) | |
| return structure | |
| def get_primitive_structure(structure): | |
| """Convert the structure to its primitive cell.""" | |
| analyzer = SpacegroupAnalyzer(structure) | |
| return analyzer.get_primitive_standard_structure() | |
| def visualize_structure(structure): | |
| """Generate an image of the structure.""" | |
| atoms = AseAtomsAdaptor.get_atoms(structure) | |
| with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file: | |
| ase_write(temp_file.name, atoms, format='png', rotation='10x,10y,10z') | |
| return temp_file.name | |
| def process_structure(structure): | |
| """Wrap and convert to primitive cell.""" | |
| structure = wrap_structure(structure) | |
| return get_primitive_structure(structure) | |
| def cif_to_slices(cif_file): | |
| try: | |
| structure = Structure.from_file(cif_file.name) | |
| structure = process_structure(structure) | |
| slices_string = backend.structure2SLICES(structure) | |
| image_file = visualize_structure(structure) | |
| return slices_string, image_file, None, slices_string, slices_string # Added another slices_string for aug_slices_input | |
| except Exception as e: | |
| return str(e), None, None, "", "" | |
| def slices_to_cif(slices_string): | |
| try: | |
| structure, energy = backend.SLICES2structure(slices_string) | |
| structure = process_structure(structure) | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.cif', delete=False) as temp_file: | |
| CifWriter(structure).write_file(temp_file.name) | |
| image_file = visualize_structure(structure) | |
| return temp_file.name, image_file, f"Conversion successful. Energy: {energy:.4f} eV/atom" | |
| except Exception as e: | |
| return None, None, f"Conversion failed. Error: {str(e)}" | |
| def augment_and_canonicalize_slices(slices_string, num_augmentations): | |
| try: | |
| augmented_slices = backend.SLICES2SLICESAug_atom_order(slices_string,num=num_augmentations) | |
| unique_augmented_slices = list(set(augmented_slices)) | |
| canonical_slices = list(set([backend.get_canonical_SLICES(s) for s in unique_augmented_slices])) | |
| return augmented_slices, canonical_slices | |
| except Exception as e: | |
| return [], [], str(e) | |
| # Gradio interface | |
| with gr.Blocks() as iface: | |
| gr.Markdown("# Crystal Structure and SLICES Converter", elem_classes=["center"]) | |
| with gr.Row(elem_classes=["center"]): | |
| gr.Image("1.png", label="SLICES Representation", show_label=False, width=600, height=250) | |
| gr.Markdown("SLICES provides a text-based encoding of crystal structures, allowing for efficient manipulation and generation of new materials.", elem_classes=["center"]) | |
| with gr.Tab("CIF-SLICES Conversion"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_choice = gr.Radio( | |
| ["Use example CIF (NdSiRu.cif)", "Upload custom CIF"], | |
| label="Choose CIF source", | |
| value="Use example CIF (NdSiRu.cif)" | |
| ) | |
| example_file = gr.File(value="NdSiRu.cif", visible=False, interactive=False) | |
| custom_file = gr.File(label="Upload CIF file", file_types=[".cif"], visible=False) | |
| convert_cif_button = gr.Button("Convert CIF to SLICES") | |
| slices_input = gr.Textbox(label="Enter SLICES String") | |
| convert_slices_button = gr.Button("Convert SLICES to CIF") | |
| with gr.Column(): | |
| slices_output = gr.Textbox(label="SLICES String") | |
| cif_output = gr.File(label="Download CIF", file_types=[".cif"]) | |
| conversion_status = gr.Textbox(label="Conversion Status") | |
| with gr.Row(): | |
| cif_image = gr.Image(label="Original Structure") | |
| slices_image = gr.Image(label="Converted Structure") | |
| with gr.Tab("SLICES Augmentation and Canonicalization"): | |
| aug_slices_input = gr.Textbox(label="Enter SLICES String") | |
| num_augmentations = gr.Slider(minimum=1, maximum=50, step=1, value=10, label="Number of Augmentations") | |
| augment_button = gr.Button("Augment and Canonicalize") | |
| aug_slices_output = gr.Textbox(label="Augmented SLICES Strings") | |
| canon_slices_output = gr.Textbox(label="Canonical SLICES Strings") | |
| # Event handlers | |
| def update_file_visibility(choice): | |
| return gr.update(visible=choice == "Use example CIF (NdSiRu.cif)"), gr.update(visible=choice == "Upload custom CIF") | |
| file_choice.change( | |
| update_file_visibility, | |
| inputs=[file_choice], | |
| outputs=[example_file, custom_file] | |
| ) | |
| def get_active_file(choice, example, custom): | |
| return example if choice == "Use example CIF (NdSiRu.cif)" else custom | |
| convert_cif_button.click( | |
| lambda choice, example, custom: cif_to_slices(get_active_file(choice, example, custom)), | |
| inputs=[file_choice, example_file, custom_file], | |
| outputs=[slices_output, cif_image, conversion_status, slices_input, aug_slices_input] | |
| ) | |
| convert_slices_button.click( | |
| slices_to_cif, | |
| inputs=[slices_input], | |
| outputs=[cif_output, slices_image, conversion_status] | |
| ) | |
| augment_button.click( | |
| augment_and_canonicalize_slices, | |
| inputs=[aug_slices_input, num_augmentations], | |
| outputs=[aug_slices_output, canon_slices_output] | |
| ) | |
| iface.launch(share=True) |