File size: 6,553 Bytes
75ab49d 2a347f6 75ab49d 2a347f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
---
license: mit
tags:
- ppo
- qlora
- reinforcement-learning
- llama-3
- mmlu
pipeline_tag: text-generation
---
# PPO-QLoRA Trained Model (spark-model-QLoRA)
This repository contains an agent (actor and critic models) trained using Proximal Policy Optimization (PPO) with QLoRA.
The training was performed using the scripts and models available in the `spark_rl` directory of the `explore-rl` project.
**Base Model:** `meta-llama/Llama-3-8B-Instruct` (or specify if different, based on your `train.py` arguments)
## Model Components
The `model_final` directory (uploaded here as the root of these components) contains:
* **`actor/`**: LoRA adapters for the actor (policy) model.
* **`critic/`**: LoRA adapters for the critic (value) model's base LLM, and a `value_head.pt` file for its custom value prediction head.
* **`tokenizer/`**: The Hugging Face tokenizer used during training.
* **`hyperparams.txt`**: Key hyperparameters used for the PPO training.
* **`models.py`**: Contains the `LLMActorLora` and `LLMCriticLora` class definitions required to load and use these models.
## How to Use
To use these models, you will need the `LLMActorLora` and `LLMCriticLora` classes from the included `models.py` file.
```python
import torch
from transformers import AutoTokenizer
from models import LLMActorLora, LLMCriticLora # models.py is in this repository
# --- Configuration ---
BASE_MODEL_ID = "meta-llama/Llama-3-8B-Instruct" # IMPORTANT: Ensure this matches the model used for training!
MODEL_REPO_PATH = "gabrielbo/spark-model-QLoRA" # Or local path if downloaded
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Load Tokenizer ---
try:
tokenizer = AutoTokenizer.from_pretrained(f"{MODEL_REPO_PATH}/tokenizer")
except Exception: # Fallback if tokenizer is in the root
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_PATH)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" # Ensure consistency if PPO agent used left padding
# --- Load Actor ---
actor = LLMActorLora(
device=DEVICE,
model_id=BASE_MODEL_ID,
# lora_r and disable_quantization can be defaults or from hyperparams.txt
)
# Path to actor adapters within the model repo
actor_adapters_path = f"{MODEL_REPO_PATH}/actor"
actor.load_pretrained(actor_adapters_path)
actor.model.eval()
print("Actor loaded successfully.")
# --- Load Critic ---
critic = LLMCriticLora(
device=DEVICE,
model_id=BASE_MODEL_ID,
# lora_r and disable_quantization can be defaults or from hyperparams.txt
)
# Path to critic components within the model repo
critic_components_path = f"{MODEL_REPO_PATH}/critic"
critic.load_pretrained(critic_components_path)
critic.model.eval()
critic.value_head.eval()
print("Critic loaded successfully.")
# --- Example: Generating an action (conceptual) ---
# This part is highly dependent on how your PPOAgent prepares inputs.
# The following is a generic example. You'll need to adapt it.
# Example input construction (refer to PPOAgent.prepare_batch)
question = "What is the capital of France?"
state_text = "The current context is a geography quiz."
input_text = f"Question: {question}\n\nState: {state_text}\n\nAction:"
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
print(f"\nGenerating action for: {input_text}")
with torch.no_grad():
# Actor generates token IDs
# Note: Generation kwargs might be needed (e.g., temperature, top_p from hyperparams.txt or evaluate.py)
generated_ids = actor.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=50, # Adjust as needed
# temperature=0.7, # Example
# top_p=0.9, # Example
do_sample=True # Example, if sampling was used
)
# Decode the generated action
# The generated output includes the input_text, so we need to slice it off.
# This depends on tokenizer.padding_side; if "left", then slicing logic changes.
# Assuming tokenizer.padding_side = "right" (default for many models) or handled by generate
# If tokenizer.padding_side was "left" for generation, the input is at the end.
# For simplicity, let's assume the output only contains new tokens after input.
# This might need adjustment based on specific generation config.
# A common way to get only the generated part:
response_ids = generated_ids[0][inputs.input_ids.shape[-1]:]
action_text = tokenizer.decode(response_ids, skip_special_tokens=True)
print(f"Generated Action: {action_text.strip()}")
# --- Example: Getting a value estimate (conceptual) ---
value_prediction = critic.forward(inputs.input_ids, attention_mask=inputs.attention_mask)
print(f"Value prediction for the state: {value_prediction.item()}")
```
## Training Details
The model was trained using the PPO algorithm with the following key settings (see `hyperparams.txt` for more details):
* **Learning Rate (Actor)**: (Refer to `lr` in `hyperparams.txt`)
* **Learning Rate (Critic)**: (Refer to `critic_lr` in `hyperparams.txt`)
* **PPO Clip Ratio**: (Refer to `clip_ratio` in `hyperparams.txt`)
* **KL Coefficient**: (Refer to `kl_coef` in `hyperparams.txt`)
* **Target KL**: (Refer to `target_kl` in `hyperparams.txt`)
* **Batch Size**: (As per your training script, e.g., `args.batch`)
* **PPO Epochs**: (As per your training script, e.g., `args.ppo_epochs`)
* **Total PPO Iterations**: (As per your training script, e.g., `args.steps`)
The specific dataset used for training was MMLU trajectories.
## Intended Use
This model is intended for tasks requiring sequential decision-making and reasoning, similar to the MMLU benchmark. It can be used as a starting point for further fine-tuning or for direct application in relevant domains.
## Limitations
* The model's performance is tied to the quality and characteristics of the offline trajectory data it was trained on.
* As a LoRA-adapted model, it relies on the capabilities of the base `meta-llama/Llama-3-8B-Instruct` model.
* The generation behavior may require careful prompt engineering.
## Citation
If you use this model or the `spark_rl` codebase, please consider citing the original `explore-rl` repository:
[Link to your explore-rl GitHub repository, if public]
|