Model Card for margin_reg_baseline_20260328_121542
This model is a fine-tuned version of meta-llama/Llama-3.1-8B-Instruct. It has been trained using TRL.
Quick start
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel, PeftConfig
# -----------------------------
# 1. Define the PEFT model ID & Checkpoint (Epoch)
# -----------------------------
peft_model_id = "xxccho/margin_reg_baseline"
# [Optional] ํน์ Epoch์ ์ค๊ฐ ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ถ๋ฌ์ค๊ณ ์ถ์ ๋ ์๋ ๋ณ์๋ฅผ ์ง์ ํ์ธ์.
# ์ง์ ํ์ง ์๊ณ None์ผ๋ก ๋๋ฉด ๋ ํฌ์งํ ๋ฆฌ ์ต์๋จ์ ์๋ ๋ง์ง๋ง(์ต์ข
) ํ์ต ๋ชจ๋ธ์ด ๋ก๋๋ฉ๋๋ค.
#
# [ Ckeckpoints to Epochs Mapping ]
# Epoch 1 : "checkpoint-246"
# Epoch 2 : "checkpoint-492"
# Epoch 3 : "checkpoint-738"
# Epoch 4 : "checkpoint-984"
# Epoch 5 : "checkpoint-1230"
# Epoch 6 : "checkpoint-1476"
# Epoch 7 : "checkpoint-1722"
# Epoch 8 : "checkpoint-1968"
# Epoch 9 : "checkpoint-2214"
# Epoch 10 : "checkpoint-2460"
# ์์: 5 Epoch ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฌ์ฉํ๋ ค๋ฉด ์๋์ ๊ฐ์ด ๋ณ๊ฒฝํ์ธ์.
# checkpoint = "checkpoint-1230"
checkpoint = None # None์ด๋ฉด ๋ํดํธ๋ก ์ ์ผ ๋ง์ง๋ง ์ ์ฅ ๋ชจ๋ธ์ ์๋๋ค.
# 2. Load the PEFT config (์ฒดํฌํฌ์ธํธ ์ง์ ์ฌ๋ถ์ ๋ฐ๋ผ subfolder ์ ์ฉ)
if checkpoint:
config = PeftConfig.from_pretrained(peft_model_id, subfolder=checkpoint)
else:
config = PeftConfig.from_pretrained(peft_model_id)
# 3. Load tokenizer from base model (safer)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
# Llama padding fix
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 4. Load base model
base_model = AutoModelForSequenceClassification.from_pretrained(
config.base_model_name_or_path,
num_labels=1,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# 5. Apply LoRA adapter (์ฒดํฌํฌ์ธํธ ์ง์ ์ฌ๋ถ์ ๋ฐ๋ผ subfolder ์ ์ฉ)
if checkpoint:
model = PeftModel.from_pretrained(base_model, peft_model_id, subfolder=checkpoint)
else:
model = PeftModel.from_pretrained(base_model, peft_model_id)
model.config.pad_token_id = tokenizer.pad_token_id
model.eval()
# -----------------------------
# Example Usage (chat format)
# -----------------------------
messages = [
{"role": "user", "content": "What is the capital of France?"},
{"role": "assistant", "content": "The capital of France is Paris."}
]
# Format prompt using chat template
formatted_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False
)
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
# Get reward score
with torch.no_grad():
outputs = model(**inputs)
reward_score = outputs.logits.squeeze().item()
print(f"[Chat] Reward Score: {reward_score:.4f}")
# -----------------------------
# Example Usage (plain text)
# -----------------------------
text = "User: What is the capital of France?\nAssistant: Paris."
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs)
reward_score = outputs.logits.squeeze().item()
print(f"[Plain] Reward Score: {reward_score:.4f}")
Training procedure
This model was trained with Reward.
Framework versions
- PEFT 0.18.0
- TRL: 0.26.1
- Transformers: 4.57.3
- Pytorch: 2.9.0
- Datasets: 4.4.1
- Tokenizers: 0.22.1
Citations
Cite TRL as:
@misc{vonwerra2022trl,
title = {{TRL: Transformer Reinforcement Learning}},
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec},
year = 2020,
journal = {GitHub repository},
publisher = {GitHub},
howpublished = {\url{https://github.com/huggingface/trl}}
}
- Downloads last month
- -
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐ Ask for provider support
Model tree for xxccho/margin_reg_baseline
Base model
meta-llama/Llama-3.1-8B Finetuned
meta-llama/Llama-3.1-8B-Instruct