{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "5f93b7d1", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForSeq2SeqLM\n", "from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, PrefixTuningConfig, TaskType\n", "import torch\n", "from datasets import load_dataset\n", "import os\n", "\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", "from transformers import AutoTokenizer\n", "from torch.utils.data import DataLoader\n", "from transformers import default_data_collator, get_linear_schedule_with_warmup\n", "from tqdm import tqdm\n", "from datasets import load_dataset\n", "\n", "device = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\"\n", "model_name_or_path = \"t5-large\"\n", "tokenizer_name_or_path = \"t5-large\"\n", "\n", "checkpoint_name = \"financial_sentiment_analysis_prefix_tuning_v1.pt\"\n", "text_column = \"sentence\"\n", "label_column = \"text_label\"\n", "max_length = 128\n", "lr = 1e-2\n", "num_epochs = 5\n", "batch_size = 8" ] }, { "cell_type": "code", "execution_count": 2, "id": "8d0850ac", "metadata": {}, "outputs": [], "source": [ "# creating model\n", "peft_config = PrefixTuningConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, num_virtual_tokens=20)\n", "\n", "model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)\n", "model = get_peft_model(model, peft_config)\n", "model.print_trainable_parameters()\n", "model" ] }, { "cell_type": "code", "execution_count": 3, "id": "4ee2babf", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using the latest cached version of the dataset since financial_phrasebank couldn't be found on the Hugging Face Hub\n", "Found the latest cached dataset configuration 'sentences_allagree' at /root/.cache/huggingface/datasets/financial_phrasebank/sentences_allagree/1.0.0/550bde12e6c30e2674da973a55f57edde5181d53f5a5a34c1531c53f93b7e141 (last modified on Thu Jul 31 06:23:15 2025).\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3b321971d6f942418bd5ef6105a1aa65", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/2037 [00:00