Text Generation
PEFT
lora
trl
naming
brand-generation
controllable-generation
nomen-ai / scripts /update_adapter_card.py
krystv's picture
Add adapter model card updater
77aed1a verified
"""Update SFT/DPO adapter model cards after training.
Usage:
python scripts/update_adapter_card.py --repo krystv/nomen-ai-sft-lora --phase sft
python scripts/update_adapter_card.py --repo krystv/nomen-ai-dpo-lora --phase dpo
"""
import argparse
from huggingface_hub import HfApi
CARD_TEMPLATE = """---
license: apache-2.0
base_model: Qwen/Qwen2.5-1.5B-Instruct
datasets:
- {dataset}
tags:
- peft
- lora
- trl
- nomen-ai
- {phase}
pipeline_tag: text-generation
---
# Nomen-AI {phase_upper} LoRA Adapter
This adapter is part of the Nomen-AI controllable cross-lingual morpho-phonetic name synthesis pipeline.
Main repo: https://huggingface.co/krystv/nomen-ai
Dataset: https://huggingface.co/datasets/{dataset}
Base model: https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct
## Usage
```python
from nomen_ai.control import ControlVector
from nomen_ai.inference import NomenAI
engine = NomenAI("{repo}", base_model="Qwen/Qwen2.5-1.5B-Instruct")
cv = ControlVector(roots=["japanese", "nordic"], blend=[40, 60], theme="gaming", syllables=3, char_len=8, creativity=0.8)
print(engine.generate(cv, n=10))
```
## Training recipe
See `train_config.yaml` in the main repo:
https://huggingface.co/krystv/nomen-ai/blob/main/train_config.yaml
## Validation
Run:
```bash
python scripts/check_artifacts.py
python scripts/evaluate.py --model_id {repo} --base_model Qwen/Qwen2.5-1.5B-Instruct
```
"""
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--repo', required=True)
ap.add_argument('--phase', choices=['sft','dpo'], required=True)
args = ap.parse_args()
dataset = 'krystv/nomen-ai-sft' if args.phase == 'sft' else 'krystv/nomen-ai-dpo'
card = CARD_TEMPLATE.format(repo=args.repo, phase=args.phase, phase_upper=args.phase.upper(), dataset=dataset)
HfApi().upload_file(
repo_id=args.repo,
repo_type='model',
path_in_repo='README.md',
path_or_fileobj=card.encode('utf-8'),
commit_message=f'Update {args.phase} adapter model card',
)
print(f'Updated README.md for {args.repo}')
if __name__ == '__main__':
main()