Graph Machine Learning
Safetensors
Griffin_models / README.md
yamboo's picture
Update README.md
bd8c5be verified
---
license: apache-2.0
datasets:
- yamboo/Griffin_datasets_joint_v65
- yamboo/Griffin_datasets_single_pretrain_v3
metrics:
- accuracy
- roc_auc
- mse
pipeline_tag: graph-ml
---
# Griffin: Pretrained Checkpoints
This repository contains various pretrained checkpoints for the [Griffin model](https://github.com/yanxwb/Griffin). The paper is at [Link](arxiv.org/abs/2505.05568)
## Checkpoints
The checkpoints are organized as follows:
```bash
./checkpoints/
β”œβ”€β”€ single-completion # Pretrained single table completion model.
β”œβ”€β”€ single-sft # Pretrained single table SFT model. Used in main experiments.
└── transfer # Pretrained transfer model. Used in transfer experiments.
β”œβ”€β”€ commerce-1 # Split name.
β”œβ”€β”€ FULL # RDB-SFT setting name. This one used in main transfer experiments.
β”œβ”€β”€ MIXED # RDB-SFT setting name. Used in ablation in RDB-SFT setting.
└── LIMITED # RDB-SFT setting name. Used in ablation in RDB-SFT setting.
β”œβ”€β”€ commerce-2 # Same as above.
β”œβ”€β”€ FULL
β”œβ”€β”€ MIXED
└── LIMITED
β”œβ”€β”€ others-1
β”œβ”€β”€ FULL
β”œβ”€β”€ MIXED
└── LIMITED
└── others-2
β”œβ”€β”€ FULL
β”œβ”€β”€ MIXED
└── LIMITED
```
## How to use
To get started, you will need to have the model's architecture defined in your code, provided in [Github Repo](https://github.com/yanxwb/Griffin). You can then use the `huggingface_hub` library to download a specific checkpoint and load its weights.
```python
import json
import torch
from huggingface_hub import hf_hub_download
import accelerate
# Assume 'GriffinModel' is your model's class definition
# from your_project_position.hmodel import GriffinMod
# 1. Define the repository ID and the specific file you want to load
repo_id = "yamboo/Griffin_models"
# Example: Loading the main single-table SFT model
checkpoint_path = "single-sft/model.safetensors"
config_path = "single-sft/config.json"
# 2. Download the checkpoint file from the Hub
model_weights_path = hf_hub_download(repo_id=repo_id, filename=checkpoint_path)
model_config_path = hf_hub_download(repo_id=repo_id, filename=config_path)
config = json.load(open("config.json", "r"))
# 3. Instantiate your model and load the weights. We use accelerate to align with Github repo experiment pipeline.
model = GriffinMod(**config) # Make sure to pass any required config
accelerate.load_checkpoint_in_model(model, model_weights_path)
```