{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "machine_shape": "hm", "gpuType": "V100" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "source": [ "from google.colab import drive\n", "drive.mount('/content/drive')" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "_aRV8S4m0ZHl", "outputId": "2b72e648-d5fd-4121-e5da-853f35b932cb" }, "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n" ] } ] }, { "cell_type": "code", "source": [ "!pip install PyArabic" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "xjPR3Dp6yCB1", "outputId": "8603b4e4-5a7b-4c92-8260-48e693eaf621" }, "execution_count": 2, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: PyArabic in /usr/local/lib/python3.10/dist-packages (0.6.15)\n", "Requirement already satisfied: six>=1.14.0 in /usr/local/lib/python3.10/dist-packages (from PyArabic) (1.16.0)\n" ] } ] }, { "cell_type": "code", "source": [ "!cp /content/drive/MyDrive/preprocess.py /content/preprocess.py" ], "metadata": { "id": "gDk2FAdA1SIs" }, "execution_count": 3, "outputs": [] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "KUcx6LbbxSU_", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "ff8d2420-2e37-497c-ad38-96a2a699984a" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: \n", "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", "You will be able to reuse this secret in all of your notebooks.\n", "Please note that authentication is recommended but still optional to access public models or datasets.\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", " return self.fget.__get__(instance, owner)()\n" ] } ], "source": [ "from transformers import ElectraForQuestionAnswering, ElectraForSequenceClassification, AutoTokenizer, pipeline\n", "from preprocess import ArabertPreprocessor\n", "\n", "# Define ArabertPreprocessor if not already defined\n", "prep_object = ArabertPreprocessor(\"araelectra-base-discriminator\")\n", "\n", "# Preprocess the question and context\n", "question = ('ماذا اكلت اليوم ؟')\n", "context = ('''\n", "اليوم اكلت تفاحة''')\n", "\n", "# a) Get predictions\n", "qa_modelname = 'ZeyadAhmed/AraElectra-Arabic-SQuADv2-QA'\n", "qa_pipe = pipeline('question-answering', model=qa_modelname, tokenizer=qa_modelname)\n", "\n", "QA_input = {\n", " 'question': question,\n", " 'context': context\n", "}\n", "qa_res = qa_pipe(QA_input)\n", "threshold = 0.5 #hyperparameter can be tweaked\n", "## note classification results label0 probability it can be answered label1 probability can't be answered\n", "## if label1 probability > threshold then consider the output of qa_res is empty string else take the qa_res\n", "# b) Load model & tokenizer\n", "qa_model = ElectraForQuestionAnswering.from_pretrained(qa_modelname)\n", "tokenizer = AutoTokenizer.from_pretrained(qa_modelname)\n" ] }, { "cell_type": "code", "source": [ "# Example of evaluating classification results\n", "print(\"الجواب:\")\n", "print(qa_res['answer'])\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "hBMw89MIzqvh", "outputId": "5aebb9ae-5b2a-4e57-8fb2-a912ba238610" }, "execution_count": 5, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "الجواب:\n", "تفاحة\n" ] } ] }, { "cell_type": "code", "source": [ "!pip install tnkeeh" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Ehrs4yQqY5Yr", "outputId": "5fb0ea69-870f-4a7c-9b61-94f65630cc78" }, "execution_count": 6, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: tnkeeh in /usr/local/lib/python3.10/dist-packages (0.0.9)\n", "Requirement already satisfied: farasapy in /usr/local/lib/python3.10/dist-packages (from tnkeeh) (0.0.14)\n", "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (from tnkeeh) (2.18.0)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets->tnkeeh) (3.13.1)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets->tnkeeh) (1.25.2)\n", "Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets->tnkeeh) (14.0.2)\n", "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets->tnkeeh) (0.6)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets->tnkeeh) (0.3.8)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets->tnkeeh) (1.5.3)\n", "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets->tnkeeh) (2.31.0)\n", "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets->tnkeeh) (4.66.2)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets->tnkeeh) (3.4.1)\n", "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets->tnkeeh) (0.70.16)\n", "Requirement already satisfied: fsspec[http]<=2024.2.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets->tnkeeh) (2023.6.0)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets->tnkeeh) (3.9.3)\n", "Requirement already satisfied: huggingface-hub>=0.19.4 in /usr/local/lib/python3.10/dist-packages (from datasets->tnkeeh) (0.20.3)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets->tnkeeh) (23.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets->tnkeeh) (6.0.1)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->tnkeeh) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->tnkeeh) (23.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->tnkeeh) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->tnkeeh) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->tnkeeh) (1.9.4)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->tnkeeh) (4.0.3)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.19.4->datasets->tnkeeh) (4.10.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets->tnkeeh) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets->tnkeeh) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets->tnkeeh) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets->tnkeeh) (2024.2.2)\n", "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->tnkeeh) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->tnkeeh) (2023.4)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets->tnkeeh) (1.16.0)\n" ] } ] }, { "cell_type": "code", "source": [ "import tnkeeh as tn\n", "\n", "tn.clean_data(file_path = '/content/drive/MyDrive/arabic222_qna_222dataset_3.txt', save_path = '/content/cleaned_data.txt',)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ynNjkmYqWv3s", "outputId": "d0d7900a-7d86-419f-cb44-47e4ed1ece21" }, "execution_count": 7, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Saving to /content/cleaned_data.txt\n" ] } ] }, { "cell_type": "code", "source": [ "with open(\"/content/cleaned_data.txt\", \"r\", encoding=\"utf-8\") as file:\n", " data = file.readlines()\n" ], "metadata": { "id": "_dQqujxkZMMN" }, "execution_count": 8, "outputs": [] }, { "cell_type": "code", "source": [ "contexts = []\n", "questions = []\n", "answers = []\n", "\n", "for i in range(0, len(data), 6):\n", " if len(data) > i+5:\n", " context = data[i+1].strip()\n", " question = data[i+3].strip()\n", " answer = data[i+5].strip()\n", "\n", " if context and question and answer and context != \"سؤال :\" and question != 'جواب :' and answer != 'السياق :':\n", " contexts.append(context)\n", " questions.append(question)\n", " answers.append(answer)\n" ], "metadata": { "id": "NZFO-WAXco7d" }, "execution_count": 9, "outputs": [] }, { "cell_type": "code", "source": [ "!pip install accelerate -U" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "AL7LJEkbpJdE", "outputId": "297622be-1261-42dc-a8c5-b8bce6234098" }, "execution_count": 10, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (0.27.2)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate) (1.25.2)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (23.2)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n", "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate) (6.0.1)\n", "Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.1.0+cu121)\n", "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.20.3)\n", "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.4.2)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.13.1)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (4.10.0)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.12)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.2.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.3)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2023.6.0)\n", "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.1.0)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->accelerate) (2.31.0)\n", "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->accelerate) (4.66.2)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.5)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (2024.2.2)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n" ] } ] }, { "cell_type": "code", "source": [ "from transformers import ElectraForQuestionAnswering, TrainingArguments, Trainer\n", "from torch.utils.data import Dataset\n", "import torch\n", "\n", "# Define model and tokenizer\n", "model_name = \"ZeyadAhmed/AraElectra-Arabic-SQuADv2-QA\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "model = ElectraForQuestionAnswering.from_pretrained(model_name)\n", "\n", "class QADataset(Dataset):\n", " def __init__(self, contexts, questions, answers, tokenizer, max_length):\n", " self.encodings = tokenizer(contexts, questions, truncation=True, padding=True, max_length=max_length)\n", " self.answers = []\n", " for i, (context, answer) in enumerate(zip(contexts, answers)):\n", " answer_start_idx = context.find(answer)\n", " if answer_start_idx == -1: # Answer not found in context\n", " answer_token_start_idx = answer_token_end_idx = -1\n", " else:\n", " answer_end_idx = answer_start_idx + len(answer)\n", " answer_token_start_idx = self.encodings.char_to_token(i, answer_start_idx)\n", " answer_token_end_idx = self.encodings.char_to_token(i, answer_end_idx)\n", " self.answers.append((answer_token_start_idx, answer_token_end_idx))\n", "\n", " def __len__(self):\n", " return len(self.encodings[\"input_ids\"])\n", "\n", " def __getitem__(self, idx):\n", " item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n", " start_pos, end_pos = self.answers[idx]\n", " if start_pos == -1 or end_pos == -1: # Answer not found in context\n", " item[\"start_positions\"] = torch.tensor(-100) # Assign a tensor with special value for 'ignored' spans\n", " item[\"end_positions\"] = torch.tensor(-100) # Assign a tensor with special value for 'ignored' spans\n", " else:\n", " item[\"start_positions\"] = torch.tensor(start_pos)\n", " item[\"end_positions\"] = torch.tensor(end_pos) if end_pos is not None else torch.tensor(start_pos) # Use start position if end position is None\n", " return item\n", "\n", "\n", "\n", "train_dataset = QADataset(contexts, questions, answers, tokenizer, max_length=512)\n", "\n", "# Define training arguments\n", "training_args = TrainingArguments(\n", " output_dir=\"./output\", # Specify output directory\n", " per_device_train_batch_size=8,\n", " num_train_epochs=3,\n", " logging_dir='./logs',\n", ")\n", "\n", "# Define trainer\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", ")\n", "\n", "# Fine-tune the model\n", "trainer.train()\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 173 }, "id": "mz8PfQ1QoMWI", "outputId": "01a4fd93-7f83-4702-c283-02dbba627050" }, "execution_count": 11, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "
\n", " \n", " \n", " [1191/1191 04:44, Epoch 3/3]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
5001.654000
10001.139700

