{ "cells": [ { "cell_type": "markdown", "id": "212e1052-e0d9-404f-a4ee-db199a4c6d17", "metadata": {}, "source": [ "# 3.2 序列结构预测" ] }, { "cell_type": "markdown", "id": "0eb5d83c-8dd6-498b-adc9-1f74c97c3427", "metadata": {}, "source": [ "蛋白质的结构可分为四级:\n", "\n", "1. 一级结构也就是氨基酸序列;\n", "2. 二级结构是周期性的结构构象,比如α螺旋β折叠等\n", "3. 三级结构是整条多肽链的三维空间结构\n", "4. 四级结构是几个蛋白质分子形成的复合体结构,比如三聚体,四聚 体等\n", "\n", "\n", "二级结构(Secondary Structure)是指生物大分子如蛋白质和核酸(RNA 和 DNA)中局部的、有规则的空间构象。这些结构是由分子内的一些化学键或相互作用稳定下来的,但不涉及整个分子的整体折叠状态。以下是关于蛋白质和 RNA 二级结构的简单介绍:\n", "\n", "### 蛋白质的二级结构\n", "\n", "蛋白质的二级结构主要由主链原子间的氢键形成,具体包括以下几种常见的类型:\n", "\n", "1. **α-螺旋 (Alpha Helix)**\n", " - **描述**:一个右手螺旋结构,每个氨基酸残基沿螺旋轴旋转约 100 度,并沿着轴向上移动约 1.5 Å。\n", " - **特点**:通过相邻的肽键之间形成的氢键稳定,通常每 3.6 个氨基酸残基转一圈。\n", "\n", "2. **β-折叠片 (Beta Sheet)**\n", " - **描述**:由多个几乎平行或反平行排列的多肽链组成,链间通过氢键连接。\n", " - **特点**:可以是平行(所有链同向)或反平行(相邻链方向相反),提供了高度刚性的平面结构。\n", "\n", "3. **转角 (Turns)**\n", " - **描述**:短的序列片段,通常包含 3 到 4 个氨基酸残基,用于改变多肽链的方向。\n", " - **特点**:最常见的类型是 β-转角(beta turn),它使得链可以在空间上回折。\n", "\n", "4. **无规则卷曲 (Random Coil)**\n", " - **描述**:没有固定模式的区域,可能是由于缺乏足够的氢键或其他稳定力。\n", " - **特点**:虽然称为“无规则”,但实际上可能在特定环境下具有功能性意义。\n", "\n", "\n", "\n", "\n", "蛋白质的二级结构经常用图形来形象的描述。比如下图中黄色的箭头代表对应的氨基酸 具有β折片结构。波浪线代表螺旋结构,小鼓包是转角。此外,以字母形式书写的二级结构序列能够更加精准的描述。\n", "其中,E 代表β折叠,H 代表α螺旋,T 代表转角。没有写任何字母的地方是松散的 coil 结构。很多序列预测数据集中,一般不区分转角和coil结构。\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "id": "c90a583c-f6a5-4a41-8e7e-da27b7e95c50", "metadata": {}, "source": [ "获得实验测定的蛋白质或 RNA 的二级结构数据,通常需要依赖于实验室技术和公共数据库中已发表的实验结果。以下是一些常用的资源和方法,帮助你获取经过实验验证的二级结构数据:\n", "\n", "### 1. **蛋白质二级结构数据**\n", "\n", "#### a. **PDB (Protein Data Bank)**\n", "\n", "- **网址**:[RCSB PDB](https://www.rcsb.org/)\n", "- **特点**:PDB 是一个全球性的生物大分子结构数据库,包含通过 X 射线晶体学、核磁共振(NMR)和冷冻电镜(Cryo-EM)等实验方法测定的蛋白质三维结构。\n", "- **使用方法**:\n", " - 搜索特定蛋白质的 PDB ID 或名称。\n", " - 查看详细条目页面,其中包含了蛋白质的三级结构信息,可以通过可视化工具如 PyMOL 或 Chimera 来观察二级结构元素(如 α-螺旋、β-折叠片等)。\n", "\n", "\n", "\n", "From https://www.rcsb.org/sequence/9rsa\n", "\n", "#### b. **PDBe (Protein Data Bank in Europe)**\n", "\n", "- **网址**:[PDBe](https://www.ebi.ac.uk/pdbe/)\n", "- **特点**:PDBe 是欧洲的 PDB 镜像站点,提供了与 RCSB PDB 类似的功能,并且有额外的分析工具和注释信息。\n", "- **使用方法**:\n", " - 搜索蛋白质的 PDB ID 或名称。\n", " - 使用 PDBe-KB 和其他工具来获取详细的结构信息和二级结构注释。\n", "\n", "#### c. **Biomolecule Structure Knowledgebase (BSK)**\n", "\n", "- **网址**:[BSK](https://bsk.pdbj.org/)\n", "- **特点**:BSK 是日本的 PDB 镜像站点,同样提供丰富的结构数据和分析工具。\n", "- **使用方法**:\n", " - 搜索蛋白质的 PDB ID 或名称。\n", " - 浏览条目以获取详细的结构信息和二级结构注释。\n", "\n", "\n", "\n", "### 3. **实验方法**\n", "\n", "如果你需要最新的或特定条件下的二级结构数据,可能需要参考文献中的实验方法。以下是一些常见的实验技术:\n", "\n", "#### a. **X 射线晶体学**\n", "\n", "- **原理**:通过解析蛋白质或 RNA 晶体的衍射图案来确定其三维结构。\n", "- **应用**:适用于能够形成稳定晶体的分子。\n", "\n", "#### b. **核磁共振(NMR)**\n", "\n", "- **原理**:利用核磁共振波谱技术来确定溶液状态下分子的结构。\n", "- **应用**:适用于较小的蛋白质和 RNA 分子。\n", "\n", "#### c. **冷冻电镜(Cryo-EM)**\n", "\n", "- **原理**:通过低温冷冻样品并在电子显微镜下成像来确定分子结构。\n", "- **应用**:适用于较大的复合物和难以结晶的分子。\n", "\n", "\n", "\n", "### 4. **文献检索**\n", "\n", "#### a. **PubMed**\n", "\n", "- **网址**:[PubMed](https://pubmed.ncbi.nlm.nih.gov/)\n", "- **特点**:PubMed 是一个广泛使用的生物医学文献数据库,提供了大量关于蛋白质和 RNA 功能及结构的研究论文。\n", "- **使用方法**:\n", " - 使用关键词搜索与特定蛋白质或 RNA 相关的实验研究。\n", " - 阅读论文以获取详细的实验数据和二级结构描述。\n", "\n", "### 总结\n", "\n", "获得实验测定的蛋白质或 RNA 的二级结构数据主要依赖于公共数据库如 PDB 和 NDB,这些数据库收录了通过多种实验方法测定的结构信息。此外,查阅相关文献也是一种重要的途径,可以找到最新的或特定条件下的实验结果。对于具体的实验方法,如 X 射线晶体学、NMR 和 Cryo-EM 等,它们各自有适用的场景和优势。\n" ] }, { "cell_type": "markdown", "id": "1cadfd11-2130-429d-848f-39371356ca10", "metadata": {}, "source": [ "## 整理好的数据\n", "\n", "https://huggingface.co/datasets/proteinea/secondary_structure_prediction\n", "\n", "\n", "\n", "https://huggingface.co/datasets/genbio-ai/rna-secondary-structure-prediction" ] }, { "cell_type": "code", "execution_count": 24, "id": "134a72e3-597a-446e-9193-d060a6e677f6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\"\\nimport os\\n\\n# 设置环境变量, autodl专区 其他idc\\nos.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\\n\\n# 打印环境变量以确认设置成功\\nprint(os.environ.get('HF_ENDPOINT'))\\n\"" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import subprocess\n", "import os\n", "# 设置环境变量, autodl一般区域\n", "result = subprocess.run('bash -c \"source /etc/network_turbo && env | grep proxy\"', shell=True, capture_output=True, text=True)\n", "output = result.stdout\n", "for line in output.splitlines():\n", " if '=' in line:\n", " var, value = line.split('=', 1)\n", " os.environ[var] = value\n", "\n", "\"\"\"\n", "import os\n", "\n", "# 设置环境变量, autodl专区 其他idc\n", "os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n", "\n", "# 打印环境变量以确认设置成功\n", "print(os.environ.get('HF_ENDPOINT'))\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": 25, "id": "b43dd5f2-6b23-4b51-ad04-7b7ded732cb7", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer, AutoModel, TrainingArguments, Trainer\n", "from tokenizers import Tokenizer\n", "from transformers import GPT2LMHeadModel, AutoConfig,GPT2Tokenizer\n", "from transformers import AutoModelForTokenClassification \n", "from transformers import DataCollatorWithPadding" ] }, { "cell_type": "code", "execution_count": 26, "id": "4c66fa5b-b8b8-4dfd-ada1-32ed9e690c33", "metadata": {}, "outputs": [], "source": [ "#set tokenizer,dna protein \n", "tokenizer = GPT2Tokenizer.from_pretrained(\"dnagpt/gene_eng_gpt2_v0\")\n", "tokenizer.pad_token = tokenizer.eos_token" ] }, { "cell_type": "code", "execution_count": 27, "id": "70a3fd79-48bf-4452-a7ee-689f1b11e987", "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "# 1. load ~11k samples from promoters prediction dataset\n", "dataset = load_dataset(\"proteinea/secondary_structure_prediction\")['train'].train_test_split(test_size=0.1)" ] }, { "cell_type": "code", "execution_count": 28, "id": "13cd141e-98c3-47da-8e21-cba5576707fe", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['input', 'dssp3', 'dssp8', 'disorder', 'cb513_mask'],\n", " num_rows: 9712\n", " })\n", " test: Dataset({\n", " features: ['input', 'dssp3', 'dssp8', 'disorder', 'cb513_mask'],\n", " num_rows: 1080\n", " })\n", "})" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset" ] }, { "cell_type": "code", "execution_count": 29, "id": "7936af74-3f5f-43c1-aa69-fd7b08989e24", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'input': 'MTQTQPVTPTPPASFQTQHDPRTRLGATPLPGGAGTRFRLWTSTARTVAVRVNGTEHVMTSLGGGIYELELPVGPGARYLFVLDGVPTPDPYARFLPDGVHGEAEVVDFGTFDWTDADWHGIKLADCVFYEVHVGTFTPEGTYRAAAEKLPYLKELGVTAIQVMPLAAFDGQRGWGYDGAAFYAPYAPYGRPEDLMALVDAAHRLGLGVFLDVVYNHFGPSGNYLSSYAPSYFTDRFSSAWGMGLDYAEPHMRRYVTGNARMWLRDYHFDGLRLDATPYMTDDSETHILTELAQEIHELGGTHLLLAEDHRNLPDLVTVNHLDGIWTDDFHHETRVTLTGEQEGYYAGYRGGAEALAYTIRRGWRYEGQFWAVKGEEHERGHPSDALEAPNFVYCIQNHDQIGNRPLGERLHQSDGVTLHEYRGAAALLLTLPMTPLLFQGQEWAASTPFQFFSDHAGELGQAVSEGRKKEFGGFSGFSGEDVPDPQAEQTFLNSKLNWAEREGGEHARTLRLYRDLLRLRREDPVLHNRQRENLTTGHDGDVLWVRTVTGAGERVLLWNLGQDTRAVAEVKLPFTVPRRLLLHTEGREDLTLGAGEAVLVG',\n", " 'dssp3': 'CCCCCCCCCCCCCCCCCCCCHHHCCEEEECHHHCCEEEEEECCCCCCEEEEECCEEEECEEEECCEEEEEECCCCCCEEEEEECCEEECCCCCCCCCCCCCCCEECCCCCCCCCCCCCCCCCCHHHCCEEEECHHHHCCCCCHHHHHHCHHHHHHHCCCEEEECCCEECCCCCCCCCCCCEEEEECHHHCCHHHHHHHHHHHHHCCCEEEEEECCCCCCCCCCCHHHHCHHHEEEEEECCCCEEECCCCHHHHHHHHHHHHHHHHHHCCCEEEECCHHHCCCCCCCCHHHHHHHHHHCCCCCCEEEEECCCCCCHHHHCCCCCEEECCHHHHHHHHHHHCCCCHHHHHCCCCHHHHHHHHHHCCCCEEEEECCCCCCEEEECCCCCCCHHHEEEECCCHHHHHCCCCCCCHHHCCCCCHHHHHHHHHHHHHCCCEEEEECCHHHCCCCCCCCCCCCCHHHHHHHHHHHHHHCCCCCCCCCCCCCCCCCHHHHHCCCCCCHHHHCHHHHHHHHHHHHHHHHHHHCCCCCCCCHHHEEEEEECCEEEEEEEECCEEEEEEEECCCCCEEHHHCCCCCCCCCCEEEECCCCCCCEECCCCEEEEC',\n", " 'dssp8': 'CCCCCCCCCCCCCCCCCSCCGGGCSEEEECGGGCCEEEEEECSSCSSEEEEETTEEEECEEEETTEEEEEESCCTTCEEEEEETTEEECCTTCSCCTTCTTSCEECCCTTSSCCCCTTCCCCCGGGCCEEEECHHHHSSSCSHHHHHHTHHHHHHHTCCEEEECCCEECSSSCCCSTTCCEEEEECGGGCCHHHHHHHHHHHHHTTCEEEEEECCSCCCSSSCCHHHHCGGGEEEEEECSSSEEECTTSHHHHHHHHHHHHIIIIIHCCSEEEETTGGGCCCCSSSCHHHHHHHHHHTTCSCCEEEEECSSCCTHHHHTTCCSEEECTHHHHHHHHHHHCCCSGGGGGCCCSHHHHHHHHHHSSSCEEEEECCTTCCEEEECCCTTCCGGGEEEESCCHHHHHTSTTCCCGGGSTTCCHHHHHHHHHHHHHSSSEEEEETTGGGTCSSCCCCCCCCCHHHHHHHHHHHHHHCCCCCCCCCCCCCCTTSHHHHHTTSCCSGGGGSHHHHHHHHHHHHHHHHHHHCTTTTCCCGGGEEEEEETTEEEEEEEETTEEEEEEEECSSSCEEGGGSCCSSCCCCCEEEETTCCSSSEECTTCEEEEC',\n", " 'disorder': '0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0',\n", " 'cb513_mask': '1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0'}" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset[\"train\"][0]" ] }, { "cell_type": "code", "execution_count": 30, "id": "47b1ac0c-e934-4ac3-b869-509515b15aa1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dna datasets mean token lenght 96.07685185185186 min token length 7 max token length 576\n" ] } ], "source": [ "token_len_list = []\n", "for item in dataset[\"test\"]:\n", " inputs = tokenizer.tokenize(item[\"input\"])\n", " token_len_list.append( len(inputs) )\n", "\n", "mean_len = sum(token_len_list)/len(token_len_list)\n", "min_len = min(token_len_list)\n", "max_len = max(token_len_list)\n", "\n", "print(\"dna datasets \", \"mean token lenght\", mean_len, \"min token length\", min_len, \"max token length\", max_len)" ] }, { "cell_type": "code", "execution_count": 31, "id": "1b32de6e-fe08-426e-983e-7dd157c9af62", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of unique labels: 3\n", "Label to ID mapping: {'C': 0, 'H': 1, 'E': 2, '': 3}\n" ] } ], "source": [ "from collections import Counter\n", "\n", "# Confirm the number of labels and create a mapping from string labels to integer IDs.\n", "all_labels = [label for item in dataset[\"train\"] for label in item[\"dssp3\"]]\n", "label_counts = Counter(all_labels)\n", "num_labels = len(label_counts)\n", "\n", "# Define a special ID for padding. Make sure this ID is not used by any actual label.\n", "# If you have 3 classes, start with 3 or higher.\n", "pad_token_label_id = num_labels # Assuming no other labels have this ID.\n", "\n", "label_to_id = {label: i for i, (label, _) in enumerate(label_counts.items())}\n", "label_to_id[''] = pad_token_label_id # Add padding token to the mapping.\n", "id_to_label = {v: k for k, v in label_to_id.items()}\n", "\n", "print(f\"Number of unique labels: {num_labels}\")\n", "print(\"Label to ID mapping:\", label_to_id)" ] }, { "cell_type": "code", "execution_count": 32, "id": "2bd65f47-3325-4357-a896-9a0abf160e8a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of GPT2ForTokenClassification were not initialized from the model checkpoint at dnagpt/gene_eng_gpt2_v0 and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "#set model\n", "#model = AutoModelForTokenClassification.from_pretrained('dnagpt/gene_eng_gpt2_v0', )\n", "model = AutoModelForTokenClassification.from_pretrained(\n", " 'dnagpt/gene_eng_gpt2_v0',\n", " num_labels=num_labels + 1, # Include the padding label in the count.\n", " id2label=id_to_label,\n", " label2id=label_to_id\n", ")" ] }, { "cell_type": "code", "execution_count": 33, "id": "e247ac1e-bcd4-4aaf-9f91-dc939e5abe89", "metadata": {}, "outputs": [], "source": [ "# 5. Preprocess the data\n", "from transformers import DataCollatorForTokenClassification\n", "import torch\n", "# Define the maximum sequence length based on your model or dataset requirements.\n", "max_seq_length = 128 # Adjust this value as needed.\n", "\n", "def preprocess_function(examples):\n", " tokenized_inputs = tokenizer(\n", " examples[\"input\"], \n", " truncation=True, \n", " padding='max_length', \n", " max_length=max_seq_length,\n", " return_tensors=\"pt\" # Return PyTorch tensors directly.\n", " )\n", " \n", " labels = []\n", " for label in examples['dssp3']:\n", " label_ids = [label_to_id[l] if l in label_to_id else pad_token_label_id for l in label]\n", " # Ensure labels are padded/truncated to the same length as inputs.\n", " if len(label_ids) > max_seq_length:\n", " label_ids = label_ids[:max_seq_length]\n", " else:\n", " label_ids = label_ids + [pad_token_label_id] * (max_seq_length - len(label_ids))\n", " \n", " labels.append(label_ids)\n", " \n", " tokenized_inputs[\"labels\"] = torch.tensor(labels)\n", "\n", " return tokenized_inputs" ] }, { "cell_type": "code", "execution_count": 34, "id": "8144d093-e8d3-41ff-ae4f-82aa1f28d689", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "707978d4f8304cada1041f8e794d79b7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/9712 [00:00\n", " \n", " \n", " [ 8001/12140 09:41 < 05:00, 13.76 it/s, Epoch 13.18/20]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossPrecisionRecallF1Accuracy
11.1022000.9231860.3148430.1255210.1794850.503214
20.9421000.8833620.3574130.1534660.2147310.521366
30.8985000.8954420.3554430.1942340.2511990.522545
40.8702000.8912300.3671700.0507310.0891450.526761
50.8319000.8900300.3732520.1970960.2579710.530358
60.8158000.8678760.3786960.2366280.2912620.540153
70.8009000.8735210.3809250.2126400.2729270.544393
80.7851000.8721380.3853720.1563630.2224620.547684
90.7741000.8858550.3848130.1802800.2455310.549681
100.7508000.8845820.3884640.2065290.2696810.555933
110.7375000.8863230.3969290.2027130.2683690.557624
120.7310000.8782850.3659560.3157280.3389910.555857
130.7089000.9122780.3770300.2493460.3001740.555030

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: C seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: H seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/v1.py:57: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: C seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: H seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/v1.py:57: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: C seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: H seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/v1.py:57: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: C seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: H seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/v1.py:57: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: C seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: H seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/v1.py:57: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: C seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: H seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/v1.py:57: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: C seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: H seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/v1.py:57: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: C seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: H seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/v1.py:57: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: C seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: H seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/v1.py:57: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: C seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: H seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/v1.py:57: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: C seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: H seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/v1.py:57: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: C seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: H seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/v1.py:57: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: C seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: H seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: seems not to be NE tag.\n", " warnings.warn('{} seems not to be NE tag.'.format(chunk))\n", "/root/miniconda3/lib/python3.12/site-packages/seqeval/metrics/v1.py:57: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n" ] }, { "ename": "RuntimeError", "evalue": "[enforce fail at inline_container.cc:595] . unexpected pos 1216226560 vs 1216226452", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/torch/serialization.py:628\u001b[0m, in \u001b[0;36msave\u001b[0;34m(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization, _disable_byteorder_record)\u001b[0m\n\u001b[1;32m 627\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _open_zipfile_writer(f) \u001b[38;5;28;01mas\u001b[39;00m opened_zipfile:\n\u001b[0;32m--> 628\u001b[0m \u001b[43m_save\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopened_zipfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpickle_module\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpickle_protocol\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_disable_byteorder_record\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 629\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n", "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/torch/serialization.py:862\u001b[0m, in \u001b[0;36m_save\u001b[0;34m(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record)\u001b[0m\n\u001b[1;32m 861\u001b[0m num_bytes \u001b[38;5;241m=\u001b[39m storage\u001b[38;5;241m.\u001b[39mnbytes()\n\u001b[0;32m--> 862\u001b[0m \u001b[43mzip_file\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite_record\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstorage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_bytes\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[0;31mRuntimeError\u001b[0m: [enforce fail at inline_container.cc:764] . PytorchStreamWriter failed writing file data/94: file write failed", "\nDuring handling of the above exception, another exception occurred:\n", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[38], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Start training\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/transformers/trainer.py:2164\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 2162\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m 2163\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 2164\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2165\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2166\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2167\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2168\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2169\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/transformers/trainer.py:2591\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2589\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mepoch \u001b[38;5;241m=\u001b[39m epoch \u001b[38;5;241m+\u001b[39m (step \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m steps_skipped) \u001b[38;5;241m/\u001b[39m steps_in_epoch\n\u001b[1;32m 2590\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_end(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[0;32m-> 2591\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_maybe_log_save_evaluate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2592\u001b[0m \u001b[43m \u001b[49m\u001b[43mtr_loss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_norm\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstart_time\u001b[49m\n\u001b[1;32m 2593\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2594\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 2595\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_substep_end(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n", "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/transformers/trainer.py:3056\u001b[0m, in \u001b[0;36mTrainer._maybe_log_save_evaluate\u001b[0;34m(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time)\u001b[0m\n\u001b[1;32m 3053\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol\u001b[38;5;241m.\u001b[39mshould_save \u001b[38;5;241m=\u001b[39m is_new_best_metric\n\u001b[1;32m 3055\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol\u001b[38;5;241m.\u001b[39mshould_save:\n\u001b[0;32m-> 3056\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_save_checkpoint\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3057\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_save(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n", "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/transformers/trainer.py:3192\u001b[0m, in \u001b[0;36mTrainer._save_checkpoint\u001b[0;34m(self, model, trial)\u001b[0m\n\u001b[1;32m 3188\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msave_model(output_dir, _internal_call\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 3190\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39msave_only_model:\n\u001b[1;32m 3191\u001b[0m \u001b[38;5;66;03m# Save optimizer and scheduler\u001b[39;00m\n\u001b[0;32m-> 3192\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_save_optimizer_and_scheduler\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutput_dir\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3193\u001b[0m \u001b[38;5;66;03m# Save RNG state\u001b[39;00m\n\u001b[1;32m 3194\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_save_rng_state(output_dir)\n", "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/transformers/trainer.py:3313\u001b[0m, in \u001b[0;36mTrainer._save_optimizer_and_scheduler\u001b[0;34m(self, output_dir)\u001b[0m\n\u001b[1;32m 3308\u001b[0m save_fsdp_optimizer(\n\u001b[1;32m 3309\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfsdp_plugin, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptimizer, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, output_dir\n\u001b[1;32m 3310\u001b[0m )\n\u001b[1;32m 3311\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mshould_save:\n\u001b[1;32m 3312\u001b[0m \u001b[38;5;66;03m# deepspeed.save_checkpoint above saves model/optim/sched\u001b[39;00m\n\u001b[0;32m-> 3313\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstate_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutput_dir\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mOPTIMIZER_NAME\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3315\u001b[0m \u001b[38;5;66;03m# Save SCHEDULER & SCALER\u001b[39;00m\n\u001b[1;32m 3316\u001b[0m is_deepspeed_custom_scheduler \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mis_deepspeed_enabled \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\n\u001b[1;32m 3317\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlr_scheduler, DeepSpeedSchedulerWrapper\n\u001b[1;32m 3318\u001b[0m )\n", "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/torch/serialization.py:627\u001b[0m, in \u001b[0;36msave\u001b[0;34m(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization, _disable_byteorder_record)\u001b[0m\n\u001b[1;32m 624\u001b[0m _check_save_filelike(f)\n\u001b[1;32m 626\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _use_new_zipfile_serialization:\n\u001b[0;32m--> 627\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mwith\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m_open_zipfile_writer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mas\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mopened_zipfile\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 628\u001b[0m \u001b[43m \u001b[49m\u001b[43m_save\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopened_zipfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpickle_module\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpickle_protocol\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_disable_byteorder_record\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 629\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mreturn\u001b[39;49;00m\n", "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/torch/serialization.py:475\u001b[0m, in \u001b[0;36m_open_zipfile_writer_file.__exit__\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m 474\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__exit__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 475\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfile_like\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite_end_of_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 476\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfile_stream \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 477\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfile_stream\u001b[38;5;241m.\u001b[39mclose()\n", "\u001b[0;31mRuntimeError\u001b[0m: [enforce fail at inline_container.cc:595] . unexpected pos 1216226560 vs 1216226452" ] } ], "source": [ "# Start training\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": 39, "id": "950c460f-5631-4c9a-819b-1e3ac484cc65", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'eval_loss': 0.9122781157493591,\n", " 'eval_precision': 0.3770299145299145,\n", " 'eval_recall': 0.2493464283190843,\n", " 'eval_f1': 0.3001743716242079,\n", " 'eval_accuracy': 0.5550300748427384}" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "results = trainer.evaluate()\n", "results" ] }, { "cell_type": "code", "execution_count": 40, "id": "8174c1c6-a5bc-4fe3-8f9b-356625531e7d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ">>> Perplexity: 2.49\n" ] } ], "source": [ "import math\n", "eval_results = trainer.evaluate()\n", "print(f\">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}\")" ] }, { "cell_type": "code", "execution_count": 41, "id": "6a22f131-9e5f-4125-942a-22d1b1e6373b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('./secondary_structure_model/tokenizer_config.json',\n", " './secondary_structure_model/special_tokens_map.json',\n", " './secondary_structure_model/vocab.json',\n", " './secondary_structure_model/merges.txt',\n", " './secondary_structure_model/added_tokens.json')" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 保存模型\n", "model.save_pretrained(\"./secondary_structure_model\")\n", "tokenizer.save_pretrained(\"./secondary_structure_model\")" ] }, { "cell_type": "code", "execution_count": 42, "id": "d5817a6c-c707-4005-9210-2a12ff0d43b0", "metadata": {}, "outputs": [], "source": [ "# 加载模型\n", "model = AutoModelForTokenClassification.from_pretrained(\"./secondary_structure_model\")\n", "tokenizer = GPT2Tokenizer.from_pretrained(\"./secondary_structure_model\")" ] }, { "cell_type": "code", "execution_count": 43, "id": "2f6ebdc6-8ff8-4947-ada4-05ff4b28e0f3", "metadata": {}, "outputs": [], "source": [ "# 进行预测\n", "def predict_secondary_structure(sequence):\n", " inputs = tokenizer(sequence, return_tensors=\"pt\", truncation=True, padding=True)\n", " outputs = model(**inputs)\n", " predictions = outputs.logits.argmax(dim=-1)\n", " return predictions" ] }, { "cell_type": "code", "execution_count": 44, "id": "841ebba8-7619-411f-a11e-841de3a3f064", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0, 0, 0, 0, 2, 2, 2, 2]])\n" ] } ], "source": [ "# 示例预测\n", "sequence = \"ACDEFGHIKLMNPQRSTVWY\"\n", "predictions = predict_secondary_structure(sequence)\n", "print(predictions)" ] }, { "cell_type": "code", "execution_count": null, "id": "37e7d22e-0545-422b-b8ba-7990ca127d8a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }