--- language: - en base_model: - microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext --- # LitGene: An Interpretable Transformer Model for Gene Representation Learning LitGene is a transformer-based model that learns rich gene representations by integrating textual information from the scientific literature with structured knowledge from the Gene Ontology (GO). Using contrastive learning, the model refines gene embeddings that capture both sequence and functional annotations, enabling improved prediction of protein properties, gene-disease associations, and functional annotations such as GO terms and KEGG pathways. This repository provides model weights for the pre-trained LitGene model. It is intended to serve as a base representation model that can be further adapted/fine-tuned for specific biomedical tasks. ## Intended Usage This model is intended to be used for any tasks that require interfacing with models . LitGene can be used for any of the following: - Infrence: Providing predictions for gene functions, gene-disease/gene-protien associations, and specific biological pathway information. Prompt Ligene [here](http://64.106.39.56:5000/). - Gene Embeddings: Producing embeddings that capture both textual (literature based) sepcific biological properties of gene function.https://github.com/vinash85/LitGene/tree/master - Fine-tuning: base representation model can be fine-tuned for a multitude of biomedical tasks (e.g. protien solubility prediction, drug dosage sensitivity). Example tasks can be found in this [repo](https://github.com/vinash85/LitGene/tree/master). ## Usage (Pytorch) Below is the example (pytorch) code to import LitGene weights ```python import torch from transformers import AutoModel, AutoTokenizer # Load the model and tokenizer model_name = "tumorailab/LitGene_ContrastiveLearning" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name) # If you want to move the model to GPU device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) ``` below is example code to get embeddings for an example scentence ```python # Prepare your sentence sentence = "Your text goes here" # Tokenize the sentence inputs = tokenizer( sentence, padding=True, truncation=True, max_length=512, return_tensors="pt" ) # Move inputs to the same device as model inputs = {k: v.to(device) for k, v in inputs.items()} # Get embeddings with torch.no_grad(): model.eval() outputs = model(**inputs) # Get the CLS token embedding (first token) print(outputs.last_hidden_state) ``` ## Training Details ##### Hyperparameters | Hyperparameter | Value | | --- | --- | | Embedding Dimension | 768 | | Batch Size | 64 | | Optimizer | AdamW | | Learning Rate | 2e-5 (with linear decay) | | Weight Decay | 0.01 | | Contrastive Learning Loss Function | Margin-based ranking loss | | Contrastive Loss Margin (δ) | 0.5 | | Number of Training Steps | 100k | | Dropout Rate | 0.1 | | Gradient Clipping | 1.0 |