" ] }, "metadata": {} }, { "output_type": "execute_result", "data": { "text/plain": [ "TrainOutput(global_step=1191, training_loss=1.330872261053769, metrics={'train_runtime': 286.4473, 'train_samples_per_second': 33.21, 'train_steps_per_second': 4.158, 'total_flos': 2272099824055152.0, 'train_loss': 1.330872261053769, 'epoch': 3.0})" ] }, "metadata": {}, "execution_count": 11 } ] }, { "cell_type": "code", "source": [ "# Save the fine-tuned model\n", "trainer.save_model(\"/content/model\")" ], "metadata": { "id": "am9J2LqIohcT" }, "execution_count": 12, "outputs": [] }, { "cell_type": "code", "source": [ "from transformers import ElectraForQuestionAnswering, ElectraForSequenceClassification, AutoTokenizer, pipeline\n", "from preprocess import ArabertPreprocessor\n", "\n", "# Define ArabertPreprocessor if not already defined\n", "prep_object = ArabertPreprocessor(\"araelectra-base-discriminator\")\n", "\n", "# Preprocess the question and context\n", "question = ('ماذا اكلت اليوم ؟')\n", "context = ('''\n", "اليوم اكلت تفاحة''')\n", "\n", "# a) Get predictions\n", "qa_modelname = 'ZeyadAhmed/AraElectra-Arabic-SQuADv2-QA'\n", "qa_pipe = pipeline('question-answering', model=qa_modelname, tokenizer=qa_modelname)\n", "\n", "QA_input = {\n", " 'question': question,\n", " 'context': context\n", "}\n", "qa_res = qa_pipe(QA_input)\n", "threshold = 0.5 #hyperparameter can be tweaked\n", "## note classification results label0 probability it can be answered label1 probability can't be answered\n", "## if label1 probability > threshold then consider the output of qa_res is empty string else take the qa_res\n", "# b) Load model & tokenizer\n", "qa_model = ElectraForQuestionAnswering.from_pretrained(qa_modelname)\n", "tokenizer = AutoTokenizer.from_pretrained(qa_modelname)" ], "metadata": { "id": "TYWogpVE3tqr" }, "execution_count": null, "outputs": [] } ] }