spark-model-QLoRA / README.md
gabrielbo's picture
Add PPO trained model (actor, critic, tokenizer, hyperparams) and models.py
2a347f6
---
license: mit
tags:
- ppo
- qlora
- reinforcement-learning
- llama-3
- mmlu
pipeline_tag: text-generation
---
# PPO-QLoRA Trained Model (spark-model-QLoRA)
This repository contains an agent (actor and critic models) trained using Proximal Policy Optimization (PPO) with QLoRA.
The training was performed using the scripts and models available in the `spark_rl` directory of the `explore-rl` project.
**Base Model:** `meta-llama/Llama-3-8B-Instruct` (or specify if different, based on your `train.py` arguments)
## Model Components
The `model_final` directory (uploaded here as the root of these components) contains:
* **`actor/`**: LoRA adapters for the actor (policy) model.
* **`critic/`**: LoRA adapters for the critic (value) model's base LLM, and a `value_head.pt` file for its custom value prediction head.
* **`tokenizer/`**: The Hugging Face tokenizer used during training.
* **`hyperparams.txt`**: Key hyperparameters used for the PPO training.
* **`models.py`**: Contains the `LLMActorLora` and `LLMCriticLora` class definitions required to load and use these models.
## How to Use
To use these models, you will need the `LLMActorLora` and `LLMCriticLora` classes from the included `models.py` file.
```python
import torch
from transformers import AutoTokenizer
from models import LLMActorLora, LLMCriticLora # models.py is in this repository
# --- Configuration ---
BASE_MODEL_ID = "meta-llama/Llama-3-8B-Instruct" # IMPORTANT: Ensure this matches the model used for training!
MODEL_REPO_PATH = "gabrielbo/spark-model-QLoRA" # Or local path if downloaded
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Load Tokenizer ---
try:
tokenizer = AutoTokenizer.from_pretrained(f"{MODEL_REPO_PATH}/tokenizer")
except Exception: # Fallback if tokenizer is in the root
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_PATH)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" # Ensure consistency if PPO agent used left padding
# --- Load Actor ---
actor = LLMActorLora(
device=DEVICE,
model_id=BASE_MODEL_ID,
# lora_r and disable_quantization can be defaults or from hyperparams.txt
)
# Path to actor adapters within the model repo
actor_adapters_path = f"{MODEL_REPO_PATH}/actor"
actor.load_pretrained(actor_adapters_path)
actor.model.eval()
print("Actor loaded successfully.")
# --- Load Critic ---
critic = LLMCriticLora(
device=DEVICE,
model_id=BASE_MODEL_ID,
# lora_r and disable_quantization can be defaults or from hyperparams.txt
)
# Path to critic components within the model repo
critic_components_path = f"{MODEL_REPO_PATH}/critic"
critic.load_pretrained(critic_components_path)
critic.model.eval()
critic.value_head.eval()
print("Critic loaded successfully.")
# --- Example: Generating an action (conceptual) ---
# This part is highly dependent on how your PPOAgent prepares inputs.
# The following is a generic example. You'll need to adapt it.
# Example input construction (refer to PPOAgent.prepare_batch)
question = "What is the capital of France?"
state_text = "The current context is a geography quiz."
input_text = f"Question: {question}\n\nState: {state_text}\n\nAction:"
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
print(f"\nGenerating action for: {input_text}")
with torch.no_grad():
# Actor generates token IDs
# Note: Generation kwargs might be needed (e.g., temperature, top_p from hyperparams.txt or evaluate.py)
generated_ids = actor.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=50, # Adjust as needed
# temperature=0.7, # Example
# top_p=0.9, # Example
do_sample=True # Example, if sampling was used
)
# Decode the generated action
# The generated output includes the input_text, so we need to slice it off.
# This depends on tokenizer.padding_side; if "left", then slicing logic changes.
# Assuming tokenizer.padding_side = "right" (default for many models) or handled by generate
# If tokenizer.padding_side was "left" for generation, the input is at the end.
# For simplicity, let's assume the output only contains new tokens after input.
# This might need adjustment based on specific generation config.
# A common way to get only the generated part:
response_ids = generated_ids[0][inputs.input_ids.shape[-1]:]
action_text = tokenizer.decode(response_ids, skip_special_tokens=True)
print(f"Generated Action: {action_text.strip()}")
# --- Example: Getting a value estimate (conceptual) ---
value_prediction = critic.forward(inputs.input_ids, attention_mask=inputs.attention_mask)
print(f"Value prediction for the state: {value_prediction.item()}")
```
## Training Details
The model was trained using the PPO algorithm with the following key settings (see `hyperparams.txt` for more details):
* **Learning Rate (Actor)**: (Refer to `lr` in `hyperparams.txt`)
* **Learning Rate (Critic)**: (Refer to `critic_lr` in `hyperparams.txt`)
* **PPO Clip Ratio**: (Refer to `clip_ratio` in `hyperparams.txt`)
* **KL Coefficient**: (Refer to `kl_coef` in `hyperparams.txt`)
* **Target KL**: (Refer to `target_kl` in `hyperparams.txt`)
* **Batch Size**: (As per your training script, e.g., `args.batch`)
* **PPO Epochs**: (As per your training script, e.g., `args.ppo_epochs`)
* **Total PPO Iterations**: (As per your training script, e.g., `args.steps`)
The specific dataset used for training was MMLU trajectories.
## Intended Use
This model is intended for tasks requiring sequential decision-making and reasoning, similar to the MMLU benchmark. It can be used as a starting point for further fine-tuning or for direct application in relevant domains.
## Limitations
* The model's performance is tied to the quality and characteristics of the offline trajectory data it was trained on.
* As a LoRA-adapted model, it relies on the capabilities of the base `meta-llama/Llama-3-8B-Instruct` model.
* The generation behavior may require careful prompt engineering.
## Citation
If you use this model or the `spark_rl` codebase, please consider citing the original `explore-rl` repository:
[Link to your explore-rl GitHub repository, if public]