CrystalReasoner: Reasoning and RL for Property-Conditioned Crystal Structure Generation

CrystalReasoner (CrysReas) is an end-to-end LLM framework for generating crystal structures from natural language instructions. It uses supervised fine-tuning (SFT) to teach crystal-structure generation, thinking traces to introduce crystallographic and physical priors before coordinates, and reinforcement learning (RL) with verifiable rewards to improve validity, stability, and property conditioning.

Qwen2.5-3B-CrysReas-CrystalTextLLM

Quick Start

You can use this model directly with the transformers library:

from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import torch

model_id = "CrystalReasoner/Qwen2.5-3B-CrysReas-CrystalTextLLM"

tokenizer = AutoTokenizer.from_pretrained(model_id)
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    config=config,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True
)

messages = [
    {"role": "user", "content": "Below is a description of a bulk material. The chemical formula is NaCl. The bulk_modulus is about 100 GPa. Generate a description of the lengths and angles of the lattice vectors and then the element type and coordinates for each atom within the lattice:"},
]

text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer(text, return_tensors="pt").to(model.device)

generated_ids = model.generate(
    model_inputs.input_ids,
    max_new_tokens=2048,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id,
    use_cache=True,
)
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0]
print(generated_text)

If you want the generated structure in pymatgen Structure format, please use this script after the previous generation:

def get_structure(generated_text: str):
    import re
    from pymatgen.core import Lattice, Structure

    cif_match = re.search(r'<CIF>(.*?)</CIF>', generated_text, re.DOTALL)
    if cif_match:
        generated_text = cif_match.group(1)

    lines = [line.strip() for line in generated_text.strip().split('\n') if line.strip()]
    if lines and not re.match(r'^[-+0-9.eE\s]+$', lines[0]):
        lines = lines[1:]

    lengths = list(map(float, lines[0].split()))
    angles = list(map(float, lines[1].split()))
    lattice = Lattice.from_parameters(*lengths, *angles)

    species = []
    coords = []
    for line in lines[2:]:
        parts = line.split()
        species.append(parts[0])
        coords.append([float(parts[2]), float(parts[3]), float(parts[4])])

    structure = Structure(lattice, species, coords)
    return structure

structure = get_structure(generated_text)
print(structure)
Downloads last month
-
Safetensors
Model size
3B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for CrystalReasoner/Qwen2.5-3B-CrysReas-CrystalTextLLM

Base model

Qwen/Qwen2.5-3B
Finetuned
(406)
this model