--- license: mit tags: - ppo - qlora - reinforcement-learning - llama-3 - mmlu pipeline_tag: text-generation --- # PPO-QLoRA Trained Model (spark-model-QLoRA) This repository contains an agent (actor and critic models) trained using Proximal Policy Optimization (PPO) with QLoRA. The training was performed using the scripts and models available in the `spark_rl` directory of the `explore-rl` project. **Base Model:** `meta-llama/Llama-3-8B-Instruct` (or specify if different, based on your `train.py` arguments) ## Model Components The `model_final` directory (uploaded here as the root of these components) contains: * **`actor/`**: LoRA adapters for the actor (policy) model. * **`critic/`**: LoRA adapters for the critic (value) model's base LLM, and a `value_head.pt` file for its custom value prediction head. * **`tokenizer/`**: The Hugging Face tokenizer used during training. * **`hyperparams.txt`**: Key hyperparameters used for the PPO training. * **`models.py`**: Contains the `LLMActorLora` and `LLMCriticLora` class definitions required to load and use these models. ## How to Use To use these models, you will need the `LLMActorLora` and `LLMCriticLora` classes from the included `models.py` file. ```python import torch from transformers import AutoTokenizer from models import LLMActorLora, LLMCriticLora # models.py is in this repository # --- Configuration --- BASE_MODEL_ID = "meta-llama/Llama-3-8B-Instruct" # IMPORTANT: Ensure this matches the model used for training! MODEL_REPO_PATH = "gabrielbo/spark-model-QLoRA" # Or local path if downloaded DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --- Load Tokenizer --- try: tokenizer = AutoTokenizer.from_pretrained(f"{MODEL_REPO_PATH}/tokenizer") except Exception: # Fallback if tokenizer is in the root tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_PATH) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" # Ensure consistency if PPO agent used left padding # --- Load Actor --- actor = LLMActorLora( device=DEVICE, model_id=BASE_MODEL_ID, # lora_r and disable_quantization can be defaults or from hyperparams.txt ) # Path to actor adapters within the model repo actor_adapters_path = f"{MODEL_REPO_PATH}/actor" actor.load_pretrained(actor_adapters_path) actor.model.eval() print("Actor loaded successfully.") # --- Load Critic --- critic = LLMCriticLora( device=DEVICE, model_id=BASE_MODEL_ID, # lora_r and disable_quantization can be defaults or from hyperparams.txt ) # Path to critic components within the model repo critic_components_path = f"{MODEL_REPO_PATH}/critic" critic.load_pretrained(critic_components_path) critic.model.eval() critic.value_head.eval() print("Critic loaded successfully.") # --- Example: Generating an action (conceptual) --- # This part is highly dependent on how your PPOAgent prepares inputs. # The following is a generic example. You'll need to adapt it. # Example input construction (refer to PPOAgent.prepare_batch) question = "What is the capital of France?" state_text = "The current context is a geography quiz." input_text = f"Question: {question}\n\nState: {state_text}\n\nAction:" inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE) print(f"\nGenerating action for: {input_text}") with torch.no_grad(): # Actor generates token IDs # Note: Generation kwargs might be needed (e.g., temperature, top_p from hyperparams.txt or evaluate.py) generated_ids = actor.generate( inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=50, # Adjust as needed # temperature=0.7, # Example # top_p=0.9, # Example do_sample=True # Example, if sampling was used ) # Decode the generated action # The generated output includes the input_text, so we need to slice it off. # This depends on tokenizer.padding_side; if "left", then slicing logic changes. # Assuming tokenizer.padding_side = "right" (default for many models) or handled by generate # If tokenizer.padding_side was "left" for generation, the input is at the end. # For simplicity, let's assume the output only contains new tokens after input. # This might need adjustment based on specific generation config. # A common way to get only the generated part: response_ids = generated_ids[0][inputs.input_ids.shape[-1]:] action_text = tokenizer.decode(response_ids, skip_special_tokens=True) print(f"Generated Action: {action_text.strip()}") # --- Example: Getting a value estimate (conceptual) --- value_prediction = critic.forward(inputs.input_ids, attention_mask=inputs.attention_mask) print(f"Value prediction for the state: {value_prediction.item()}") ``` ## Training Details The model was trained using the PPO algorithm with the following key settings (see `hyperparams.txt` for more details): * **Learning Rate (Actor)**: (Refer to `lr` in `hyperparams.txt`) * **Learning Rate (Critic)**: (Refer to `critic_lr` in `hyperparams.txt`) * **PPO Clip Ratio**: (Refer to `clip_ratio` in `hyperparams.txt`) * **KL Coefficient**: (Refer to `kl_coef` in `hyperparams.txt`) * **Target KL**: (Refer to `target_kl` in `hyperparams.txt`) * **Batch Size**: (As per your training script, e.g., `args.batch`) * **PPO Epochs**: (As per your training script, e.g., `args.ppo_epochs`) * **Total PPO Iterations**: (As per your training script, e.g., `args.steps`) The specific dataset used for training was MMLU trajectories. ## Intended Use This model is intended for tasks requiring sequential decision-making and reasoning, similar to the MMLU benchmark. It can be used as a starting point for further fine-tuning or for direct application in relevant domains. ## Limitations * The model's performance is tied to the quality and characteristics of the offline trajectory data it was trained on. * As a LoRA-adapted model, it relies on the capabilities of the base `meta-llama/Llama-3-8B-Instruct` model. * The generation behavior may require careful prompt engineering. ## Citation If you use this model or the `spark_rl` codebase, please consider citing the original `explore-rl` repository: [Link to your explore-rl GitHub repository, if public]