--- license: mit library_name: transformers tags: - biology - genomics - plant-science - mamba - caduceus - moe - genome-annotation - sequence-segmentation --- # PlantGenoANN **PlantGenoANN** is a plant genomic segmentation model that enables the prediction of various plant genomic elements at single-nucleotide resolution. The model is built upon the **PlantBiMoE** architecture with a 1D U-Net segmentation head, specifically designed for automated plant genome annotation. It predicts gene structures—including genes, CDSs, and exons—on both the forward and reverse strands. In addition, PlantGenoANN can serve as a **long-context plant genomic foundation model** (up to 49,152 bp), adaptable through fine-tuning to predict plant omic signal tracks such as RNA-seq or ATAC-seq. Developed by: **hu-lab** ## Model Sources * **Repository:** [PlantGenoANN](https://huggingface.co/qzzhang/PlantGenoANN) * **GitHub:** https://github.com/qzzhang0131/PlantGenoANN ## How to use The model requires the `mamba-ssm` and `causal-conv1d` libraries for the core backbone. You can retrieve both genomic feature probabilities and sequence embeddings using the following snippet: ```python import torch from transformers import AutoTokenizer, AutoModel # Load model and tokenizer repo_id = "qzzhang/PlantGenoANN" tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True) model = AutoModel.from_pretrained(repo_id, trust_remote_code=True) # The number of DNA tokens (excluding the [CLS] and [SEP] token) needs to be divisible by 8 # as required by the U-Net downsampling blocks. sequences = ["ACTAGAGCGAGAGAAA","TTTGAGAGCGCGCGGA"] # Tokenize tokenized_sequences = tokenizer( sequences, return_tensors="pt", padding="longest" )["input_ids"] # Infer model.to("cuda") model.eval() with torch.no_grad(): outs = model(input_ids=tokenized_sequences.to("cuda")) # Obtain the logits over the genomic features # Shape: [batch, sequence_length, num_features] logits = outs.logits # Get probabilities associated with CDS on the forward strand (+) pos_strand_cds_probs = model.get_feature_logits(feature="CDS", strand="+", logtis=logits).detach() print(f"CDS probabilities on the forward strand: {pos_strand_cds_probs}") # Get the sequence embeddings # Shape: [batch, sequence_length, 1024] hidden_states = outs.hidden_states.detach() print(f"Sequence embeddings shape is: {hidden_states.shape}") ``` ## Architecture PlantGenoANN is composed of the **PlantBiMoE** encoder (a 116M parameter foundation model) coupled with a custom **U-Net** segmentation head. ## 🛠️ Training Procedure PlantGenoANN was trained for **30 hours** on **4x NVIDIA A800-80G** GPUs, processing a total of **18B tokens**. The training utilized a high-quality dataset of **9 model plant genomes** with their annotations. The model was optimized using **AdamW** (learning rate: 1e-4 and weight decay: 0.01) with a **cosine learning rate scheduler**, ensuring robust convergence across diverse plant genomic contexts. ## BibTeX entry and citation info ```