|
|
--- |
|
|
library_name: transformers |
|
|
datasets: |
|
|
- whynlp/gsm8k-aug |
|
|
base_model: |
|
|
- openai-community/gpt2 |
|
|
--- |
|
|
|
|
|
# PCCoT |
|
|
|
|
|
This is a research-purpose finetuned [GPT-2 Small](https://huggingface.co/openai-community/gpt2) model described in paper "[Parallel Continuous Chain-of-Thought with Jacobi Iteration](https://arxiv.org/abs/2506.18582)". |
|
|
|
|
|
## About |
|
|
|
|
|
Parallel Continuous Chain-of-Thought (PCCoT) parallelizes the continuous chain-of-thought reasoning process through Jacobi iteration. Instead of generating the latent thought tokens sequentially, it iteratively updates all latent thought tokens in parallel, significantly improving the training and inference efficiency of continuous chain-of-thought (CoT) reasoning with better performance on reasoning tasks. See more details in our Github repo: https://github.com/whyNLP/PCCoT |
|
|
|
|
|
## Quick Start |
|
|
|
|
|
The use of this model relies on the code in our [Github repo](https://github.com/whyNLP/PCCoT). See more detailed instructions in the repo. |
|
|
|
|
|
Below is a quick start example to showcase how to use this model for solving a math word problem. |
|
|
|
|
|
```python |
|
|
import models # from the github repo |
|
|
from transformers import AutoTokenizer, AutoConfig, HfArgumentParser |
|
|
from transformers.utils.hub import cached_file |
|
|
from peft import AutoPeftModel |
|
|
|
|
|
# Example model name |
|
|
model_name_or_path = "whyNLP/pccot-gpt2" |
|
|
# Example question |
|
|
question = "Every day, Wendi feeds each of her chickens three cups of mixed chicken feed, containing seeds, mealworms and vegetables to help keep them healthy. She gives the chickens their feed in three separate meals. In the morning, she gives her flock of chickens 15 cups of feed. In the afternoon, she gives her chickens another 25 cups of feed. How many cups of feed does she need to give her chickens in the final meal of the day if the size of Wendi's flock is 20 chickens?" |
|
|
|
|
|
|
|
|
# Load the model and tokenizer |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
|
|
config = AutoConfig.from_pretrained(model_name_or_path) |
|
|
model = AutoPeftModel.from_pretrained(model_name_or_path) |
|
|
|
|
|
# we have to override the model config after loading the model, peft does not provide interface to |
|
|
# load base model with custom config with AutoPeftModel. |
|
|
model.get_base_model().config = config |
|
|
|
|
|
# Load the PCCoT arguments |
|
|
pccot_args_file = cached_file(model_name_or_path, models.PCCOT_ARGS_NAME) |
|
|
parser = HfArgumentParser(models.PCCoTArguments) |
|
|
(pccot_args, ) = parser.parse_json_file(json_file=pccot_args_file) |
|
|
|
|
|
# Load the data processor |
|
|
data_processor = models.COTDataProcessor( |
|
|
tokenizer=tokenizer, |
|
|
pccot_args=pccot_args, |
|
|
) |
|
|
collated = data_processor.process(question) |
|
|
|
|
|
# generation |
|
|
decoded_tokens = model.generate( |
|
|
collated=collated, |
|
|
max_new_tokens=10, |
|
|
do_sample=False, |
|
|
# temperature=0.1, # you can adjust the temperature for more diverse answers |
|
|
# top_p=0.9, # you can adjust the top_p for more diverse answers |
|
|
# top_k=50, # you can adjust the top_k for more diverse answers |
|
|
) |
|
|
|
|
|
decoded_tokens = decoded_tokens[:, collated["input_ids"].shape[1]:] # remove the input_ids part |
|
|
answers = tokenizer.batch_decode(decoded_tokens, skip_special_tokens=True) |
|
|
|
|
|
# Print the answer |
|
|
print("Question:", question) |
|
|
print("Answer:", answers[0]) # 20 |
|
|
``` |
|
|
|
|
|
## Hyperparameters |
|
|
|
|
|
The model is finetuned using LoRA with the following hyperparameters: |
|
|
|
|
|
| Hyperparameter | Value | |
|
|
| ------------------- | ------------------------------ | |
|
|
| LoRA Rank \\(r\\) | 128 | |
|
|
| LoRA \\(\alpha\\) | 32 | |
|
|
| LoRA dropout | 0.1 | |
|
|
| LoRA bias | False | |
|
|
| LoRA target modules | `['c_attn', 'c_fc', 'c_proj']` | |
|
|
| Learning Rate | 3e-3 | |
|
|
| Weight Decay | 1e-2 | |
|
|
|
|
|
See more details in the paper. |
|
|
|
|
|
## Performance |
|
|
|
|
|
The model is finetuned on the [GSM8K-AUG](https://huggingface.co/datasets/whyNLP/gsm8k-aug) dataset, achieving the following performance: |
|
|
|
|
|
| Model | Accuracy | |
|
|
| --------- | -------------------- | |
|
|
| CoT | 44.1 | |
|
|
| **PCCoT** | 49.48 \\(\pm\\) 0.31 | |
|
|
|