{ "cells": [ { "cell_type": "code", "execution_count": 10, "id": "97211288-7031-49f0-9916-844e0f1c7128", "metadata": {}, "outputs": [], "source": [ "import json\n", "from transformers import BertTokenizer, BertModel\n", "from collections import OrderedDict\n", "import torch\n", "import random\n", "import pandas as pd\n", "from transformers import AutoModelForSequenceClassification,DataCollatorWithPadding\n", "from transformers import TrainingArguments, Trainer\n", "from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n", "import sys, os\n", "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n", "from transformers import pipeline\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "68632e9b-1933-4f62-a32e-d111dc2964bc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/cache/comfyui/InstantID_wozai/cat_classification\n" ] } ], "source": [ "home_path = os.getcwd()\n", "print(home_path)" ] }, { "cell_type": "code", "execution_count": 7, "id": "005d6545-4e86-496b-a61b-246c45a596f1", "metadata": {}, "outputs": [], "source": [ "model_path = f\"{home_path}/cate_calssification_en_thai_checkpoint_150320\"\n", "assert os.path.isdir(model_path)" ] }, { "cell_type": "code", "execution_count": 9, "id": "c80d4a18-be09-477c-b352-7fd382775523", "metadata": {}, "outputs": [], "source": [ "model = AutoModelForSequenceClassification.from_pretrained(model_path)\n", "tokenizer = AutoTokenizer.from_pretrained(model_path)" ] }, { "cell_type": "code", "execution_count": 53, "id": "f6a3e36a-5284-4913-9b53-9e1ea5a05692", "metadata": {}, "outputs": [], "source": [ "val_df = pd.read_csv(f\"your_data.csv\")\n", "# 保存到文件\n", "with open(f\"{home_path}/label_name.json\", \"r\") as f:\n", " cat_dic = json.load(f)\n", "label_name_dict = cat_dic[\"label_name_dict\"]\n", "name_label_dict = cat_dic[\"name_label_dict\"]\n", "classifier = pipeline(\"text-classification\", model=model, tokenizer=tokenizer, device=\"cuda:0\")\n" ] }, { "cell_type": "code", "execution_count": 55, "id": "01613942-c304-4980-92bf-05b3da07fdae", "metadata": {}, "outputs": [], "source": [ "import re\n", "def get_label(name):\n", " label = classifier(name)[0][\"label\"].split('_')[1]\n", " return label_name_dict[label]\n", "def is_english(text):\n", " # 检查是否只包含英文字母和常见的标点符号、空格\n", " try:\n", " return bool(re.fullmatch(r'[A-Za-z\\s.,!?\\'\\\";:-]*', text[:4]))\n", " except:\n", " return False\n", "val_df[\"is_en\"] = val_df.pri_cate_name.apply(lambda x: is_english(x))" ] }, { "cell_type": "code", "execution_count": 37, "id": "9cd99694-d8c6-4a6e-a592-6444d7226bd4", "metadata": {}, "outputs": [], "source": [ "en_df = val_df[val_df.is_en]\n", "th_df = val_df[~val_df.is_en]" ] }, { "cell_type": "code", "execution_count": 40, "id": "5f317115-278b-421d-b42c-c56d91064417", "metadata": {}, "outputs": [], "source": [ "def predict_df(df):\n", " df[\"predict_label\"] = df[\"product_name\"].apply(lambda x: get_label(x))\n", " return df" ] }, { "cell_type": "code", "execution_count": null, "id": "8e673c24-aece-44a7-aff1-fa5de49cc8e2", "metadata": {}, "outputs": [], "source": [ "# 预测数据的label\n", "en_df = predict_df(en_df)" ] } ], "metadata": { "kernelspec": { "display_name": "vllm_xtts", "language": "python", "name": "myenv" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" } }, "nbformat": 4, "nbformat_minor": 5 }