Model Card for margin_reg_baseline_20260328_121542

This model is a fine-tuned version of meta-llama/Llama-3.1-8B-Instruct. It has been trained using TRL.

Quick start

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel, PeftConfig

# -----------------------------
# 1. Define the PEFT model ID & Checkpoint (Epoch)
# -----------------------------
peft_model_id = "xxccho/margin_reg_baseline"

# [Optional] ํŠน์ • Epoch์˜ ์ค‘๊ฐ„ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๊ณ  ์‹ถ์„ ๋•Œ ์•„๋ž˜ ๋ณ€์ˆ˜๋ฅผ ์ง€์ •ํ•˜์„ธ์š”.
# ์ง€์ •ํ•˜์ง€ ์•Š๊ณ  None์œผ๋กœ ๋‘๋ฉด ๋ ˆํฌ์ง€ํ† ๋ฆฌ ์ตœ์ƒ๋‹จ์— ์žˆ๋Š” ๋งˆ์ง€๋ง‰(์ตœ์ข…) ํ•™์Šต ๋ชจ๋ธ์ด ๋กœ๋“œ๋ฉ๋‹ˆ๋‹ค.
#
# [ Ckeckpoints to Epochs Mapping ]
# Epoch 1  : "checkpoint-246"
# Epoch 2  : "checkpoint-492"
# Epoch 3  : "checkpoint-738"
# Epoch 4  : "checkpoint-984"
# Epoch 5  : "checkpoint-1230"
# Epoch 6  : "checkpoint-1476"
# Epoch 7  : "checkpoint-1722"
# Epoch 8  : "checkpoint-1968"
# Epoch 9  : "checkpoint-2214"
# Epoch 10 : "checkpoint-2460"

# ์˜ˆ์‹œ: 5 Epoch ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์‚ฌ์šฉํ•˜๋ ค๋ฉด ์•„๋ž˜์™€ ๊ฐ™์ด ๋ณ€๊ฒฝํ•˜์„ธ์š”.
# checkpoint = "checkpoint-1230"  
checkpoint = None  # None์ด๋ฉด ๋””ํดํŠธ๋กœ ์ œ์ผ ๋งˆ์ง€๋ง‰ ์ €์žฅ ๋ชจ๋ธ์„ ์”๋‹ˆ๋‹ค.

# 2. Load the PEFT config (์ฒดํฌํฌ์ธํŠธ ์ง€์ • ์—ฌ๋ถ€์— ๋”ฐ๋ผ subfolder ์ ์šฉ)
if checkpoint:
    config = PeftConfig.from_pretrained(peft_model_id, subfolder=checkpoint)
else:
    config = PeftConfig.from_pretrained(peft_model_id)

# 3. Load tokenizer from base model (safer)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

# Llama padding fix
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# 4. Load base model
base_model = AutoModelForSequenceClassification.from_pretrained(
    config.base_model_name_or_path,
    num_labels=1,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

# 5. Apply LoRA adapter (์ฒดํฌํฌ์ธํŠธ ์ง€์ • ์—ฌ๋ถ€์— ๋”ฐ๋ผ subfolder ์ ์šฉ)
if checkpoint:
    model = PeftModel.from_pretrained(base_model, peft_model_id, subfolder=checkpoint)
else:
    model = PeftModel.from_pretrained(base_model, peft_model_id)

model.config.pad_token_id = tokenizer.pad_token_id
model.eval()

# -----------------------------
# Example Usage (chat format)
# -----------------------------
messages = [
    {"role": "user", "content": "What is the capital of France?"},
    {"role": "assistant", "content": "The capital of France is Paris."}
]

# Format prompt using chat template
formatted_prompt = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=False
)

inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)

# Get reward score
with torch.no_grad():
    outputs = model(**inputs)
    reward_score = outputs.logits.squeeze().item()

print(f"[Chat] Reward Score: {reward_score:.4f}")


# -----------------------------
# Example Usage (plain text)
# -----------------------------
text = "User: What is the capital of France?\nAssistant: Paris."

inputs = tokenizer(text, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model(**inputs)
    reward_score = outputs.logits.squeeze().item()

print(f"[Plain] Reward Score: {reward_score:.4f}")

Training procedure

Visualize in Weights & Biases

This model was trained with Reward.

Framework versions

  • PEFT 0.18.0
  • TRL: 0.26.1
  • Transformers: 4.57.3
  • Pytorch: 2.9.0
  • Datasets: 4.4.1
  • Tokenizers: 0.22.1

Citations

Cite TRL as:

@misc{vonwerra2022trl,
    title        = {{TRL: Transformer Reinforcement Learning}},
    author       = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec},
    year         = 2020,
    journal      = {GitHub repository},
    publisher    = {GitHub},
    howpublished = {\url{https://github.com/huggingface/trl}}
}
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for xxccho/margin_reg_baseline

Adapter
(1801)
this model