{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"#### Introduction to Fine-Tuning BERT for Predicting Affective Meanings\n",
"\n",
"This Jupyter notebook demonstrates how to fine-tune the BERT model to estimate affective meanings in language using advanced NLP techniques. The process is designed to refine BERT embeddings for specific tasks, continuing the work developed in BERTNN, a tool by Moeen Mostafavi, Michael D. Porter, and Dawn T. Robinson. \n",
"\n",
"This notebook borrows from the following post by Chris McCormick, available [here](https://mccormickml.com/2019/05/14/BERT-word-embeddings-tutorial/) , which provides a foundational tutorial on using BERT embeddings.\n",
"\n",
"Depending on your computational resources (e.g., CPU/GPU), the notebook might take a few minutes to execute. The functions defined herein can be used to generate estimated EPA (Evaluation, Potency, Activity) values, contributing to our ongoing exploration of affective meaning in language."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Loading required libraries"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 162,
"referenced_widgets": [
"5b4469b58a744da1a15ea28a918ad1aa",
"081eda02be0b40edad3cbab832a79667",
"817e41eac37448fc82e1ab4a4eaa7f54",
"eb2d14c36ee74587afd67d216dcf5ef8",
"b95438470d39403491e43171994c2468",
"d1fc7c2fafd74effbf11b42c74ffb7b0",
"3e28697501b2497a94ca53174a5d647c",
"dc14d974b2734c6e997c415d885d1b79",
"f6953e66e3d445938d1650c850e6f45a",
"885e801a160c405d83d42f9dbea58ced",
"9df19fff43aa44d1b08816730ec83029",
"9faf62522de4493d9ee007edde7601de",
"b8346621449d4a09a0db59845e1e292e",
"16a8c19ffa4c4e55978cd6ee01c1f9f1",
"1ae4f6f212d8403ebf22ba448ce6a39d",
"59909d8ca5af415aae0e2e7f64001d78",
"77900d6e71924c998a5bd41d2097e4a0",
"1e89ca8823414adcb55b57953bd2b25b",
"8aefab5c69bb49e68aa1a7ce56c219cf",
"c3fbd6eef6014346a51db8a808867835",
"2cf178b819ca4a3db05cc9ad0d2490e4",
"860f138d9a4c44729176915246acfe3c",
"a082e16c77c7446086f61b4125867a82",
"e8c08401834f4724b1abcee5ad17293d",
"bb25c0c52c724b0f9af44199f5843ab4",
"5af534d8e04e4e769eecb9b71ce097b3",
"0f638f53620245d3b5d67f172a0d4abc",
"81b86cf2fbb443368ad2584d3e80ea7a",
"e8433bee864a4dcaa55f9a20345c79a5",
"4036232b409b467eaa5b9e4f2ddcced4",
"11075e61698a485896267f188e133958",
"c928829cd6394d50b991e569bfd5caf2",
"1e8f16fa716f466f8fd8772d765ad7a1",
"61c8d744d6be4250a9ecacd387114314",
"ed3def74372245c38b386b3b40ce4efc",
"0784db0a7fad434ea5c58e56f3f9db48",
"67b617180f634cd7a87825bd4b717f12",
"355008ec0b0e4fee9ad82cff6fce97ef",
"a997f2a2c54f4e1a840254b0a4ebfc2b",
"a9a34097523c4bcf9e6b100dc50bc3bf",
"dc05deaed5f04aea9270527e989e378d",
"4c94071de1ba46479762d0576e0a5dac",
"8deb2d8214b54a6ab0f7fac14c78b0dc",
"1ea6b565ffe945d5a4c527f5607e52fc"
]
},
"id": "lJEnBJ3gHTsQ",
"outputId": "81f967d3-346d-4f9d-dd90-da6fed7682e2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"There are 2 GPU(s) available.\n",
"Device name: NVIDIA GeForce RTX 2080 Ti\n"
]
}
],
"source": [
"from datetime import datetime\n",
"import numpy as np\n",
"import pandas as pd\n",
"import random\n",
"from random import sample\n",
"import torch\n",
"from sklearn.linear_model import LinearRegression\n",
"from sklearn.preprocessing import PolynomialFeatures\n",
"\n",
"# Check if a GPU is available for PyTorch, otherwise use CPU\n",
"if torch.cuda.is_available(): \n",
" device = torch.device(\"cuda:0\")\n",
" print(f'There are {torch.cuda.device_count()} GPU(s) available.')\n",
" print('Device name:', torch.cuda.get_device_name(0))\n",
"\n",
"else:\n",
" print('No GPU available, using the CPU instead.')\n",
" device = torch.device(\"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Setting random seed for NumPy, random, and PyTorch to ensure reproducibility of results\n",
"rnd_st=42\n",
"np.random.seed(rnd_st)\n",
"random.seed(rnd_st)\n",
"torch.manual_seed(rnd_st)\n",
"torch.cuda.manual_seed(rnd_st)\n",
"# Ensuring the PyTorch backend uses deterministic algorithms for reproducibility\n",
"torch.backends.cudnn.deterministic = True\n",
"torch.backends.cudnn.benchmark = False\n",
"\n",
"# Set a fixed value for the hash seed Ref:https://wandb.ai/sauravmaheshkar/RSNA-MICCAI/reports/How-to-Set-Random-Seeds-in-PyTorch-and-Tensorflow--VmlldzoxMDA2MDQy\n",
"# Uncommenting the following line would ensure consistent hash values in Python across runs (useful for some random operations)\n",
"# os.environ[\"PYTHONHASHSEED\"] = str(rnd_st)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Clustering Concepts in the Affective Dictionaries\n",
"\n",
"Concepts in affective dictionaries are represented in a 3-dimensional EPA (Evaluation, Potency, Activity) space. To ensure that the training, test, and validation sets have similar distributions across these dimensions, it's important to balance the data. Previously, we used the 'elbow method' to determine the optimal number of clusters. In this step, we apply K-means clustering to group the concepts in the dictionaries into five distinct clusters."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Importing the KMeans clustering algorithm from scikit-learn\n",
"from sklearn.cluster import KMeans\n",
"\n",
"# Function to perform KMeans clustering on a DataFrame containing EPA (Evaluation, Potency, Activity) values\n",
"def add_cluster(df,n_clusters=5,seed_random = rnd_st): \n",
" kmeans = KMeans(n_clusters=n_clusters,random_state=seed_random, )\n",
" # Perform clustering on the 'E', 'P', and 'A' columns of the DataFrame and assign the cluster labels\n",
" labels_clusters = kmeans.fit_predict(df[['E', 'P', 'A']])\n",
" labels_kmeans = labels_clusters\n",
" # Add the cluster labels as a new column 'cluster' to the original DataFrame\n",
" return(pd.concat([df,pd.DataFrame({'cluster':labels_kmeans})],axis=1))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Loading Affectiove Dictionaries\n",
"\n",
"We load the affective dictionaries into pandas DataFrames and add several additional columns.\n",
"- 'index_in_dic' : The index of each concept in the dictionary. After training the model, we use this index to map EPA values back to their corresponding concepts.\n",
"- 'term2': A modified version of 'term' where underscores (_) are removed. \n",
"- 'len_Bert': The number of tokens the BERT tokenizer uses to represent each concept.\n",
"- 'cluster': The cluster index of the concept, created using KMeans clustering to group similar concepts."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\ProgramData\\Anaconda3\\envs\\py39latest\\lib\\site-packages\\sklearn\\cluster\\_kmeans.py:1334: UserWarning: KMeans is known to have a memory leak on Windows with MKL, when there are less chunks than available threads. You can avoid it by setting the environment variable OMP_NUM_THREADS=3.\n",
" warnings.warn(\n",
"c:\\ProgramData\\Anaconda3\\envs\\py39latest\\lib\\site-packages\\sklearn\\cluster\\_kmeans.py:1334: UserWarning: KMeans is known to have a memory leak on Windows with MKL, when there are less chunks than available threads. You can avoid it by setting the environment variable OMP_NUM_THREADS=4.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Max length for ABMOs: 37\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\ProgramData\\Anaconda3\\envs\\py39latest\\lib\\site-packages\\sklearn\\cluster\\_kmeans.py:1334: UserWarning: KMeans is known to have a memory leak on Windows with MKL, when there are less chunks than available threads. You can avoid it by setting the environment variable OMP_NUM_THREADS=4.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from transformers import BertTokenizer, BertModel\n",
"\n",
"# Load the BERT tokenizer (using a pre-trained 'bert-large-uncased' model and applying lowercase)\n",
"tokenizer = BertTokenizer.from_pretrained('bert-large-uncased', do_lower_case=True)\n",
"\n",
"# Function to load and process the dictionary from a CSV file\n",
"def load_dictionary(file):\n",
" df=pd.read_csv(file).reset_index().rename(columns={\"index\": 'index_in_dic'})\n",
" df['term2']=df['term']\n",
" df.term=df.term.str.replace(\"_\", \" \") \n",
" df['len_Bert']=df.apply(lambda x: len(tokenizer.tokenize(x['term'])),axis=1)\n",
" df=add_cluster(df)\n",
" return(df)\n",
"\n",
"Modifiers =load_dictionary(\"E:/ACT/FullSurveyorInteract_Modifiers.csv\")\n",
"Behaviors=load_dictionary(\"E:/ACT/mod_dicts/FullSurveyorInteract_Behaviors.csv\")\n",
"Identities=load_dictionary(\"E:/ACT/FullSurveyorInteract_Identities.csv\")\n",
"\n",
"# Print the maximum sequence length for ABMOs (Actor, Behavior, Modifier, Object) \n",
"# by calculating the token lengths across Modifiers, Behaviors, and Identities\n",
"print('Max length for ABMOs:', 2*np.max(Identities.len_Bert)+np.max(Behaviors.len_Bert)+2*np.max(Modifiers.len_Bert))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Stratified Sampling of Affective Dictionaries\n",
"\n",
"To ensure balanced and representative training, validation, and test sets, we perform stratified sampling on the affective dictionaries (Identities, Behaviors, and Modifiers). We want each set to have similar distributions across the clusters created in the EPA (Evaluation, Potency, Activity) space. This function splits the data into training, validation, and test sets by maintaining proportional representation of each cluster.\n",
"\n",
"- Identities, Behaviors, and Modifiers are clustered into five groups.\n",
"- The function splits each cluster into training, validation, and test sets based on the given proportions (section1 and section2).\n",
"\n",
"The stratified sampling ensures that all sets have consistent distributions of EPA dimensions for the concepts."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Behavior total 813 , Train: 649 , Validation: 96 , Test: 68\n"
]
}
],
"source": [
"# Function to perform stratified sampling on Identities, Behaviors, and Modifiers based on their clusters\n",
"def Stratified_sample(ID=Identities,BE=Behaviors,MO=Modifiers,section1=.8,secttion2=.92) :\n",
" # Initialize empty DataFrames for training (T), validation (V), and test (TE) sets for Identities, Behaviors, and Modifiers\n",
" I_T,I_V,I_TE,B_T,B_V,B_TE,M_T,M_V,M_TE=pd.DataFrame(),pd.DataFrame(),pd.DataFrame(),pd.DataFrame(),pd.DataFrame(),pd.DataFrame(),pd.DataFrame(),pd.DataFrame(),pd.DataFrame()\n",
" # Iterate over the 5 clusters\n",
" for i in range(5):\n",
" # Filter the data for the current cluster (i) for Identities, Behaviors, and Modifiers\n",
" ident_tmp=ID.loc[ID.cluster==i,]\n",
" beh_tmp=BE.loc[BE.cluster==i,]\n",
" mod_tmp=MO.loc[MO.cluster==i,]\n",
" \n",
" # Randomly shuffle and split the data for the current cluster into train, validation, and test sets\n",
" I_train_tmp, I_v_tmp,I_test_tmp = np.split(ident_tmp.sample(frac=1, random_state=rnd_st), [int(section1*len(ident_tmp)), int(secttion2*len(ident_tmp))])\n",
" B_train_tmp, B_v_tmp,B_test_tmp = np.split(beh_tmp.sample(frac=1, random_state=rnd_st), [int(section1*len(beh_tmp)), int(secttion2*len(beh_tmp))])\n",
" M_train_tmp, M_v_tmp,M_test_tmp = np.split(mod_tmp.sample(frac=1, random_state=rnd_st), [int(section1*len(mod_tmp)), int(secttion2*len(mod_tmp))])\n",
" # Concatenate the current cluster's samples to the overall training, validation, and test sets\n",
" I_T,I_V,I_TE=pd.concat([I_T,I_train_tmp],axis=0),pd.concat([I_V,I_v_tmp],axis=0),pd.concat([I_TE,I_test_tmp],axis=0)\n",
" B_T,B_V,B_TE=pd.concat([B_T,B_train_tmp],axis=0),pd.concat([B_V,B_v_tmp],axis=0),pd.concat([B_TE,B_test_tmp],axis=0)\n",
" M_T,M_V,M_TE=pd.concat([M_T,M_train_tmp],axis=0),pd.concat([M_V,M_v_tmp],axis=0),pd.concat([M_TE,M_test_tmp],axis=0)\n",
" # Return the final stratified training, validation, and test sets for Identities, Behaviors, and Modifiers\n",
" return(I_T,I_V,I_TE,B_T,B_V,B_TE,M_T,M_V,M_TE)\n",
"# Call the stratified sampling function to split Identities, Behaviors, and Modifiers into training, validation, and test sets\n",
"I_train, I_v,I_test,B_train, B_v,B_test,M_train, M_v,M_test=Stratified_sample()\n",
"# Print the total number of behaviors and the number of samples in the train, validation, and test sets for Behaviors\n",
"print('Behavior total',len(Behaviors),', Train:',len(B_train),', Validation: ',len(B_v),', Test:',len(B_test))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Standardizing EPA Values Across Training, Validation, and Test Sets\n",
"To ensure consistency in the scale of EPA (Evaluation, Potency, Activity) values across the training, validation, and test sets for Identities, Behaviors, and Modifiers, we apply standardization. Standardization ensures that the EPA values for each concept are normalized to have a mean of 0 and a standard deviation of 1. This is particularly useful for machine learning algorithms that are sensitive to feature scaling.\n",
"- StandardScaler is used to fit the training data and apply the same transformation to the test and validation sets.\n",
"- Separate scalers are used for Behaviors, Modifiers, and Identities to account for potential differences in EPA value distributions."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" index_in_dic | \n",
" term | \n",
" E | \n",
" P | \n",
" A | \n",
" E2 | \n",
" P2 | \n",
" A2 | \n",
" term2 | \n",
" len_Bert | \n",
" cluster | \n",
"
\n",
" \n",
" \n",
" \n",
" | 516 | \n",
" 516 | \n",
" landlord | \n",
" -0.202782 | \n",
" 1.051292 | \n",
" 0.277485 | \n",
" 0.14 | \n",
" 2.00 | \n",
" 0.67 | \n",
" landlord | \n",
" 1 | \n",
" 0 | \n",
"
\n",
" \n",
" | 701 | \n",
" 701 | \n",
" radio and television announcer | \n",
" 0.194633 | \n",
" 0.485202 | \n",
" 2.108632 | \n",
" 0.76 | \n",
" 1.24 | \n",
" 2.44 | \n",
" radio_and_television_announcer | \n",
" 4 | \n",
" 0 | \n",
"
\n",
" \n",
" | 355 | \n",
" 355 | \n",
" fan | \n",
" 0.777936 | \n",
" 0.127671 | \n",
" 1.901723 | \n",
" 1.67 | \n",
" 0.76 | \n",
" 2.24 | \n",
" fan | \n",
" 1 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" index_in_dic term E P \\\n",
"516 516 landlord -0.202782 1.051292 \n",
"701 701 radio and television announcer 0.194633 0.485202 \n",
"355 355 fan 0.777936 0.127671 \n",
"\n",
" A E2 P2 A2 term2 len_Bert \\\n",
"516 0.277485 0.14 2.00 0.67 landlord 1 \n",
"701 2.108632 0.76 1.24 2.44 radio_and_television_announcer 4 \n",
"355 1.901723 1.67 0.76 2.24 fan 1 \n",
"\n",
" cluster \n",
"516 0 \n",
"701 0 \n",
"355 0 "
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Importing StandardScaler from scikit-learn for standardizing the EPA values\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"# Creating copies of the training, validation, and test sets for Behaviors, Modifiers, and Identities\n",
"n_B_train,n_B_test,n_B_v,n_M_train = B_train.copy(),B_test.copy(),B_v.copy(),M_train.copy()\n",
"n_M_test,n_M_v,n_I_train,n_I_test,n_I_v = M_test.copy(),M_v.copy(),I_train.copy(),I_test.copy(),I_v.copy()\n",
"\n",
"scaler_B,scaler_M,scaler_I = StandardScaler(),StandardScaler(),StandardScaler()\n",
"\n",
"\n",
"n_B_train[['E','P','A']] = scaler_B.fit_transform(B_train[['E','P','A']])\n",
"n_B_test[['E','P','A']] = scaler_B.transform(B_test[['E','P','A']])\n",
"n_B_v[['E','P','A']] = scaler_B.transform(B_v[['E','P','A']])\n",
"n_M_train[['E','P','A']] = scaler_M.fit_transform(M_train[['E','P','A']])\n",
"n_M_test[['E','P','A']] = scaler_M.transform(M_test[['E','P','A']])\n",
"n_M_v[['E','P','A']] = scaler_M.transform(M_v[['E','P','A']])\n",
"n_I_train[['E','P','A']] = scaler_I.fit_transform(I_train[['E','P','A']])\n",
"n_I_test[['E','P','A']] = scaler_I.transform(I_test[['E','P','A']])\n",
"n_I_v[['E','P','A']] = scaler_I.transform(I_v[['E','P','A']])\n",
"\n",
"n_I_train.head(3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Loading Impression Change and Amalgamation Equations\n",
"\n",
"In this section, we load impression change data and amalgamation equations that are used to calculate the deflection of events in affect control theory. We define a function called next_state, which computes the deflection between two states using impression change equations for identities, behaviors, and modifiers (ABO).\n",
"\n",
"- m_ABO and f_ABO: These CSV files contain impression change equations for male and female actors, respectively.\n",
"- Amalgamation Equations: We define matrices that control how initial states are combined with modifiers to compute the final state of a given actor, behavior, or object.\n",
"- The function amalgamate merges the initial state and modifier into a final state using the amalgamation equations.\n",
"- The function next_state calculates the change in state (impression) based on input events using the impression change equations.\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_20012\\2723112370.py:1: FutureWarning: DataFrame.set_axis 'inplace' keyword is deprecated and will be removed in a future version. Use `obj = obj.set_axis(..., copy=False)` instead\n",
" m_ABO = pd.read_csv(\"D:/ACT/data/m_ABO.csv\",sep='\\s+',header=None).set_axis(['_','EA', 'PA', 'AA','EB', 'PB', 'AB','EO', 'PO', 'AO'], axis=1, inplace=False)\n",
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_20012\\2723112370.py:2: FutureWarning: DataFrame.set_axis 'inplace' keyword is deprecated and will be removed in a future version. Use `obj = obj.set_axis(..., copy=False)` instead\n",
" f_ABO = pd.read_csv(\"D:/ACT/data/f_ABO.csv\",sep='\\s+',header=None).set_axis(['_','EA', 'PA', 'AA','EB', 'PB', 'AB','EO', 'PO', 'AO'], axis=1, inplace=False)\n"
]
}
],
"source": [
"# Loading male (m_ABO) and female (f_ABO) impression change equations into DataFrames\n",
"m_ABO = pd.read_csv(\"D:/ACT/data/m_ABO.csv\",sep='\\s+',header=None).set_axis(['_','EA', 'PA', 'AA','EB', 'PB', 'AB','EO', 'PO', 'AO'], axis=1, inplace=False)\n",
"f_ABO = pd.read_csv(\"D:/ACT/data/f_ABO.csv\",sep='\\s+',header=None).set_axis(['_','EA', 'PA', 'AA','EB', 'PB', 'AB','EO', 'PO', 'AO'], axis=1, inplace=False)\n",
"\n",
"# Defining amalgamation matrices for female and male (used in the amalgamation function)\n",
"amalgamate_f=np.array([[-0.36,.5,0],[-.23,.46,0],[0,.12,0],[-.17,0,.32],[.1,0,.55],[0,0,.05],[-.22,0,0],[.44,-.05,0],[.66,.02,.03]])\n",
"amalgamate_m=np.array([[-0.36,.5,0],[-.23,.46,0],[0,.12,0],[-.17,0,.32],[0,0,.62],[0,0,.05],[-.22,0,0],[.44,-.05,0],[.66,.02,.03]])\n",
"\n",
"# Function to calculate the combination of initial identity and modifiers \n",
"def amalgamate(S_init,S_modif,amalgamate_i=amalgamate_f):\n",
" # Unpacking initial (S_init) and modifier (S_modif) states\n",
" Se,Sp,Sa,Me,Mp,Ma = S_init[0],S_init[1],S_init[2],S_modif[0],S_modif[1],S_modif[2]\n",
" rowss=[1]+S_modif+S_init+[Se*Me, Sa*Me]\n",
" # Applying the matrix to compute the final state\n",
" S_final=np.dot(amalgamate_i.transpose(),rowss)\n",
" return(S_final.transpose())\n",
" \n",
"# Function to calculate the next state (impression) after an event based on impression change equations\n",
"def next_state(A_init,B_init,O_init,ABO):\n",
" #Unpacking the initial state vectors for Actor (A), Behavior (B), and Object (O)\n",
" Ae,Ap,Aa,Be,Bp,Ba,Oe,Op,Oa = A_init[0],A_init[1],A_init[2],B_init[0],B_init[1],B_init[2],O_init[0],O_init[1],O_init[2]\n",
" rowss=np.concatenate(([1],A_init,B_init,O_init,[Ae*Be,Ae*Op,Ap*Bp,Aa*Ba,Be*Oe,Be*Op,Bp*Oe,Bp*Op,Ae*Be*Oe,Ae*Be*Op]))\n",
" # Applying the impression change equations (ABO matrix) to calculate the next state\n",
" impressed=np.dot(ABO.iloc[0:20,1:10].transpose(),rowss)\n",
" # Compute the deflection as the difference between the impressed state and the initial state\n",
" output=impressed-np.concatenate((A_init,B_init,O_init))#[A_final,B_final,O_final]\n",
" return(output) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Tokenizing a Set of Texts for BERT\n",
"This function, preprocessing_for_bert, takes an array of texts and prepares them for input into a pretrained BERT model. The preprocessing includes tokenization, adding special tokens ([CLS] and [SEP]), truncating or padding the sentences to a maximum length, and generating attention masks. These preprocessing steps ensure that the text data is in the correct format to be fed into BERT for further tasks such as classification or sequence prediction.\n",
"\n",
"Key steps performed by the function:\n",
"\n",
" - Tokenization: Convert each sentence into BERT-compatible tokens.\n",
" - Special Tokens: Add [CLS] at the start and [SEP] at the end of each sentence.\n",
" - Padding/Truncation: Ensure sentences are all the same length by truncating longer ones and padding shorter ones.\n",
" - Attention Masks: Generate masks that indicate which tokens should be attended to (non-padded tokens). "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Create a function to tokenize a set of texts\n",
"def preprocessing_for_bert(data,MAX_LEN=40):\n",
" \"\"\"\n",
" Perform required preprocessing steps for pretrained BERT.\n",
"\n",
" @param data (np.array): Array of texts to be processed.\n",
" @param MAX_LEN (int): Maximum length to which the sentences will be truncated/padded.\n",
" @return input_ids (torch.Tensor): Tensor of token IDs to be fed to a model.\n",
" @return attention_masks (torch.Tensor): Tensor of indices specifying which\n",
" tokens should be attended to by the model.\n",
" \"\"\"\n",
" \n",
" # Initialize empty lists to store token IDs and attention masks\n",
" input_ids = []\n",
" attention_masks = []\n",
"\n",
" # Loop through each sentence in the data array\n",
" for sent in data:\n",
" # Use BERT's tokenizer `encode_plus` method to:\n",
" # - Tokenize the sentence\n",
" # - Add `[CLS]` at the start and `[SEP]` at the end\n",
" # - Truncate or pad the sentence to the specified max length\n",
" # - Map tokens to their corresponding IDs\n",
" # - Generate the attention mask\n",
" encoded_sent = tokenizer.encode_plus(\n",
" text=sent, # The sentence to preprocess\n",
" add_special_tokens=True, # Add special tokens like `[CLS]` and `[SEP]`\n",
" max_length=MAX_LEN, # Truncate/pad the sentence to the max length\n",
" padding='max_length', # Pad the sentence to the max length if needed\n",
" return_attention_mask=True # Return the attention mask\n",
" )\n",
" \n",
" # Append the token IDs and attention mask to their respective lists\n",
" input_ids.append(encoded_sent.get('input_ids'))\n",
" attention_masks.append(encoded_sent.get('attention_mask'))\n",
"\n",
" # Convert the lists to tensors (PyTorch format)\n",
" input_ids = torch.tensor(input_ids[0]) # Convert input IDs to a tensor\n",
" attention_masks = torch.tensor(attention_masks[0]) # Convert attention masks to a tensor\n",
" # Return the token IDs and attention masks as outputs\n",
" return input_ids, attention_masks"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Generating All Possible ABO Events in the Training Set\n",
"\n",
"In this step, we generate all possible Actor-Behavior-Object (ABO) event combinations from the training set. Using the indices of the training set for Identities (Actors and Objects) and Behaviors, we create every possible event where each Actor and Object pair is associated with a Behavior.\n",
"\n",
" - ABO Event Generation: By taking the Cartesian product of the indices for the Actor (Identity), Behavior, and Object (Identity) columns in the training set, we generate all possible events.\n",
" - The total number of events is calculated based on these combinations."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using the training set we can make 356353569 events\n"
]
}
],
"source": [
"import itertools\n",
"# list(I_train.index_in_dic)\n",
"# import itertools\n",
"# a = [I_train.index_in_dic,B_train.index_in_dic,I_train.index_in_dic]\n",
"# qq=list(itertools.product(*a))\n",
"\n",
"# Create a list of indices for Actors (Identities), Behaviors, and Objects (Identities) from the training set\n",
"a = [[x for x in range(len(I_train))],[x for x in range(len(B_train))],[x for x in range(len(I_train))]]\n",
"# Generate all possible ABO events by taking the Cartesian product of the indices\n",
"qq=list(itertools.product(*a))\n",
"\n",
"# Output the total number of events generated from the training set\n",
"print('Using the training set we can make ',len(qq),' events')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P-NoEu_ESPfr"
},
"source": [
"\n",
"#### Calculating Deflection for All Possible ABO Events\n",
"\n",
"In this section, we calculate the deflection for every possible event generated from the training set. Deflection is a measure used in Affect Control Theory (ACT) to quantify the discrepancy between expected and actual outcomes in Actor-Behavior-Object (ABO) events. This process can be computationally intensive, so we save the results to a file for later use.\n",
"\n",
"This section draws from [a tutorial on basic regression with TensorBoard](https://github.com/thoo/trax-tutorial/blob/master/basic_regression_tensorboard.ipynb), which guided the implementation of the event processing loop.\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# Uncommented lines for future use if needed\n",
"\n",
"# I_train_array=np.array(I_train[['E', 'P', 'A']]) \n",
"# B_train_array=np.array(B_train[['E', 'P', 'A']]) \n",
"# from datetime import datetime\n",
"# print('Process started at',datetime.now() )\n",
"# deflects=[]\n",
"# counter_tmp=0\n",
"# #Loop over all possible ABO events in the training set\n",
"# for x in qq:\n",
" # #Calculate deflection using the next_state function for each event \n",
"# w=next_state(I_train_array[x[0]],B_train_array[x[1]],I_train_array[x[2]],m_ABO)\n",
"# counter_tmp=counter_tmp+1 # Increment the counter\n",
"# # Print progress every 10 million samples\n",
"# if (counter_tmp % 10000000 == 0 ) : print(counter_tmp, ' sample processed at:',datetime.now() )\n",
"# deflects.append(w)\n",
"# print('Process ended at',datetime.now() )\n",
"# df_deflect=pd.DataFrame(np.array(deflects)).set_axis(['EA', 'PA', 'AA','EB', 'PB', 'AB','EO', 'PO', 'AO'], axis=1, inplace=False)\n",
"# df_deflect.to_pickle('D:/tmp/deflects_all_words_df.pickle')\n",
"# df_deflect.describe().apply(lambda series: series.apply(\"{0:.2f}\".format)).set_axis(['EA', 'PA', 'AA','EB', 'PB', 'AB','EO', 'PO', 'AO'], axis=1, inplace=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Observing Event Frequency and Class Distribution\n",
"In this section, we classify all possible Actor-Behavior-Object (ABO) events based on the sign of their deflection values. The sign indicates the direction of deflection, allowing us to group events into 508 distinct classes. These event classes are used in the training process of the model.\n",
"\n",
"We also observe the frequency of events within each class, noting that the events are not uniformly distributed across these classes.\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"112 23716303\n",
"224 18421833\n",
"392 15305386\n",
"240 12284697\n",
"56 11802293\n",
" ... \n",
"178 6\n",
"306 5\n",
"290 5\n",
"162 3\n",
"355 2\n",
"Length: 508, dtype: int64"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Import the math library for mathematical operations\n",
"import math as math\n",
"# Define a lambda function to return the sign of a number (1 for positive, -1 for negative)\n",
"sign = lambda x: math.copysign(1, x)\n",
"\n",
"# #Uncommented lines for extracting the column names from the deflection DataFrame\n",
"# lst=list(df_deflect.columns)\n",
"\n",
"# #Classify each event based on the sign of the deflection values.\n",
"# #The deflection values are converted to binary (1 for positive, 0 for negative) and then combined into an integer class ID.\n",
"\n",
"# event_type=df_deflect.apply(lambda row:int(\"\".join(str(x) for x in [int(0.5*sign(row[x])+0.5) for x in lst]), 2) , axis=1)#.value_counts()\n",
"# event_type.to_pickle('D:/tmp/all_event_types.pickle')\n",
"\n",
"# #Load the previously saved event types from the pickle file\n",
"event_type=pd.read_pickle('D:/tmp/all_event_types.pickle')\n",
"\n",
"# #Print the number of unique event types (classes) created based on deflection sign\n",
"print(len(list(event_type.unique())))\n",
"\n",
"# #Display the frequency of events within each class (value counts)\n",
"event_type.value_counts()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Generating Event Classes Based on Deflection Signs\n",
"\n",
"In this section, we process the deflection data and group events by their deflection signs. The deflection signs for each event are converted into a binary string representation, which helps classify the events. These classified event groups can then be iterated over during training.\n",
"\n",
"The function def_class_gen generates groups of events that share the same deflection class, based on the sign of the deflection values. This function uses a generator to yield groups of events as needed, optimizing memory usag"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# #Uncommented lines to process deflection DataFrame columns and classify events by deflection sign\n",
"\n",
"# lst=list(df_deflect.columns)\n",
"# df_deflect['sign'] = df_deflect.apply(lambda row: [sign(row[x]) for x in lst], axis=1)\n",
"# df_deflect = df_deflect['sign']\n",
"# df_deflect['signs']= [''.join(map(str, l)) for l in df_deflect['sign']]\n",
"# del df_deflect\n",
"\n",
"# #Import permutations to handle event combinations\n",
"from itertools import permutations\n",
"# Function to generate groups of events based on their deflection class\n",
"def def_class_gen():\n",
" while True: ## Infinite loop to continuously yield groups of events\n",
" # Iterate over each unique event type (deflection class)\n",
" for w in list(event_type.unique()):\n",
" # Yield the list of events that belong to the current deflection class\n",
" yield [qq[x] for x in (event_type.loc[(event_type==w),].index.tolist())]\n",
" \n",
"# Initialize the generator to start producing event groups by class \n",
"g = def_class_gen()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Generating Random ABO Events for Model Training\n",
"\n",
"In this section, we define a function gnrtr to generate random Actor-Behavior-Object (ABO) events from the training set. This function selects random entries from the Identity, Behavior, and Modifier datasets and concatenates them into a sentence that represents the event. The function also retrieves the corresponding EPA values and indices for these events. These generated events will be used as input for model training.\n",
"\n",
"The function yields the following:\n",
"\n",
" - inputs: BERT tokenized sentences.\n",
" - masks: Attention masks for BERT.\n",
" - ys: EPA values (evaluation, potency, and activity) for the selected event.\n",
" - indexx: The indices of the selected components in the respective datasets (Modifier, Identity, Behavior)."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# Define a function to generate random ABO (Actor-Behavior-Object) events\n",
"def gnrtr(Identity,Behavior,Modifier):\n",
" # Randomly select a combination of indices for Identity (Actor, Object), Behavior, and Modifier\n",
" qq=sample(next(g),1)[0]\n",
" # Retrieve the two identities (Actor and Object) and the behavior from the training set based on the randomly selected index\n",
" ident1=Identity.loc[Identity.index_in_dic==list(n_I_train.index_in_dic)[qq[0]],] \n",
" ident2=Identity.loc[Identity.index_in_dic==list(n_I_train.index_in_dic)[qq[2]],] \n",
" behav=Behavior.loc[Behavior.index_in_dic==list(n_B_train.index_in_dic)[qq[1]],] \n",
" # Randomly select two modifiers (one for the actor and one for the object)\n",
" modif1,modif2=Modifier.sample(axis = 0),Modifier.sample(axis = 0)\n",
" # Extract the terms (strings) for the actor, behavior, and object components\n",
" id1,id2,beh,mod1,mod2=list(ident1.term),list(ident2.term),list(behav.term),list(modif1.term),list(modif2.term)\n",
" # Combine the terms into a single sentence representing the event\n",
" sents=' '.join(map(str, (mod1+id1+beh+mod2+id2)))\n",
"\n",
" # Concatenate the EPA values for the modifier, actor, behavior, and object into a single array\n",
" values=np.concatenate([(modif1[['E','P','A']]).to_numpy(),\n",
" (ident1[['E','P','A']]).to_numpy(),\n",
" (behav[['E','P','A']]).to_numpy(),\n",
" (modif2[['E','P','A']]).to_numpy(),\n",
" (ident2[['E','P','A']]).to_numpy()], axis=1)[0]\n",
" # Retrieve the indices of the selected components from the respective datasets\n",
" indexx=torch.tensor([[(modif1['index_in_dic']).to_numpy()][0][0],\n",
" [(ident1['index_in_dic']).to_numpy()][0][0],\n",
" [(behav['index_in_dic']).to_numpy()][0][0],\n",
" [(modif2['index_in_dic']).to_numpy()][0][0],\n",
" [(ident2['index_in_dic']).to_numpy()][0][0]])\n",
" # Convert the EPA values to a PyTorch tensor\n",
" ys= torch.tensor(values)\n",
" # Tokenize the sentence and generate the attention mask using the BERT tokenizer \n",
" inputs, masks = preprocessing_for_bert([sents])\n",
" # Yield the inputs (tokenized sentence), masks (attention mask), EPA values (ys), and component indices (indexx) \n",
" yield inputs, masks, ys,indexx "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Generating Random ABO Events for Model Training\n",
"\n",
"For the validation and test sets, we don't need to classify events based on deflection sign or event classes. Instead, we can generate random Actor-Behavior-Object (ABO) events without considering class structure. The function gnrtr2 is used to randomly sample events from the Identity, Behavior, and Modifier datasets, and constructs sentences and their associated EPA values.\n",
"\n",
"Additionally, we define a dta_ldr2 function to load batches of data using PyTorch's DataLoader. This helps in efficiently handling batches during model validation and testing.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"# Import DataLoader from PyTorch to handle data batching\n",
"from torch.utils.data import DataLoader \n",
"\n",
"# Define a function to generate random ABO (Actor-Behavior-Object) events for validation and test sets\n",
"def gnrtr2(Identity,Behavior,Modifier):\n",
" # Randomly sample identities (Actor and Object), a behavior, and two modifiers from the datasets \n",
" ident1,ident2,behav=Identity.sample(axis = 0),Identity.sample(axis = 0),Behavior.sample(axis = 0)\n",
" modif1,modif2=Modifier.sample(axis = 0),Modifier.sample(axis = 0)\n",
" id1,id2,beh,mod1,mod2=list(ident1.term),list(ident2.term),list(behav.term),list(modif1.term),list(modif2.term)\n",
" # Combine the terms into a single sentence representing the event\n",
" sents=' '.join(map(str, (mod1+id1+beh+mod2+id2)))\n",
"\n",
" # Create an array of EPA values for the event, concatenating modifier, actor, behavior, and object values\n",
" values=np.concatenate([(modif1[['E','P','A']]).to_numpy(),\n",
" (ident1[['E','P','A']]).to_numpy(),\n",
" (behav[['E','P','A']]).to_numpy(),\n",
" (modif2[['E','P','A']]).to_numpy(),\n",
" (ident2[['E','P','A']]).to_numpy()], axis=1)[0]\n",
" # Create a tensor of indices for each component of the event\n",
" indexx=torch.tensor([[(modif1['index_in_dic']).to_numpy()][0][0],\n",
" [(ident1['index_in_dic']).to_numpy()][0][0],\n",
" [(behav['index_in_dic']).to_numpy()][0][0],\n",
" [(modif2['index_in_dic']).to_numpy()][0][0],\n",
" [(ident2['index_in_dic']).to_numpy()][0][0]])\n",
" # Convert the EPA values to a PyTorch tensor\n",
" ys= torch.tensor(values)\n",
" # Tokenize the sentence and generate an attention mask using the BERT tokenizer\n",
" inputs, masks = preprocessing_for_bert([sents])\n",
" # Yield the tokenized inputs, attention mask, EPA values, and component indices \n",
" yield inputs, masks, ys,indexx #torch.tensor(sents),\n",
"\n",
"# Define a function to load batches of ABO events for validation or testing\n",
"# The batch size is set to 64 by default, but it can be modified\n",
"def dta_ldr2(I,B,M,batch_size=64):\n",
" dt_ldr= [x for x in DataLoader([next(gnrtr2(I,B,M)) for x in range(batch_size)], batch_size=batch_size)][0]\n",
" return(dt_ldr)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Fine-Tuning BERT for Regression Tasks\n",
"In this section, we define functions to handle batch loading, build a BERT-based regressor, and train the model. The BertRegressor class is a custom neural network for regression tasks that fine-tunes a pre-trained BERT model. We also define helper functions to initialize the model, set the seed for reproducibility, and train and evaluate the model.\n",
"\n",
"The training process involves loading batches of data, computing predictions, calculating the loss, and updating the model's weights. Evaluation is performed after a specified number of batches to track the model's performance on validation data.\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# Import required modules from PyTorch\n",
"import torch.nn as nn\n",
"\n",
"\n",
"# Function to load batches of data for training\n",
"def dta_ldr(I,B,M,batch_size=32): # For fine-tuning BERT, the BERT authors recommend a batch size of 16 or 32.\n",
" # Generate a batch of ABO events and return the first batch from DataLoader\n",
" dt_ldr= [x for x in DataLoader([next(gnrtr(I,B,M)) for x in range(batch_size)], batch_size=batch_size)][0]\n",
" return(dt_ldr)\n",
"# Define a custom BERT model for regression tasks\n",
"class BertRegressor(nn.Module):\n",
" \"\"\"Bert Model for Regression Tasks.\n",
" \"\"\"\n",
" def __init__(self, freeze_bert=False):\n",
" \"\"\"\n",
" Initializes the BERT-based regressor.\n",
"\n",
" @param freeze_bert (bool): Set to `True` if you want to freeze BERT layers during training.\n",
" \"\"\"\n",
" \n",
" super(BertRegressor, self).__init__()\n",
" # Define the hidden size of BERT, hidden size for the regressor, and output size\n",
" D_in, H, D_out = 1024, 120, 15\n",
"\n",
" # Load the pre-trained BERT model\n",
" self.bert = BertModel.from_pretrained('bert-large-uncased')\n",
"\n",
" # Define a feed-forward regressor with dropout and ReLU activation layers\n",
" self.regressor = nn.Sequential(\n",
" nn.Dropout(0.4),\n",
" nn.Linear(D_in, H),\n",
" nn.Dropout(0.3),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.3),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.3),\n",
" nn.Linear(H, D_out)\n",
" )\n",
"\n",
" # Optionally freeze the BERT model during training\n",
" if freeze_bert:\n",
" for param in self.bert.parameters():\n",
" param.requires_grad = False\n",
" \n",
" def forward(self, input_ids, attention_mask):\n",
" \"\"\"\n",
" Forward pass for the BERT regressor model.\n",
" \n",
" @param input_ids (torch.Tensor): Input tensor of token IDs.\n",
" @param attention_mask (torch.Tensor): Attention mask tensor for BERT.\n",
" @return predictions (torch.Tensor): Output predictions for the regression task.\n",
" \"\"\"\n",
" \n",
" # Pass input through the BERT model and get the last hidden state of the [CLS] token\n",
" outputs = self.bert(input_ids=input_ids,\n",
" attention_mask=attention_mask) \n",
" last_hidden_state_cls = outputs.pooler_output \n",
"\n",
" # Pass the hidden state through the regressor to get predictions\n",
" predictions = self.regressor(last_hidden_state_cls) \n",
"\n",
" return predictions \n",
"from transformers import AdamW, get_linear_schedule_with_warmup\n",
"\n",
"# Function to initialize the BERT model, optimizer, and learning rate scheduler\n",
"def initialize_model(epochs=4):\n",
" \"\"\"Initialize the Bert Regressor, optimizer, and learning rate scheduler.\n",
" \"\"\"\n",
" # Instantiate the BERT regressor\n",
" bert_regressor = BertRegressor(freeze_bert=False)\n",
"\n",
" # Move the model to the GPU\n",
" bert_regressor.to(device)\n",
"\n",
" # Define the optimizer (AdamW) with learning rate and weight decay\n",
" optimizer = torch.optim.AdamW(bert_regressor.parameters(),\n",
" lr=2e-5, # Smaller LR\n",
" eps=1e-8, # Default epsilon value\n",
" weight_decay =0.001 # Decoupled weight decay to apply.\n",
" )\n",
"\n",
"\n",
"\n",
" # Define the total number of training steps (based on batch size and epochs)\n",
" total_steps = 100000#len(train_dataloader) * epochs\n",
"\n",
" # Set up the learning rate scheduler\n",
" scheduler = get_linear_schedule_with_warmup(optimizer,\n",
" num_warmup_steps=0, # Default value\n",
" num_training_steps=total_steps)\n",
" return bert_regressor, optimizer, scheduler\n",
"import random\n",
"import time\n",
"\n",
" \n",
"# Define the Mean Squared Error loss function\n",
"loss_fn = nn.MSELoss()\n",
"\n",
"# Function to set random seeds for reproducibility\n",
"def set_seed(seed_value=42):\n",
" \"\"\"Set seed for reproducibility.\n",
" \"\"\"\n",
" random.seed(seed_value)\n",
" np.random.seed(seed_value)\n",
" torch.manual_seed(seed_value)\n",
" torch.cuda.manual_seed_all(seed_value)\n",
"\n",
"# Function to train the BERT regressor model\n",
"def train(model, I_trn=n_I_train,B_trn=n_B_train,M_trn=n_M_train,batch_size_trn=32,\n",
" I_tst=n_I_test,B_tst=n_B_test,M_tst=n_M_test,batch_size_tst=32, batch_size=50,batch_epochs=400, evaluation=False):\n",
" \"\"\"Train the Bert Regressor model.\n",
" \"\"\"\n",
" print(\"Start training...\\n\")\n",
" # =======================================\n",
" # Training\n",
" # =======================================\n",
" # Print the header of the result table\n",
" print(f\" {'Batch':^5} | {'Train Loss':^12} | {'Val Loss':^10} | {'Elapsed':^9}\")\n",
" print(\"-\"*50)\n",
" t0_batch = time.time() # Start timer for batch\n",
" batch_loss, batch_counts = 0, 0 # Initialize batch loss and counts\n",
" # Set model to training mode\n",
" model.train()\n",
" # For each batch of training data...\n",
" for batch in range(batch_epochs): #298\n",
" batch_counts +=1\n",
" if ((batch==(704))):break #457#383#1451#246\n",
"# if val_loss<0.3: break\n",
" # Load batch data onto GPU\n",
" b_input_ids, b_attn_mask, b_ys,_ = tuple(t.to(device) for t in dta_ldr(I=I_trn,B=B_trn,M=M_trn,batch_size=batch_size_trn))\n",
" \n",
" model.zero_grad() # Clear previous gradients\n",
" # Perform forward pass to get predictions\n",
" preds = model(b_input_ids, b_attn_mask)\n",
" # Compute loss between predictions and ground trut \n",
" loss = loss_fn(preds.float(), b_ys.float())\n",
" batch_loss += loss.item()\n",
" # Perform a backward pass to calculate gradients\n",
" loss.backward() # Perform backward pass to compute gradients\n",
" # Clip gradients to avoid exploding gradients\n",
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
" # Update parameters and the learning rate\n",
" optimizer.step() # Update model weights\n",
" scheduler.step() # Update learning rate\n",
"\n",
"\n",
" # Print every 50 batches\n",
" if (batch_counts % 50 == 0 and batch_counts != 0) : #or(batch>585)\n",
" # Calculate elapsed time \n",
" time_elapsed = time.time() - t0_batch\n",
"\n",
" # Print training and validation results\n",
" val_loss = evaluate(model, Ie=I_tst,Be=B_tst,Me=M_tst,batch_size_e=batch_size_tst)\n",
" print(f\"{batch+ 1:^7}|{batch_loss / batch_counts:^12.6f} | {val_loss:^10.6f} | {time_elapsed:^9.2f}\") #| {step:^7}\n",
" # After the completion of each training epoch, measure the model's performance\n",
" # on our validation set.\n",
" print(\"-\"*50)\n",
"# print(batch)\n",
"\n",
"# if (batch<586):\n",
"# # Reset batch tracking variables\n",
"# batch_loss, batch_counts = 0, 0\n",
"# t0_batch = time.time()\n",
"# # Reset batch tracking variables\n",
" batch_loss, batch_counts = 0, 0\n",
" t0_batch = time.time()\n",
"\n",
"\n",
" # =======================================\n",
" # Evaluation\n",
" # =======================================\n",
" if evaluation == True:\n",
" # After the completion of each training epoch, measure the model's performance\n",
" # on our validation set.\n",
" val_loss = evaluate(model, Ie=I_tst,Be=B_tst,Me=M_tst,batch_size_e=batch_size_tst)\n",
" if val_loss<0.31: \n",
" print('\\n Consider this one with val:', val_loss,' at:',batch,'\\n')\n",
" print(\"-\"*50)\n",
"\n",
" \n",
" # Calculate the average loss over the entire training data\n",
"# avg_train_loss = total_loss / (batch_size*batch_epochs)\n",
"\n",
" val_loss = evaluate(model, Ie=I_tst,Be=B_tst,Me=M_tst,batch_size_e=batch_size_tst)\n",
" print(f\"{batch+ 1:^7}|{batch_loss / batch_counts:^12.6f} | {val_loss:^10.6f} | {time_elapsed:^9.2f}\") #| {step:^7} \n",
" print(\"Training complete!\")\n",
"\n",
"\n",
"def evaluate(model, Ie,Be,Me,batch_size_e):\n",
" \"\"\"Evaluate model performance on the validation set.\n",
" \"\"\"\n",
" # Set model to evaluation mode\n",
" model.eval()\n",
"\n",
" # Initialize list to store validation losses\n",
" val_loss = []\n",
"\n",
" # Load validation data and compute predictions\n",
" for batch in range(1):\n",
" # Load batch to GPU\n",
" b_input_ids, b_attn_mask, b_ys,_ = tuple(t.to(device) for t in dta_ldr2(Ie,Be,Me,batch_size_e))\n",
"\n",
" with torch.no_grad(): # Disable gradient computation for validation\n",
" preds = model(b_input_ids, b_attn_mask)\n",
"\n",
" # Compute loss\n",
" loss = loss_fn(preds, b_ys)\n",
" val_loss.append(loss.item()) # Append loss to list\n",
"\n",
"\n",
" # Compute average validation loss\n",
" val_loss = np.mean(val_loss)\n",
"\n",
" return val_loss\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"rnd_st=42\n",
"np.random.seed(rnd_st)\n",
"random.seed(rnd_st)\n",
"torch.manual_seed(rnd_st)\n",
"torch.cuda.manual_seed(rnd_st)\n",
"# Running on the CuDNN backend\n",
"torch.backends.cudnn.deterministic = True\n",
"torch.backends.cudnn.benchmark = False"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Process started at 2022-10-25 17:35:14.774299\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Start training...\n",
"\n",
" Batch | Train Loss | Val Loss | Elapsed \n",
"--------------------------------------------------\n",
" 50 | 0.929968 | 0.967893 | 1656.88 \n",
"--------------------------------------------------\n",
" 100 | 0.783472 | 0.858978 | 2323.24 \n",
"--------------------------------------------------\n",
" 150 | 0.622484 | 0.707333 | 2316.32 \n",
"--------------------------------------------------\n",
" 200 | 0.500507 | 0.601719 | 2363.26 \n",
"--------------------------------------------------\n",
" 250 | 0.356392 | 0.495498 | 2201.78 \n",
"--------------------------------------------------\n",
" 300 | 0.233981 | 0.419157 | 2202.94 \n",
"--------------------------------------------------\n",
" 350 | 0.152351 | 0.370178 | 2216.25 \n",
"--------------------------------------------------\n",
" 400 | 0.098886 | 0.359047 | 2171.53 \n",
"--------------------------------------------------\n",
" 450 | 0.072336 | 0.343374 | 2266.49 \n",
"--------------------------------------------------\n",
" 500 | 0.056853 | 0.348072 | 2349.00 \n",
"--------------------------------------------------\n",
" 550 | 0.044879 | 0.333341 | 2381.07 \n",
"--------------------------------------------------\n",
" 600 | 0.033761 | 0.325401 | 2323.50 \n",
"--------------------------------------------------\n",
"\n",
" Consider this one with val: 0.3066757112455284 at: 632 \n",
"\n",
"--------------------------------------------------\n",
"\n",
" Consider this one with val: 0.30767131862777475 at: 646 \n",
"\n",
"--------------------------------------------------\n",
" 650 | 0.026053 | 0.318807 | 2281.13 \n",
"--------------------------------------------------\n",
" 700 | 0.021020 | 0.314143 | 2328.34 \n",
"--------------------------------------------------\n",
" 705 | 0.016111 | 0.313417 | 2328.34 \n",
"Training complete!\n",
"Process ended at 2022-10-26 02:24:05.768346\n"
]
}
],
"source": [
"print('Process started at',datetime.now() )\n",
"# set_seed(42) # Set seed for reproducibility\n",
"bert_regressor, optimizer, scheduler = initialize_model(epochs=2)\n",
"train(bert_regressor, I_trn=n_I_train,B_trn=n_B_train,M_trn=n_M_train,batch_size_trn=64,\n",
" I_tst=n_I_test,B_tst=n_B_test,M_tst=n_M_test,batch_size_tst=1024, batch_epochs=705, evaluation=True)\n",
"print('Process ended at',datetime.now() )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# print('Process started at',datetime.now() )\n",
"# # set_seed(42) # Set seed for reproducibility\n",
"# bert_regressor, optimizer, scheduler = initialize_model(epochs=2)\n",
"# train(bert_regressor, I_trn=n_I_train,B_trn=n_B_train,M_trn=n_M_train,batch_size_trn=64,\n",
"# I_tst=n_I_test,B_tst=n_B_test,M_tst=n_M_test,batch_size_tst=1024, batch_epochs=800, evaluation=True)\n",
"# print('Process ended at',datetime.now() )\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# torch.save(bert_regressor.state_dict(), \"E:/ACT/torch_models/BERTNN_model\") "
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
},
{
"data": {
"text/plain": [
"BertRegressor(\n",
" (bert): BertModel(\n",
" (embeddings): BertEmbeddings(\n",
" (word_embeddings): Embedding(30522, 1024, padding_idx=0)\n",
" (position_embeddings): Embedding(512, 1024)\n",
" (token_type_embeddings): Embedding(2, 1024)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (encoder): BertEncoder(\n",
" (layer): ModuleList(\n",
" (0): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (1): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (2): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (3): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (4): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (5): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (6): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (7): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (8): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (9): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (10): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (11): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (12): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (13): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (14): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (15): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (16): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (17): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (18): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (19): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (20): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (21): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (22): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (23): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
" (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (pooler): BertPooler(\n",
" (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (activation): Tanh()\n",
" )\n",
" )\n",
" (regressor): Sequential(\n",
" (0): Dropout(p=0.4, inplace=False)\n",
" (1): Linear(in_features=1024, out_features=120, bias=True)\n",
" (2): Dropout(p=0.3, inplace=False)\n",
" (3): ReLU()\n",
" (4): Dropout(p=0.3, inplace=False)\n",
" (5): ReLU()\n",
" (6): Dropout(p=0.3, inplace=False)\n",
" (7): Linear(in_features=120, out_features=15, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bert_regressor = BertRegressor()\n",
"bert_regressor.load_state_dict(torch.load(\"E:/ACT/torch_models/BERTNN_model\"))\n",
"# MABMO_Cluster_deflection_training_impression_change_paper_present_tense_2RELU_pooler_check\n",
"bert_regressor.eval()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"\n",
"def bert_predict(model, test_dataloader):\n",
" \"\"\"Perform a forward pass on the trained BERT model to predict probabilities\n",
" on the test set.\n",
" \"\"\"\n",
" # Put the model into the evaluation mode. The dropout layers are disabled during\n",
" # the test time.\n",
" model.eval()\n",
" all_preds = []\n",
" # For each batch in our test set...\n",
" for batch in range(1):\n",
" # Load batch to GPU\n",
" b_input_ids, b_attn_mask = tuple(t.to(device) for t in test_dataloader)[:2]\n",
"\n",
" # Compute predictions\n",
" with torch.no_grad():\n",
" preds = model(b_input_ids, b_attn_mask)#.to(device)\n",
" all_preds.append(preds)\n",
" \n",
" # Concatenate predictions from each batch\n",
" all_preds = torch.cat(all_preds, dim=0)\n",
"\n",
" return all_preds"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Evaluatie on the validation set"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"def out_df(data,predictions,df_beh=Behaviors,df_ident=Identities,df_mod=Modifiers):\n",
" df2=pd.concat([pd.DataFrame(scaler_M.inverse_transform(predictions[:,0:3].cpu())),\n",
" pd.DataFrame(scaler_M.inverse_transform(data[2][:,0:3])),\n",
" pd.DataFrame(scaler_I.inverse_transform(predictions[:,3:6].cpu())),\n",
" pd.DataFrame(scaler_I.inverse_transform(data[2][:,3:6])),\n",
" pd.DataFrame(scaler_B.inverse_transform(predictions[:,6:9].cpu())),\n",
" pd.DataFrame(scaler_B.inverse_transform(data[2][:,6:9])),\n",
" pd.DataFrame(scaler_M.inverse_transform(predictions[:,9:12].cpu())),\n",
" pd.DataFrame(scaler_M.inverse_transform(data[2][:,9:12])),\n",
" pd.DataFrame(scaler_I.inverse_transform(predictions[:,12:15].cpu())),\n",
" pd.DataFrame(scaler_I.inverse_transform(data[2][:,12:15])),pd.DataFrame(np.array(data[3]))\n",
" ],axis=1).set_axis(['EEMA', 'EPMA', 'EAMA','EM1', 'PM1', 'AM1',\n",
" 'EEA', 'EPA', 'EAA','EA', 'PA', 'AA',\n",
" 'EEB', 'EPB', 'EAB','EB', 'PB', 'AB',\n",
" 'EEMO', 'EPMO', 'EAMO','EM2', 'PM2', 'AM2',\n",
" 'EEO', 'EPO', 'EAO','EO', 'PO', 'AO',\n",
" 'idx_ModA','idx_Act','idx_Beh','idx_ModO','idx_Obj'], axis=1, inplace=False)\n",
" df2=pd.merge(df2, df_mod[['term','index_in_dic']], left_on= ['idx_ModA'], right_on = [\"index_in_dic\"], \n",
" how='left').rename(columns={\"term\": 'ModA'}).drop(['index_in_dic'], axis=1)\n",
" df2=pd.merge(df2, df_ident[['term','index_in_dic']], left_on= ['idx_Act'], right_on = [\"index_in_dic\"], \n",
" how='left').rename(columns={\"term\": 'Actor'}).drop(['index_in_dic'], axis=1)\n",
" df2=pd.merge(df2, df_beh[['term','index_in_dic']], left_on= ['idx_Beh'], right_on = [\"index_in_dic\"], \n",
" how='left').rename(columns={\"term\": 'Behavior'}).drop(['index_in_dic'], axis=1)\n",
" df2=pd.merge(df2, df_mod[['term','index_in_dic']], left_on= ['idx_ModO'], right_on = [\"index_in_dic\"], \n",
" how='left').rename(columns={\"term\": 'ModO'}).drop(['index_in_dic'], axis=1)\n",
" df2=pd.merge(df2, df_ident[['term','index_in_dic']], left_on= ['idx_Obj'], right_on = [\"index_in_dic\"], \n",
" how='left').rename(columns={\"term\": 'Object'}).drop(['index_in_dic'], axis=1)\n",
"\n",
" df2=df2[['EEMA','EPMA', 'EAMA', 'EEA', 'EPA', 'EAA', 'EEB', 'EPB', 'EAB','EEMO', 'EPMO', 'EAMO', 'EEO', 'EPO', 'EAO','EM1', 'PM1', 'AM1','EA', 'PA', 'AA', 'EB', 'PB','AB', 'EM2', 'PM2', 'AM2', 'EO',\n",
" 'PO', 'AO', 'ModA','Actor','Behavior', 'ModO', 'Object']]\n",
" return(df2)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"def get_output(I_b=n_I_v,B_b=n_B_v,M_b=n_M_v,batch_sz=3000,batch_num=10):\n",
" df=pd.DataFrame()\n",
" for i in range(batch_num):\n",
" q=dta_ldr2(I=I_b,B=B_b,M=M_b,batch_size=batch_sz)\n",
" preds = bert_predict(bert_regressor.to(device), q)\n",
" df2=out_df(data=q,predictions=preds)\n",
" df=pd.concat([df,df2],axis=0)\n",
" return(df)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_20012\\3313474982.py:2: FutureWarning: DataFrame.set_axis 'inplace' keyword is deprecated and will be removed in a future version. Use `obj = obj.set_axis(..., copy=False)` instead\n",
" df2=pd.concat([pd.DataFrame(scaler_M.inverse_transform(predictions[:,0:3].cpu())),\n",
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_20012\\3313474982.py:2: FutureWarning: DataFrame.set_axis 'inplace' keyword is deprecated and will be removed in a future version. Use `obj = obj.set_axis(..., copy=False)` instead\n",
" df2=pd.concat([pd.DataFrame(scaler_M.inverse_transform(predictions[:,0:3].cpu())),\n",
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_20012\\3313474982.py:2: FutureWarning: DataFrame.set_axis 'inplace' keyword is deprecated and will be removed in a future version. Use `obj = obj.set_axis(..., copy=False)` instead\n",
" df2=pd.concat([pd.DataFrame(scaler_M.inverse_transform(predictions[:,0:3].cpu())),\n",
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_20012\\3313474982.py:2: FutureWarning: DataFrame.set_axis 'inplace' keyword is deprecated and will be removed in a future version. Use `obj = obj.set_axis(..., copy=False)` instead\n",
" df2=pd.concat([pd.DataFrame(scaler_M.inverse_transform(predictions[:,0:3].cpu())),\n",
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_20012\\3313474982.py:2: FutureWarning: DataFrame.set_axis 'inplace' keyword is deprecated and will be removed in a future version. Use `obj = obj.set_axis(..., copy=False)` instead\n",
" df2=pd.concat([pd.DataFrame(scaler_M.inverse_transform(predictions[:,0:3].cpu())),\n",
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_20012\\3313474982.py:2: FutureWarning: DataFrame.set_axis 'inplace' keyword is deprecated and will be removed in a future version. Use `obj = obj.set_axis(..., copy=False)` instead\n",
" df2=pd.concat([pd.DataFrame(scaler_M.inverse_transform(predictions[:,0:3].cpu())),\n",
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_20012\\3313474982.py:2: FutureWarning: DataFrame.set_axis 'inplace' keyword is deprecated and will be removed in a future version. Use `obj = obj.set_axis(..., copy=False)` instead\n",
" df2=pd.concat([pd.DataFrame(scaler_M.inverse_transform(predictions[:,0:3].cpu())),\n",
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_20012\\3313474982.py:2: FutureWarning: DataFrame.set_axis 'inplace' keyword is deprecated and will be removed in a future version. Use `obj = obj.set_axis(..., copy=False)` instead\n",
" df2=pd.concat([pd.DataFrame(scaler_M.inverse_transform(predictions[:,0:3].cpu())),\n",
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_20012\\3313474982.py:2: FutureWarning: DataFrame.set_axis 'inplace' keyword is deprecated and will be removed in a future version. Use `obj = obj.set_axis(..., copy=False)` instead\n",
" df2=pd.concat([pd.DataFrame(scaler_M.inverse_transform(predictions[:,0:3].cpu())),\n",
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_20012\\3313474982.py:2: FutureWarning: DataFrame.set_axis 'inplace' keyword is deprecated and will be removed in a future version. Use `obj = obj.set_axis(..., copy=False)` instead\n",
" df2=pd.concat([pd.DataFrame(scaler_M.inverse_transform(predictions[:,0:3].cpu())),\n"
]
}
],
"source": [
"df_val=get_output(I_b=n_I_v,B_b=n_B_v,M_b=n_M_v,batch_sz=3000,batch_num=10)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"from scipy.stats import pearsonr\n",
"def pval_cor(df):\n",
" rho = df.corr()\n",
" pval = df.corr(method=lambda x, y: pearsonr(x, y)[1]) - np.eye(*rho.shape)\n",
" p = pval.applymap(lambda x: ''.join(['*' for t in [0.001,0.01,0.05] if x<=t]))\n",
" return(rho.round(3).astype(str) + p)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"# This function is added to calculate the p-values for the Spearman correlation\n",
"# Adding it after Christophe Blaison requested it\n",
" \n",
"from scipy.stats import spearmanr\n",
"\n",
"def pval_cor_spearman(df):\n",
" # Calculate Spearman correlation matrix\n",
" rho = df.corr(method='spearman')\n",
" \n",
" # Initialize p-value matrix with the same shape and column/index labels as rho\n",
" pval = pd.DataFrame(np.zeros_like(rho.values), columns=rho.columns, index=rho.index)\n",
" \n",
" # Iterate through pairs of columns in rho, calculating p-values with spearmanr\n",
" for i, col in enumerate(rho.columns):\n",
" for j, row in enumerate(rho.index):\n",
" if i != j: # Avoid diagonal elements\n",
" _, p = spearmanr(df[col], df[row])\n",
" pval.iloc[i, j] = p\n",
" else:\n",
" # Optionally, set diagonal elements to NaN or leave them as 0,\n",
" # since they represent correlation of a variable with itself\n",
" pval.iloc[i, j] = np.nan\n",
"\n",
" # Apply significance stars to p-values\n",
" p_stars = pval.applymap(lambda x: ''.join(['*' for t in [0.001, 0.01, 0.05] if x <= t]))\n",
"\n",
" # Combine rho and p_stars, ensuring the output is rounded and formatted as strings\n",
" result = rho.round(3).astype(str) + p_stars\n",
" \n",
" return result\n"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_20012\\3435180724.py:8: FutureWarning: The default value of numeric_only in DataFrame.corr is deprecated. In a future version, it will default to False. Select only valid columns or specify the value of numeric_only to silence this warning.\n",
" rho = df.corr(method='spearman')\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEMA | \n",
" EPMA | \n",
" EAMA | \n",
" EEA | \n",
" EPA | \n",
" EAA | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
" EEMO | \n",
" ... | \n",
" AA | \n",
" EB | \n",
" PB | \n",
" AB | \n",
" EM2 | \n",
" PM2 | \n",
" AM2 | \n",
" EO | \n",
" PO | \n",
" AO | \n",
"
\n",
" \n",
" \n",
" \n",
" | EM1 | \n",
" 0.902*** | \n",
" 0.788*** | \n",
" 0.006 | \n",
" 0.002 | \n",
" 0.014* | \n",
" -0.007 | \n",
" 0.012* | \n",
" -0.02*** | \n",
" 0.017** | \n",
" -0.027*** | \n",
" ... | \n",
" -0.0 | \n",
" 0.002 | \n",
" -0.008 | \n",
" -0.0 | \n",
" 0.001 | \n",
" 0.002 | \n",
" -0.004 | \n",
" -0.001 | \n",
" -0.0 | \n",
" 0.007 | \n",
"
\n",
" \n",
" | PM1 | \n",
" 0.792*** | \n",
" 0.89*** | \n",
" 0.271*** | \n",
" 0.0 | \n",
" 0.019** | \n",
" -0.003 | \n",
" 0.008 | \n",
" -0.013* | \n",
" 0.008 | \n",
" -0.031*** | \n",
" ... | \n",
" 0.006 | \n",
" 0.003 | \n",
" -0.0 | \n",
" 0.0 | \n",
" -0.002 | \n",
" -0.002 | \n",
" -0.0 | \n",
" -0.004 | \n",
" -0.002 | \n",
" 0.007 | \n",
"
\n",
" \n",
" | AM1 | \n",
" 0.022*** | \n",
" 0.172*** | \n",
" 0.82*** | \n",
" -0.007 | \n",
" 0.003 | \n",
" -0.002 | \n",
" -0.005 | \n",
" -0.006 | \n",
" -0.005 | \n",
" -0.007 | \n",
" ... | \n",
" -0.001 | \n",
" 0.002 | \n",
" 0.004 | \n",
" -0.006 | \n",
" -0.005 | \n",
" -0.006 | \n",
" 0.002 | \n",
" -0.002 | \n",
" 0.003 | \n",
" 0.005 | \n",
"
\n",
" \n",
" | EA | \n",
" 0.005 | \n",
" 0.015** | \n",
" -0.021*** | \n",
" 0.862*** | \n",
" 0.406*** | \n",
" 0.072*** | \n",
" 0.017** | \n",
" 0.025*** | \n",
" 0.001 | \n",
" 0.009 | \n",
" ... | \n",
" 0.142*** | \n",
" 0.018** | \n",
" 0.007 | \n",
" -0.004 | \n",
" -0.002 | \n",
" -0.003 | \n",
" -0.004 | \n",
" -0.004 | \n",
" -0.001 | \n",
" 0.001 | \n",
"
\n",
" \n",
" | PA | \n",
" -0.003 | \n",
" 0.019** | \n",
" -0.026*** | \n",
" 0.439*** | \n",
" 0.814*** | \n",
" 0.271*** | \n",
" 0.002 | \n",
" -0.001 | \n",
" -0.015** | \n",
" 0.003 | \n",
" ... | \n",
" 0.296*** | \n",
" 0.005 | \n",
" 0.0 | \n",
" -0.004 | \n",
" 0.007 | \n",
" 0.006 | \n",
" 0.005 | \n",
" -0.0 | \n",
" -0.003 | \n",
" -0.005 | \n",
"
\n",
" \n",
" | AA | \n",
" -0.014* | \n",
" -0.0 | \n",
" -0.014* | \n",
" 0.105*** | \n",
" 0.174*** | \n",
" 0.65*** | \n",
" -0.01 | \n",
" -0.002 | \n",
" -0.027*** | \n",
" 0.002 | \n",
" ... | \n",
" 1.0 | \n",
" 0.003 | \n",
" 0.004 | \n",
" -0.008 | \n",
" 0.002 | \n",
" 0.001 | \n",
" -0.009 | \n",
" -0.002 | \n",
" -0.004 | \n",
" -0.002 | \n",
"
\n",
" \n",
" | EB | \n",
" 0.039*** | \n",
" 0.041*** | \n",
" 0.05*** | \n",
" 0.007 | \n",
" 0.023*** | \n",
" 0.033*** | \n",
" 0.89*** | \n",
" 0.516*** | \n",
" -0.238*** | \n",
" -0.0 | \n",
" ... | \n",
" 0.003 | \n",
" 1.0 | \n",
" 0.525*** | \n",
" -0.207*** | \n",
" 0.001 | \n",
" 0.002 | \n",
" -0.0 | \n",
" -0.005 | \n",
" -0.005 | \n",
" 0.005 | \n",
"
\n",
" \n",
" | PB | \n",
" 0.005 | \n",
" 0.025*** | \n",
" 0.031*** | \n",
" 0.008 | \n",
" 0.007 | \n",
" 0.009 | \n",
" 0.408*** | \n",
" 0.691*** | \n",
" 0.042*** | \n",
" 0.01 | \n",
" ... | \n",
" 0.004 | \n",
" 0.525*** | \n",
" 1.0 | \n",
" 0.148*** | \n",
" 0.01 | \n",
" 0.012* | \n",
" 0.002 | \n",
" -0.014* | \n",
" -0.008 | \n",
" 0.007 | \n",
"
\n",
" \n",
" | AB | \n",
" -0.006 | \n",
" 0.001 | \n",
" -0.017** | \n",
" 0.007 | \n",
" 0.001 | \n",
" -0.012* | \n",
" -0.247*** | \n",
" 0.118*** | \n",
" 0.726*** | \n",
" 0.015** | \n",
" ... | \n",
" -0.008 | \n",
" -0.207*** | \n",
" 0.148*** | \n",
" 1.0 | \n",
" 0.009 | \n",
" 0.006 | \n",
" -0.004 | \n",
" -0.011 | \n",
" -0.006 | \n",
" 0.001 | \n",
"
\n",
" \n",
" | EM2 | \n",
" 0.002 | \n",
" -0.006 | \n",
" 0.017** | \n",
" 0.004 | \n",
" 0.005 | \n",
" -0.007 | \n",
" -0.003 | \n",
" 0.004 | \n",
" 0.012* | \n",
" 0.892*** | \n",
" ... | \n",
" 0.002 | \n",
" 0.001 | \n",
" 0.01 | \n",
" 0.009 | \n",
" 1.0 | \n",
" 0.835*** | \n",
" 0.087*** | \n",
" 0.006 | \n",
" -0.001 | \n",
" -0.007 | \n",
"
\n",
" \n",
" | PM2 | \n",
" -0.001 | \n",
" 0.002 | \n",
" 0.016** | \n",
" 0.001 | \n",
" 0.003 | \n",
" -0.003 | \n",
" -0.004 | \n",
" -0.007 | \n",
" -0.0 | \n",
" 0.791*** | \n",
" ... | \n",
" 0.001 | \n",
" 0.002 | \n",
" 0.012* | \n",
" 0.006 | \n",
" 0.835*** | \n",
" 1.0 | \n",
" 0.328*** | \n",
" 0.008 | \n",
" -0.003 | \n",
" -0.007 | \n",
"
\n",
" \n",
" | AM2 | \n",
" -0.007 | \n",
" 0.002 | \n",
" 0.013* | \n",
" -0.011 | \n",
" 0.001 | \n",
" -0.002 | \n",
" -0.011 | \n",
" -0.009 | \n",
" -0.043*** | \n",
" 0.004 | \n",
" ... | \n",
" -0.009 | \n",
" -0.0 | \n",
" 0.002 | \n",
" -0.004 | \n",
" 0.087*** | \n",
" 0.328*** | \n",
" 1.0 | \n",
" -0.002 | \n",
" -0.011* | \n",
" 0.005 | \n",
"
\n",
" \n",
" | EO | \n",
" -0.026*** | \n",
" -0.012* | \n",
" -0.008 | \n",
" 0.002 | \n",
" -0.012* | \n",
" -0.001 | \n",
" -0.009 | \n",
" 0.001 | \n",
" -0.023*** | \n",
" 0.007 | \n",
" ... | \n",
" -0.002 | \n",
" -0.005 | \n",
" -0.014* | \n",
" -0.011 | \n",
" 0.006 | \n",
" 0.008 | \n",
" -0.002 | \n",
" 1.0 | \n",
" 0.453*** | \n",
" 0.144*** | \n",
"
\n",
" \n",
" | PO | \n",
" -0.016** | \n",
" -0.014* | \n",
" -0.003 | \n",
" 0.011 | \n",
" -0.018** | \n",
" -0.012* | \n",
" -0.015** | \n",
" -0.011 | \n",
" -0.003 | \n",
" -0.004 | \n",
" ... | \n",
" -0.004 | \n",
" -0.005 | \n",
" -0.008 | \n",
" -0.006 | \n",
" -0.001 | \n",
" -0.003 | \n",
" -0.011* | \n",
" 0.453*** | \n",
" 1.0 | \n",
" 0.291*** | \n",
"
\n",
" \n",
" | AO | \n",
" 0.016** | \n",
" 0.012* | \n",
" -0.002 | \n",
" 0.014* | \n",
" 0.003 | \n",
" 0.007 | \n",
" 0.005 | \n",
" 0.002 | \n",
" 0.003 | \n",
" -0.01 | \n",
" ... | \n",
" -0.002 | \n",
" 0.005 | \n",
" 0.007 | \n",
" 0.001 | \n",
" -0.007 | \n",
" -0.007 | \n",
" 0.005 | \n",
" 0.144*** | \n",
" 0.291*** | \n",
" 1.0 | \n",
"
\n",
" \n",
"
\n",
"
15 rows × 30 columns
\n",
"
"
],
"text/plain": [
" EEMA EPMA EAMA EEA EPA EAA EEB \\\n",
"EM1 0.902*** 0.788*** 0.006 0.002 0.014* -0.007 0.012* \n",
"PM1 0.792*** 0.89*** 0.271*** 0.0 0.019** -0.003 0.008 \n",
"AM1 0.022*** 0.172*** 0.82*** -0.007 0.003 -0.002 -0.005 \n",
"EA 0.005 0.015** -0.021*** 0.862*** 0.406*** 0.072*** 0.017** \n",
"PA -0.003 0.019** -0.026*** 0.439*** 0.814*** 0.271*** 0.002 \n",
"AA -0.014* -0.0 -0.014* 0.105*** 0.174*** 0.65*** -0.01 \n",
"EB 0.039*** 0.041*** 0.05*** 0.007 0.023*** 0.033*** 0.89*** \n",
"PB 0.005 0.025*** 0.031*** 0.008 0.007 0.009 0.408*** \n",
"AB -0.006 0.001 -0.017** 0.007 0.001 -0.012* -0.247*** \n",
"EM2 0.002 -0.006 0.017** 0.004 0.005 -0.007 -0.003 \n",
"PM2 -0.001 0.002 0.016** 0.001 0.003 -0.003 -0.004 \n",
"AM2 -0.007 0.002 0.013* -0.011 0.001 -0.002 -0.011 \n",
"EO -0.026*** -0.012* -0.008 0.002 -0.012* -0.001 -0.009 \n",
"PO -0.016** -0.014* -0.003 0.011 -0.018** -0.012* -0.015** \n",
"AO 0.016** 0.012* -0.002 0.014* 0.003 0.007 0.005 \n",
"\n",
" EPB EAB EEMO ... AA EB PB \\\n",
"EM1 -0.02*** 0.017** -0.027*** ... -0.0 0.002 -0.008 \n",
"PM1 -0.013* 0.008 -0.031*** ... 0.006 0.003 -0.0 \n",
"AM1 -0.006 -0.005 -0.007 ... -0.001 0.002 0.004 \n",
"EA 0.025*** 0.001 0.009 ... 0.142*** 0.018** 0.007 \n",
"PA -0.001 -0.015** 0.003 ... 0.296*** 0.005 0.0 \n",
"AA -0.002 -0.027*** 0.002 ... 1.0 0.003 0.004 \n",
"EB 0.516*** -0.238*** -0.0 ... 0.003 1.0 0.525*** \n",
"PB 0.691*** 0.042*** 0.01 ... 0.004 0.525*** 1.0 \n",
"AB 0.118*** 0.726*** 0.015** ... -0.008 -0.207*** 0.148*** \n",
"EM2 0.004 0.012* 0.892*** ... 0.002 0.001 0.01 \n",
"PM2 -0.007 -0.0 0.791*** ... 0.001 0.002 0.012* \n",
"AM2 -0.009 -0.043*** 0.004 ... -0.009 -0.0 0.002 \n",
"EO 0.001 -0.023*** 0.007 ... -0.002 -0.005 -0.014* \n",
"PO -0.011 -0.003 -0.004 ... -0.004 -0.005 -0.008 \n",
"AO 0.002 0.003 -0.01 ... -0.002 0.005 0.007 \n",
"\n",
" AB EM2 PM2 AM2 EO PO AO \n",
"EM1 -0.0 0.001 0.002 -0.004 -0.001 -0.0 0.007 \n",
"PM1 0.0 -0.002 -0.002 -0.0 -0.004 -0.002 0.007 \n",
"AM1 -0.006 -0.005 -0.006 0.002 -0.002 0.003 0.005 \n",
"EA -0.004 -0.002 -0.003 -0.004 -0.004 -0.001 0.001 \n",
"PA -0.004 0.007 0.006 0.005 -0.0 -0.003 -0.005 \n",
"AA -0.008 0.002 0.001 -0.009 -0.002 -0.004 -0.002 \n",
"EB -0.207*** 0.001 0.002 -0.0 -0.005 -0.005 0.005 \n",
"PB 0.148*** 0.01 0.012* 0.002 -0.014* -0.008 0.007 \n",
"AB 1.0 0.009 0.006 -0.004 -0.011 -0.006 0.001 \n",
"EM2 0.009 1.0 0.835*** 0.087*** 0.006 -0.001 -0.007 \n",
"PM2 0.006 0.835*** 1.0 0.328*** 0.008 -0.003 -0.007 \n",
"AM2 -0.004 0.087*** 0.328*** 1.0 -0.002 -0.011* 0.005 \n",
"EO -0.011 0.006 0.008 -0.002 1.0 0.453*** 0.144*** \n",
"PO -0.006 -0.001 -0.003 -0.011* 0.453*** 1.0 0.291*** \n",
"AO 0.001 -0.007 -0.007 0.005 0.144*** 0.291*** 1.0 \n",
"\n",
"[15 rows x 30 columns]"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pval_cor_spearman(df_val).iloc[15:,:]"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEMA | \n",
" EPMA | \n",
" EAMA | \n",
" EEA | \n",
" EPA | \n",
" EAA | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
" EEMO | \n",
" ... | \n",
" AA | \n",
" EB | \n",
" PB | \n",
" AB | \n",
" EM2 | \n",
" PM2 | \n",
" AM2 | \n",
" EO | \n",
" PO | \n",
" AO | \n",
"
\n",
" \n",
" \n",
" \n",
" | EM1 | \n",
" 0.932*** | \n",
" 0.839*** | \n",
" -0.001 | \n",
" 0.002 | \n",
" 0.01 | \n",
" -0.005 | \n",
" 0.01 | \n",
" -0.018** | \n",
" 0.017** | \n",
" -0.017** | \n",
" ... | \n",
" -0.001 | \n",
" 0.002 | \n",
" -0.006 | \n",
" 0.001 | \n",
" -0.0 | \n",
" 0.001 | \n",
" -0.002 | \n",
" 0.0 | \n",
" 0.001 | \n",
" 0.005 | \n",
"
\n",
" \n",
" | PM1 | \n",
" 0.852*** | \n",
" 0.908*** | \n",
" 0.261*** | \n",
" 0.002 | \n",
" 0.016** | \n",
" -0.003 | \n",
" 0.007 | \n",
" -0.015* | \n",
" 0.01 | \n",
" -0.02*** | \n",
" ... | \n",
" 0.004 | \n",
" 0.003 | \n",
" -0.002 | \n",
" 0.001 | \n",
" -0.003 | \n",
" -0.002 | \n",
" 0.001 | \n",
" -0.004 | \n",
" -0.001 | \n",
" 0.007 | \n",
"
\n",
" \n",
" | AM1 | \n",
" 0.037*** | \n",
" 0.185*** | \n",
" 0.805*** | \n",
" -0.007 | \n",
" 0.007 | \n",
" -0.002 | \n",
" -0.003 | \n",
" -0.005 | \n",
" -0.003 | \n",
" -0.011 | \n",
" ... | \n",
" -0.001 | \n",
" 0.002 | \n",
" 0.005 | \n",
" -0.001 | \n",
" -0.008 | \n",
" -0.008 | \n",
" 0.002 | \n",
" -0.002 | \n",
" 0.001 | \n",
" 0.006 | \n",
"
\n",
" \n",
" | EA | \n",
" 0.004 | \n",
" 0.012* | \n",
" -0.021*** | \n",
" 0.892*** | \n",
" 0.4*** | \n",
" 0.014* | \n",
" 0.018** | \n",
" 0.024*** | \n",
" 0.001 | \n",
" 0.004 | \n",
" ... | \n",
" 0.132*** | \n",
" 0.019** | \n",
" 0.008 | \n",
" -0.005 | \n",
" 0.0 | \n",
" -0.001 | \n",
" -0.005 | \n",
" -0.007 | \n",
" -0.001 | \n",
" 0.001 | \n",
"
\n",
" \n",
" | PA | \n",
" -0.002 | \n",
" 0.015** | \n",
" -0.022*** | \n",
" 0.423*** | \n",
" 0.817*** | \n",
" 0.275*** | \n",
" 0.002 | \n",
" 0.001 | \n",
" -0.019*** | \n",
" -0.001 | \n",
" ... | \n",
" 0.338*** | \n",
" 0.007 | \n",
" 0.001 | \n",
" -0.002 | \n",
" 0.005 | \n",
" 0.004 | \n",
" 0.004 | \n",
" -0.0 | \n",
" -0.003 | \n",
" -0.002 | \n",
"
\n",
" \n",
" | AA | \n",
" -0.011* | \n",
" -0.002 | \n",
" -0.012* | \n",
" 0.102*** | \n",
" 0.191*** | \n",
" 0.609*** | \n",
" -0.007 | \n",
" -0.001 | \n",
" -0.026*** | \n",
" -0.0 | \n",
" ... | \n",
" 1.0*** | \n",
" 0.004 | \n",
" 0.004 | \n",
" -0.006 | \n",
" 0.003 | \n",
" 0.001 | \n",
" -0.007 | \n",
" -0.002 | \n",
" -0.004 | \n",
" 0.001 | \n",
"
\n",
" \n",
" | EB | \n",
" 0.029*** | \n",
" 0.034*** | \n",
" 0.048*** | \n",
" 0.006 | \n",
" 0.023*** | \n",
" 0.034*** | \n",
" 0.903*** | \n",
" 0.502*** | \n",
" -0.273*** | \n",
" 0.003 | \n",
" ... | \n",
" 0.004 | \n",
" 1.0*** | \n",
" 0.516*** | \n",
" -0.236*** | \n",
" 0.003 | \n",
" 0.003 | \n",
" -0.0 | \n",
" -0.003 | \n",
" -0.005 | \n",
" 0.006 | \n",
"
\n",
" \n",
" | PB | \n",
" 0.002 | \n",
" 0.019** | \n",
" 0.032*** | \n",
" 0.007 | \n",
" 0.009 | \n",
" 0.008 | \n",
" 0.388*** | \n",
" 0.699*** | \n",
" 0.07*** | \n",
" 0.012* | \n",
" ... | \n",
" 0.004 | \n",
" 0.516*** | \n",
" 1.0*** | \n",
" 0.2*** | \n",
" 0.012* | \n",
" 0.012* | \n",
" 0.004 | \n",
" -0.012* | \n",
" -0.008 | \n",
" 0.007 | \n",
"
\n",
" \n",
" | AB | \n",
" -0.004 | \n",
" 0.0 | \n",
" -0.012* | \n",
" 0.006 | \n",
" 0.003 | \n",
" -0.01 | \n",
" -0.276*** | \n",
" 0.147*** | \n",
" 0.727*** | \n",
" 0.01 | \n",
" ... | \n",
" -0.006 | \n",
" -0.236*** | \n",
" 0.2*** | \n",
" 1.0*** | \n",
" 0.005 | \n",
" 0.005 | \n",
" -0.004 | \n",
" -0.007 | \n",
" -0.002 | \n",
" 0.002 | \n",
"
\n",
" \n",
" | EM2 | \n",
" 0.002 | \n",
" -0.005 | \n",
" 0.01 | \n",
" 0.006 | \n",
" 0.002 | \n",
" -0.009 | \n",
" -0.003 | \n",
" 0.005 | \n",
" 0.009 | \n",
" 0.93*** | \n",
" ... | \n",
" 0.003 | \n",
" 0.003 | \n",
" 0.012* | \n",
" 0.005 | \n",
" 1.0*** | \n",
" 0.899*** | \n",
" 0.076*** | \n",
" 0.002 | \n",
" -0.001 | \n",
" -0.008 | \n",
"
\n",
" \n",
" | PM2 | \n",
" -0.0 | \n",
" -0.0 | \n",
" 0.011 | \n",
" 0.003 | \n",
" 0.001 | \n",
" -0.005 | \n",
" -0.006 | \n",
" -0.004 | \n",
" 0.0 | \n",
" 0.854*** | \n",
" ... | \n",
" 0.001 | \n",
" 0.003 | \n",
" 0.012* | \n",
" 0.005 | \n",
" 0.899*** | \n",
" 1.0*** | \n",
" 0.308*** | \n",
" 0.005 | \n",
" -0.003 | \n",
" -0.008 | \n",
"
\n",
" \n",
" | AM2 | \n",
" -0.003 | \n",
" 0.003 | \n",
" 0.01 | \n",
" -0.01 | \n",
" -0.001 | \n",
" -0.001 | \n",
" -0.012* | \n",
" -0.011 | \n",
" -0.038*** | \n",
" 0.03*** | \n",
" ... | \n",
" -0.007 | \n",
" -0.0 | \n",
" 0.004 | \n",
" -0.004 | \n",
" 0.076*** | \n",
" 0.308*** | \n",
" 1.0*** | \n",
" -0.005 | \n",
" -0.011 | \n",
" 0.006 | \n",
"
\n",
" \n",
" | EO | \n",
" -0.021*** | \n",
" -0.01 | \n",
" -0.006 | \n",
" -0.001 | \n",
" -0.013* | \n",
" -0.003 | \n",
" -0.005 | \n",
" 0.003 | \n",
" -0.026*** | \n",
" 0.006 | \n",
" ... | \n",
" -0.002 | \n",
" -0.003 | \n",
" -0.012* | \n",
" -0.007 | \n",
" 0.002 | \n",
" 0.005 | \n",
" -0.005 | \n",
" 1.0*** | \n",
" 0.461*** | \n",
" 0.135*** | \n",
"
\n",
" \n",
" | PO | \n",
" -0.013* | \n",
" -0.012* | \n",
" -0.005 | \n",
" 0.012* | \n",
" -0.016** | \n",
" -0.012* | \n",
" -0.014* | \n",
" -0.012* | \n",
" -0.004 | \n",
" -0.002 | \n",
" ... | \n",
" -0.004 | \n",
" -0.005 | \n",
" -0.008 | \n",
" -0.002 | \n",
" -0.001 | \n",
" -0.003 | \n",
" -0.011 | \n",
" 0.461*** | \n",
" 1.0*** | \n",
" 0.333*** | \n",
"
\n",
" \n",
" | AO | \n",
" 0.012* | \n",
" 0.01 | \n",
" -0.001 | \n",
" 0.012* | \n",
" 0.006 | \n",
" 0.007 | \n",
" 0.005 | \n",
" 0.005 | \n",
" 0.002 | \n",
" -0.01 | \n",
" ... | \n",
" 0.001 | \n",
" 0.006 | \n",
" 0.007 | \n",
" 0.002 | \n",
" -0.008 | \n",
" -0.008 | \n",
" 0.006 | \n",
" 0.135*** | \n",
" 0.333*** | \n",
" 1.0*** | \n",
"
\n",
" \n",
"
\n",
"
15 rows × 30 columns
\n",
"
"
],
"text/plain": [
" EEMA EPMA EAMA EEA EPA EAA EEB \\\n",
"EM1 0.932*** 0.839*** -0.001 0.002 0.01 -0.005 0.01 \n",
"PM1 0.852*** 0.908*** 0.261*** 0.002 0.016** -0.003 0.007 \n",
"AM1 0.037*** 0.185*** 0.805*** -0.007 0.007 -0.002 -0.003 \n",
"EA 0.004 0.012* -0.021*** 0.892*** 0.4*** 0.014* 0.018** \n",
"PA -0.002 0.015** -0.022*** 0.423*** 0.817*** 0.275*** 0.002 \n",
"AA -0.011* -0.002 -0.012* 0.102*** 0.191*** 0.609*** -0.007 \n",
"EB 0.029*** 0.034*** 0.048*** 0.006 0.023*** 0.034*** 0.903*** \n",
"PB 0.002 0.019** 0.032*** 0.007 0.009 0.008 0.388*** \n",
"AB -0.004 0.0 -0.012* 0.006 0.003 -0.01 -0.276*** \n",
"EM2 0.002 -0.005 0.01 0.006 0.002 -0.009 -0.003 \n",
"PM2 -0.0 -0.0 0.011 0.003 0.001 -0.005 -0.006 \n",
"AM2 -0.003 0.003 0.01 -0.01 -0.001 -0.001 -0.012* \n",
"EO -0.021*** -0.01 -0.006 -0.001 -0.013* -0.003 -0.005 \n",
"PO -0.013* -0.012* -0.005 0.012* -0.016** -0.012* -0.014* \n",
"AO 0.012* 0.01 -0.001 0.012* 0.006 0.007 0.005 \n",
"\n",
" EPB EAB EEMO ... AA EB PB \\\n",
"EM1 -0.018** 0.017** -0.017** ... -0.001 0.002 -0.006 \n",
"PM1 -0.015* 0.01 -0.02*** ... 0.004 0.003 -0.002 \n",
"AM1 -0.005 -0.003 -0.011 ... -0.001 0.002 0.005 \n",
"EA 0.024*** 0.001 0.004 ... 0.132*** 0.019** 0.008 \n",
"PA 0.001 -0.019*** -0.001 ... 0.338*** 0.007 0.001 \n",
"AA -0.001 -0.026*** -0.0 ... 1.0*** 0.004 0.004 \n",
"EB 0.502*** -0.273*** 0.003 ... 0.004 1.0*** 0.516*** \n",
"PB 0.699*** 0.07*** 0.012* ... 0.004 0.516*** 1.0*** \n",
"AB 0.147*** 0.727*** 0.01 ... -0.006 -0.236*** 0.2*** \n",
"EM2 0.005 0.009 0.93*** ... 0.003 0.003 0.012* \n",
"PM2 -0.004 0.0 0.854*** ... 0.001 0.003 0.012* \n",
"AM2 -0.011 -0.038*** 0.03*** ... -0.007 -0.0 0.004 \n",
"EO 0.003 -0.026*** 0.006 ... -0.002 -0.003 -0.012* \n",
"PO -0.012* -0.004 -0.002 ... -0.004 -0.005 -0.008 \n",
"AO 0.005 0.002 -0.01 ... 0.001 0.006 0.007 \n",
"\n",
" AB EM2 PM2 AM2 EO PO AO \n",
"EM1 0.001 -0.0 0.001 -0.002 0.0 0.001 0.005 \n",
"PM1 0.001 -0.003 -0.002 0.001 -0.004 -0.001 0.007 \n",
"AM1 -0.001 -0.008 -0.008 0.002 -0.002 0.001 0.006 \n",
"EA -0.005 0.0 -0.001 -0.005 -0.007 -0.001 0.001 \n",
"PA -0.002 0.005 0.004 0.004 -0.0 -0.003 -0.002 \n",
"AA -0.006 0.003 0.001 -0.007 -0.002 -0.004 0.001 \n",
"EB -0.236*** 0.003 0.003 -0.0 -0.003 -0.005 0.006 \n",
"PB 0.2*** 0.012* 0.012* 0.004 -0.012* -0.008 0.007 \n",
"AB 1.0*** 0.005 0.005 -0.004 -0.007 -0.002 0.002 \n",
"EM2 0.005 1.0*** 0.899*** 0.076*** 0.002 -0.001 -0.008 \n",
"PM2 0.005 0.899*** 1.0*** 0.308*** 0.005 -0.003 -0.008 \n",
"AM2 -0.004 0.076*** 0.308*** 1.0*** -0.005 -0.011 0.006 \n",
"EO -0.007 0.002 0.005 -0.005 1.0*** 0.461*** 0.135*** \n",
"PO -0.002 -0.001 -0.003 -0.011 0.461*** 1.0*** 0.333*** \n",
"AO 0.002 -0.008 -0.008 0.006 0.135*** 0.333*** 1.0*** \n",
"\n",
"[15 rows x 30 columns]"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pval_cor(df_val).iloc[15:,:]"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEMA | \n",
" EPMA | \n",
" EAMA | \n",
" EEA | \n",
" EPA | \n",
" EAA | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
" EEMO | \n",
" ... | \n",
" AA | \n",
" EB | \n",
" PB | \n",
" AB | \n",
" EM2 | \n",
" PM2 | \n",
" AM2 | \n",
" EO | \n",
" PO | \n",
" AO | \n",
"
\n",
" \n",
" \n",
" \n",
" | EM1 | \n",
" 0.932*** | \n",
" 0.839*** | \n",
" -0.001 | \n",
" 0.002 | \n",
" 0.01 | \n",
" -0.005 | \n",
" 0.01 | \n",
" -0.018** | \n",
" 0.017** | \n",
" -0.017** | \n",
" ... | \n",
" -0.001 | \n",
" 0.002 | \n",
" -0.006 | \n",
" 0.001 | \n",
" -0.0 | \n",
" 0.001 | \n",
" -0.002 | \n",
" 0.0 | \n",
" 0.001 | \n",
" 0.005 | \n",
"
\n",
" \n",
" | PM1 | \n",
" 0.852*** | \n",
" 0.908*** | \n",
" 0.261*** | \n",
" 0.002 | \n",
" 0.016** | \n",
" -0.003 | \n",
" 0.007 | \n",
" -0.015* | \n",
" 0.01 | \n",
" -0.02*** | \n",
" ... | \n",
" 0.004 | \n",
" 0.003 | \n",
" -0.002 | \n",
" 0.001 | \n",
" -0.003 | \n",
" -0.002 | \n",
" 0.001 | \n",
" -0.004 | \n",
" -0.001 | \n",
" 0.007 | \n",
"
\n",
" \n",
" | AM1 | \n",
" 0.037*** | \n",
" 0.185*** | \n",
" 0.805*** | \n",
" -0.007 | \n",
" 0.007 | \n",
" -0.002 | \n",
" -0.003 | \n",
" -0.005 | \n",
" -0.003 | \n",
" -0.011 | \n",
" ... | \n",
" -0.001 | \n",
" 0.002 | \n",
" 0.005 | \n",
" -0.001 | \n",
" -0.008 | \n",
" -0.008 | \n",
" 0.002 | \n",
" -0.002 | \n",
" 0.001 | \n",
" 0.006 | \n",
"
\n",
" \n",
" | EA | \n",
" 0.004 | \n",
" 0.012* | \n",
" -0.021*** | \n",
" 0.892*** | \n",
" 0.4*** | \n",
" 0.014* | \n",
" 0.018** | \n",
" 0.024*** | \n",
" 0.001 | \n",
" 0.004 | \n",
" ... | \n",
" 0.132*** | \n",
" 0.019** | \n",
" 0.008 | \n",
" -0.005 | \n",
" 0.0 | \n",
" -0.001 | \n",
" -0.005 | \n",
" -0.007 | \n",
" -0.001 | \n",
" 0.001 | \n",
"
\n",
" \n",
" | PA | \n",
" -0.002 | \n",
" 0.015** | \n",
" -0.022*** | \n",
" 0.423*** | \n",
" 0.817*** | \n",
" 0.275*** | \n",
" 0.002 | \n",
" 0.001 | \n",
" -0.019*** | \n",
" -0.001 | \n",
" ... | \n",
" 0.338*** | \n",
" 0.007 | \n",
" 0.001 | \n",
" -0.002 | \n",
" 0.005 | \n",
" 0.004 | \n",
" 0.004 | \n",
" -0.0 | \n",
" -0.003 | \n",
" -0.002 | \n",
"
\n",
" \n",
" | AA | \n",
" -0.011* | \n",
" -0.002 | \n",
" -0.012* | \n",
" 0.102*** | \n",
" 0.191*** | \n",
" 0.609*** | \n",
" -0.007 | \n",
" -0.001 | \n",
" -0.026*** | \n",
" -0.0 | \n",
" ... | \n",
" 1.0*** | \n",
" 0.004 | \n",
" 0.004 | \n",
" -0.006 | \n",
" 0.003 | \n",
" 0.001 | \n",
" -0.007 | \n",
" -0.002 | \n",
" -0.004 | \n",
" 0.001 | \n",
"
\n",
" \n",
" | EB | \n",
" 0.029*** | \n",
" 0.034*** | \n",
" 0.048*** | \n",
" 0.006 | \n",
" 0.023*** | \n",
" 0.034*** | \n",
" 0.903*** | \n",
" 0.502*** | \n",
" -0.273*** | \n",
" 0.003 | \n",
" ... | \n",
" 0.004 | \n",
" 1.0*** | \n",
" 0.516*** | \n",
" -0.236*** | \n",
" 0.003 | \n",
" 0.003 | \n",
" -0.0 | \n",
" -0.003 | \n",
" -0.005 | \n",
" 0.006 | \n",
"
\n",
" \n",
" | PB | \n",
" 0.002 | \n",
" 0.019** | \n",
" 0.032*** | \n",
" 0.007 | \n",
" 0.009 | \n",
" 0.008 | \n",
" 0.388*** | \n",
" 0.699*** | \n",
" 0.07*** | \n",
" 0.012* | \n",
" ... | \n",
" 0.004 | \n",
" 0.516*** | \n",
" 1.0*** | \n",
" 0.2*** | \n",
" 0.012* | \n",
" 0.012* | \n",
" 0.004 | \n",
" -0.012* | \n",
" -0.008 | \n",
" 0.007 | \n",
"
\n",
" \n",
" | AB | \n",
" -0.004 | \n",
" 0.0 | \n",
" -0.012* | \n",
" 0.006 | \n",
" 0.003 | \n",
" -0.01 | \n",
" -0.276*** | \n",
" 0.147*** | \n",
" 0.727*** | \n",
" 0.01 | \n",
" ... | \n",
" -0.006 | \n",
" -0.236*** | \n",
" 0.2*** | \n",
" 1.0*** | \n",
" 0.005 | \n",
" 0.005 | \n",
" -0.004 | \n",
" -0.007 | \n",
" -0.002 | \n",
" 0.002 | \n",
"
\n",
" \n",
" | EM2 | \n",
" 0.002 | \n",
" -0.005 | \n",
" 0.01 | \n",
" 0.006 | \n",
" 0.002 | \n",
" -0.009 | \n",
" -0.003 | \n",
" 0.005 | \n",
" 0.009 | \n",
" 0.93*** | \n",
" ... | \n",
" 0.003 | \n",
" 0.003 | \n",
" 0.012* | \n",
" 0.005 | \n",
" 1.0*** | \n",
" 0.899*** | \n",
" 0.076*** | \n",
" 0.002 | \n",
" -0.001 | \n",
" -0.008 | \n",
"
\n",
" \n",
" | PM2 | \n",
" -0.0 | \n",
" -0.0 | \n",
" 0.011 | \n",
" 0.003 | \n",
" 0.001 | \n",
" -0.005 | \n",
" -0.006 | \n",
" -0.004 | \n",
" 0.0 | \n",
" 0.854*** | \n",
" ... | \n",
" 0.001 | \n",
" 0.003 | \n",
" 0.012* | \n",
" 0.005 | \n",
" 0.899*** | \n",
" 1.0*** | \n",
" 0.308*** | \n",
" 0.005 | \n",
" -0.003 | \n",
" -0.008 | \n",
"
\n",
" \n",
" | AM2 | \n",
" -0.003 | \n",
" 0.003 | \n",
" 0.01 | \n",
" -0.01 | \n",
" -0.001 | \n",
" -0.001 | \n",
" -0.012* | \n",
" -0.011 | \n",
" -0.038*** | \n",
" 0.03*** | \n",
" ... | \n",
" -0.007 | \n",
" -0.0 | \n",
" 0.004 | \n",
" -0.004 | \n",
" 0.076*** | \n",
" 0.308*** | \n",
" 1.0*** | \n",
" -0.005 | \n",
" -0.011 | \n",
" 0.006 | \n",
"
\n",
" \n",
" | EO | \n",
" -0.021*** | \n",
" -0.01 | \n",
" -0.006 | \n",
" -0.001 | \n",
" -0.013* | \n",
" -0.003 | \n",
" -0.005 | \n",
" 0.003 | \n",
" -0.026*** | \n",
" 0.006 | \n",
" ... | \n",
" -0.002 | \n",
" -0.003 | \n",
" -0.012* | \n",
" -0.007 | \n",
" 0.002 | \n",
" 0.005 | \n",
" -0.005 | \n",
" 1.0*** | \n",
" 0.461*** | \n",
" 0.135*** | \n",
"
\n",
" \n",
" | PO | \n",
" -0.013* | \n",
" -0.012* | \n",
" -0.005 | \n",
" 0.012* | \n",
" -0.016** | \n",
" -0.012* | \n",
" -0.014* | \n",
" -0.012* | \n",
" -0.004 | \n",
" -0.002 | \n",
" ... | \n",
" -0.004 | \n",
" -0.005 | \n",
" -0.008 | \n",
" -0.002 | \n",
" -0.001 | \n",
" -0.003 | \n",
" -0.011 | \n",
" 0.461*** | \n",
" 1.0*** | \n",
" 0.333*** | \n",
"
\n",
" \n",
" | AO | \n",
" 0.012* | \n",
" 0.01 | \n",
" -0.001 | \n",
" 0.012* | \n",
" 0.006 | \n",
" 0.007 | \n",
" 0.005 | \n",
" 0.005 | \n",
" 0.002 | \n",
" -0.01 | \n",
" ... | \n",
" 0.001 | \n",
" 0.006 | \n",
" 0.007 | \n",
" 0.002 | \n",
" -0.008 | \n",
" -0.008 | \n",
" 0.006 | \n",
" 0.135*** | \n",
" 0.333*** | \n",
" 1.0*** | \n",
"
\n",
" \n",
"
\n",
"
15 rows × 30 columns
\n",
"
"
],
"text/plain": [
" EEMA EPMA EAMA EEA EPA EAA EEB \\\n",
"EM1 0.932*** 0.839*** -0.001 0.002 0.01 -0.005 0.01 \n",
"PM1 0.852*** 0.908*** 0.261*** 0.002 0.016** -0.003 0.007 \n",
"AM1 0.037*** 0.185*** 0.805*** -0.007 0.007 -0.002 -0.003 \n",
"EA 0.004 0.012* -0.021*** 0.892*** 0.4*** 0.014* 0.018** \n",
"PA -0.002 0.015** -0.022*** 0.423*** 0.817*** 0.275*** 0.002 \n",
"AA -0.011* -0.002 -0.012* 0.102*** 0.191*** 0.609*** -0.007 \n",
"EB 0.029*** 0.034*** 0.048*** 0.006 0.023*** 0.034*** 0.903*** \n",
"PB 0.002 0.019** 0.032*** 0.007 0.009 0.008 0.388*** \n",
"AB -0.004 0.0 -0.012* 0.006 0.003 -0.01 -0.276*** \n",
"EM2 0.002 -0.005 0.01 0.006 0.002 -0.009 -0.003 \n",
"PM2 -0.0 -0.0 0.011 0.003 0.001 -0.005 -0.006 \n",
"AM2 -0.003 0.003 0.01 -0.01 -0.001 -0.001 -0.012* \n",
"EO -0.021*** -0.01 -0.006 -0.001 -0.013* -0.003 -0.005 \n",
"PO -0.013* -0.012* -0.005 0.012* -0.016** -0.012* -0.014* \n",
"AO 0.012* 0.01 -0.001 0.012* 0.006 0.007 0.005 \n",
"\n",
" EPB EAB EEMO ... AA EB PB \\\n",
"EM1 -0.018** 0.017** -0.017** ... -0.001 0.002 -0.006 \n",
"PM1 -0.015* 0.01 -0.02*** ... 0.004 0.003 -0.002 \n",
"AM1 -0.005 -0.003 -0.011 ... -0.001 0.002 0.005 \n",
"EA 0.024*** 0.001 0.004 ... 0.132*** 0.019** 0.008 \n",
"PA 0.001 -0.019*** -0.001 ... 0.338*** 0.007 0.001 \n",
"AA -0.001 -0.026*** -0.0 ... 1.0*** 0.004 0.004 \n",
"EB 0.502*** -0.273*** 0.003 ... 0.004 1.0*** 0.516*** \n",
"PB 0.699*** 0.07*** 0.012* ... 0.004 0.516*** 1.0*** \n",
"AB 0.147*** 0.727*** 0.01 ... -0.006 -0.236*** 0.2*** \n",
"EM2 0.005 0.009 0.93*** ... 0.003 0.003 0.012* \n",
"PM2 -0.004 0.0 0.854*** ... 0.001 0.003 0.012* \n",
"AM2 -0.011 -0.038*** 0.03*** ... -0.007 -0.0 0.004 \n",
"EO 0.003 -0.026*** 0.006 ... -0.002 -0.003 -0.012* \n",
"PO -0.012* -0.004 -0.002 ... -0.004 -0.005 -0.008 \n",
"AO 0.005 0.002 -0.01 ... 0.001 0.006 0.007 \n",
"\n",
" AB EM2 PM2 AM2 EO PO AO \n",
"EM1 0.001 -0.0 0.001 -0.002 0.0 0.001 0.005 \n",
"PM1 0.001 -0.003 -0.002 0.001 -0.004 -0.001 0.007 \n",
"AM1 -0.001 -0.008 -0.008 0.002 -0.002 0.001 0.006 \n",
"EA -0.005 0.0 -0.001 -0.005 -0.007 -0.001 0.001 \n",
"PA -0.002 0.005 0.004 0.004 -0.0 -0.003 -0.002 \n",
"AA -0.006 0.003 0.001 -0.007 -0.002 -0.004 0.001 \n",
"EB -0.236*** 0.003 0.003 -0.0 -0.003 -0.005 0.006 \n",
"PB 0.2*** 0.012* 0.012* 0.004 -0.012* -0.008 0.007 \n",
"AB 1.0*** 0.005 0.005 -0.004 -0.007 -0.002 0.002 \n",
"EM2 0.005 1.0*** 0.899*** 0.076*** 0.002 -0.001 -0.008 \n",
"PM2 0.005 0.899*** 1.0*** 0.308*** 0.005 -0.003 -0.008 \n",
"AM2 -0.004 0.076*** 0.308*** 1.0*** -0.005 -0.011 0.006 \n",
"EO -0.007 0.002 0.005 -0.005 1.0*** 0.461*** 0.135*** \n",
"PO -0.002 -0.001 -0.003 -0.011 0.461*** 1.0*** 0.333*** \n",
"AO 0.002 -0.008 -0.008 0.006 0.135*** 0.333*** 1.0*** \n",
"\n",
"[15 rows x 30 columns]"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pval_cor(df_val).iloc[15:,:]"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_20012\\3209831563.py:3: FutureWarning: The default value of numeric_only in DataFrame.corr is deprecated. In a future version, it will default to False. Select only valid columns or specify the value of numeric_only to silence this warning.\n",
" rho = df.corr()\n",
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_20012\\3209831563.py:4: FutureWarning: The default value of numeric_only in DataFrame.corr is deprecated. In a future version, it will default to False. Select only valid columns or specify the value of numeric_only to silence this warning.\n",
" pval = df.corr(method=lambda x, y: pearsonr(x, y)[1]) - np.eye(*rho.shape)\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EE | \n",
" EP | \n",
" EA | \n",
" E | \n",
" P | \n",
" A | \n",
"
\n",
" \n",
" \n",
" \n",
" | EE | \n",
" 1.0*** | \n",
" 0.89*** | \n",
" -0.021*** | \n",
" 0.93*** | \n",
" 0.853*** | \n",
" 0.033*** | \n",
"
\n",
" \n",
" | EP | \n",
" 0.89*** | \n",
" 1.0*** | \n",
" 0.213*** | \n",
" 0.831*** | \n",
" 0.905*** | \n",
" 0.185*** | \n",
"
\n",
" \n",
" | EA | \n",
" -0.021*** | \n",
" 0.213*** | \n",
" 1.0*** | \n",
" -0.007 | \n",
" 0.257*** | \n",
" 0.805*** | \n",
"
\n",
" \n",
" | E | \n",
" 0.93*** | \n",
" 0.831*** | \n",
" -0.007 | \n",
" 1.0*** | \n",
" 0.9*** | \n",
" 0.078*** | \n",
"
\n",
" \n",
" | P | \n",
" 0.853*** | \n",
" 0.905*** | \n",
" 0.257*** | \n",
" 0.9*** | \n",
" 1.0*** | \n",
" 0.309*** | \n",
"
\n",
" \n",
" | A | \n",
" 0.033*** | \n",
" 0.185*** | \n",
" 0.805*** | \n",
" 0.078*** | \n",
" 0.309*** | \n",
" 1.0*** | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EE EP EA E P A\n",
"EE 1.0*** 0.89*** -0.021*** 0.93*** 0.853*** 0.033***\n",
"EP 0.89*** 1.0*** 0.213*** 0.831*** 0.905*** 0.185***\n",
"EA -0.021*** 0.213*** 1.0*** -0.007 0.257*** 0.805***\n",
"E 0.93*** 0.831*** -0.007 1.0*** 0.9*** 0.078***\n",
"P 0.853*** 0.905*** 0.257*** 0.9*** 1.0*** 0.309***\n",
"A 0.033*** 0.185*** 0.805*** 0.078*** 0.309*** 1.0***"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_mod_v=pd.concat([df_val[['EEMA', 'EPMA', 'EAMA', 'ModA','EM1', 'PM1', 'AM1']].rename(columns={'EM1':'E', 'PM1':'P', 'AM1':'A','EEMA':'EE', 'EPMA':'EP', 'EAMA':'EA','ModA': \"item\"})\n",
",df_val[['EEMO', 'EPMO', 'EAMO', 'ModO','EM2', 'PM2', 'AM2']].rename(columns={'EM2':'E', 'PM2':'P', 'AM2':'A','EEMO':'EE', 'EPMO':'EP', 'EAMO':'EA','ModO': \"item\"})],axis=0)\n",
"pval_cor(df_mod_v)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Modifier_MAE : [0.6, 0.44, 0.54]\n",
"Modifier_RMSE : [0.78, 0.58, 0.71]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_20012\\3209831563.py:3: FutureWarning: The default value of numeric_only in DataFrame.corr is deprecated. In a future version, it will default to False. Select only valid columns or specify the value of numeric_only to silence this warning.\n",
" rho = df.corr()\n",
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_20012\\3209831563.py:4: FutureWarning: The default value of numeric_only in DataFrame.corr is deprecated. In a future version, it will default to False. Select only valid columns or specify the value of numeric_only to silence this warning.\n",
" pval = df.corr(method=lambda x, y: pearsonr(x, y)[1]) - np.eye(*rho.shape)\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EE | \n",
" EP | \n",
" EA | \n",
" E | \n",
" P | \n",
" A | \n",
"
\n",
" \n",
" \n",
" \n",
" | EE | \n",
" 1.0*** | \n",
" 0.895*** | \n",
" -0.017 | \n",
" 0.936*** | \n",
" 0.858*** | \n",
" 0.038 | \n",
"
\n",
" \n",
" | EP | \n",
" 0.895*** | \n",
" 1.0*** | \n",
" 0.219 | \n",
" 0.838*** | \n",
" 0.912*** | \n",
" 0.19 | \n",
"
\n",
" \n",
" | EA | \n",
" -0.017 | \n",
" 0.219 | \n",
" 1.0*** | \n",
" -0.001 | \n",
" 0.264* | \n",
" 0.812*** | \n",
"
\n",
" \n",
" | E | \n",
" 0.936*** | \n",
" 0.838*** | \n",
" -0.001 | \n",
" 1.0*** | \n",
" 0.901*** | \n",
" 0.085 | \n",
"
\n",
" \n",
" | P | \n",
" 0.858*** | \n",
" 0.912*** | \n",
" 0.264* | \n",
" 0.901*** | \n",
" 1.0*** | \n",
" 0.313** | \n",
"
\n",
" \n",
" | A | \n",
" 0.038 | \n",
" 0.19 | \n",
" 0.812*** | \n",
" 0.085 | \n",
" 0.313** | \n",
" 1.0*** | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EE EP EA E P A\n",
"EE 1.0*** 0.895*** -0.017 0.936*** 0.858*** 0.038\n",
"EP 0.895*** 1.0*** 0.219 0.838*** 0.912*** 0.19\n",
"EA -0.017 0.219 1.0*** -0.001 0.264* 0.812***\n",
"E 0.936*** 0.838*** -0.001 1.0*** 0.901*** 0.085\n",
"P 0.858*** 0.912*** 0.264* 0.901*** 1.0*** 0.313**\n",
"A 0.038 0.19 0.812*** 0.085 0.313** 1.0***"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": []
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Modifier_MAE : [0.6, 0.44, 0.54]\n",
"Modifier_RMSE : [0.78, 0.58, 0.71]\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EE | \n",
" EP | \n",
" EA | \n",
" E | \n",
" P | \n",
" A | \n",
"
\n",
" \n",
" \n",
" \n",
" | EE | \n",
" 1.0*** | \n",
" 0.895*** | \n",
" -0.017 | \n",
" 0.936*** | \n",
" 0.858*** | \n",
" 0.038 | \n",
"
\n",
" \n",
" | EP | \n",
" 0.895*** | \n",
" 1.0*** | \n",
" 0.219 | \n",
" 0.838*** | \n",
" 0.912*** | \n",
" 0.19 | \n",
"
\n",
" \n",
" | EA | \n",
" -0.017 | \n",
" 0.219 | \n",
" 1.0*** | \n",
" -0.001 | \n",
" 0.264* | \n",
" 0.812*** | \n",
"
\n",
" \n",
" | E | \n",
" 0.936*** | \n",
" 0.838*** | \n",
" -0.001 | \n",
" 1.0*** | \n",
" 0.901*** | \n",
" 0.085 | \n",
"
\n",
" \n",
" | P | \n",
" 0.858*** | \n",
" 0.912*** | \n",
" 0.264* | \n",
" 0.901*** | \n",
" 1.0*** | \n",
" 0.313** | \n",
"
\n",
" \n",
" | A | \n",
" 0.038 | \n",
" 0.19 | \n",
" 0.812*** | \n",
" 0.085 | \n",
" 0.313** | \n",
" 1.0*** | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EE EP EA E P A\n",
"EE 1.0*** 0.895*** -0.017 0.936*** 0.858*** 0.038\n",
"EP 0.895*** 1.0*** 0.219 0.838*** 0.912*** 0.19\n",
"EA -0.017 0.219 1.0*** -0.001 0.264* 0.812***\n",
"E 0.936*** 0.838*** -0.001 1.0*** 0.901*** 0.085\n",
"P 0.858*** 0.912*** 0.264* 0.901*** 1.0*** 0.313**\n",
"A 0.038 0.19 0.812*** 0.085 0.313** 1.0***"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_mod_v2=df_mod_v.groupby('item').mean()\n",
"diff=np.absolute(np.array(df_mod_v2[['EE','EP', 'EA']])- np.array(df_mod_v2[['E', 'P', 'A']]))\n",
"print('Modifier_MAE :',list(pd.DataFrame(diff).mean().round(decimals=2)\n",
"))\n",
"print('Modifier_RMSE :',list(pd.DataFrame(np.sqrt(np.mean(diff**2,axis=0))).round(decimals=2)[0]))\n",
"\n",
"pval_cor(df_mod_v2.reset_index())"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"# df_beh_austin=pd.read_pickle('D:/Box/Emoji paper/Austin/behavior.pkl')\n",
"# beh_dict=dict(zip(Behaviors.term, Behaviors.orig_term))\n",
"# austin_list_beh=list(df_beh_austin.item)\n",
"# df_beh=df_val[['EEB', 'EPB', 'EAB','EB', 'PB', 'AB','Behavior']].groupby('Behavior').mean().reset_index()\n",
"# df_beh=df_beh.replace({'Behavior': beh_dict})\n",
"# # df_beh=df_beh.iloc[np.where(df_beh[\"Behavior\"].isin(austin_list_beh))]\n",
"# # df_beh=df_beh.merge(df_beh_austin, left_on='Behavior', right_on='item')\n",
"# pval_cor(df_beh)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"# df_beh_austin=pd.read_pickle('D:/Box/Emoji paper/Austin/behavior.pkl')\n",
"# beh_dict=dict(zip(Behaviors.term, Behaviors.orig_term))\n",
"# austin_list_beh=list(df_beh_austin.item)\n",
"# df_beh=df_val[['EEB', 'EPB', 'EAB','EB', 'PB', 'AB','Behavior']].groupby('Behavior').mean().reset_index()\n",
"# df_beh=df_beh.replace({'Behavior': beh_dict})\n",
"# # df_beh=df_beh.iloc[np.where(df_beh[\"Behavior\"].isin(austin_list_beh))]\n",
"# # df_beh=df_beh.merge(df_beh_austin, left_on='Behavior', right_on='item')\n",
"# pval_cor(df_beh)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EE | \n",
" EP | \n",
" EA | \n",
" E | \n",
" P | \n",
" A | \n",
"
\n",
" \n",
" \n",
" \n",
" | EE | \n",
" 1.0*** | \n",
" 0.521*** | \n",
" -0.288** | \n",
" 0.91*** | \n",
" 0.396*** | \n",
" -0.279** | \n",
"
\n",
" \n",
" | EP | \n",
" 0.521*** | \n",
" 1.0*** | \n",
" 0.105 | \n",
" 0.507*** | \n",
" 0.707*** | \n",
" 0.141 | \n",
"
\n",
" \n",
" | EA | \n",
" -0.288** | \n",
" 0.105 | \n",
" 1.0*** | \n",
" -0.276** | \n",
" 0.059 | \n",
" 0.731*** | \n",
"
\n",
" \n",
" | E | \n",
" 0.91*** | \n",
" 0.507*** | \n",
" -0.276** | \n",
" 1.0*** | \n",
" 0.521*** | \n",
" -0.237* | \n",
"
\n",
" \n",
" | P | \n",
" 0.396*** | \n",
" 0.707*** | \n",
" 0.059 | \n",
" 0.521*** | \n",
" 1.0*** | \n",
" 0.191 | \n",
"
\n",
" \n",
" | A | \n",
" -0.279** | \n",
" 0.141 | \n",
" 0.731*** | \n",
" -0.237* | \n",
" 0.191 | \n",
" 1.0*** | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EE EP EA E P A\n",
"EE 1.0*** 0.521*** -0.288** 0.91*** 0.396*** -0.279**\n",
"EP 0.521*** 1.0*** 0.105 0.507*** 0.707*** 0.141\n",
"EA -0.288** 0.105 1.0*** -0.276** 0.059 0.731***\n",
"E 0.91*** 0.507*** -0.276** 1.0*** 0.521*** -0.237*\n",
"P 0.396*** 0.707*** 0.059 0.521*** 1.0*** 0.191\n",
"A -0.279** 0.141 0.731*** -0.237* 0.191 1.0***"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_beh_v3=df_val[['EEB', 'EPB', 'EAB', 'Behavior','EB', 'PB', 'AB']].rename(columns={\n",
" 'EB':'E', 'PB':'P', 'AB':'A','EEB':'EE', 'EPB':'EP', 'EAB':'EA','Behavior': \"item\"}).groupby('item').mean()\n",
"pval_cor(df_beh_v3)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Beh_MAE : [0.7, 0.47, 0.48]\n",
"Beh_RMSE : [0.87, 0.62, 0.61]\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" E | \n",
" P | \n",
" A | \n",
"
\n",
" \n",
" \n",
" \n",
" | EE | \n",
" 0.91*** | \n",
" 0.396*** | \n",
" -0.279** | \n",
"
\n",
" \n",
" | EP | \n",
" 0.507*** | \n",
" 0.707*** | \n",
" 0.141 | \n",
"
\n",
" \n",
" | EA | \n",
" -0.276** | \n",
" 0.059 | \n",
" 0.731*** | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" E P A\n",
"EE 0.91*** 0.396*** -0.279**\n",
"EP 0.507*** 0.707*** 0.141\n",
"EA -0.276** 0.059 0.731***"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"diff=np.absolute(np.array(df_beh_v3[['EE','EP', 'EA']])- np.array(df_beh_v3[['E', 'P', 'A']]))\n",
"print('Beh_MAE :',list(pd.DataFrame(diff).mean().round(decimals=2)))\n",
"print('Beh_RMSE :',list(pd.DataFrame(np.sqrt(np.mean(diff**2,axis=0))).round(decimals=2)[0]))\n",
"\n",
"pval_cor(df_beh_v3).iloc[:3,3:7]"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_20012\\3209831563.py:3: FutureWarning: The default value of numeric_only in DataFrame.corr is deprecated. In a future version, it will default to False. Select only valid columns or specify the value of numeric_only to silence this warning.\n",
" rho = df.corr()\n",
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_20012\\3209831563.py:4: FutureWarning: The default value of numeric_only in DataFrame.corr is deprecated. In a future version, it will default to False. Select only valid columns or specify the value of numeric_only to silence this warning.\n",
" pval = df.corr(method=lambda x, y: pearsonr(x, y)[1]) - np.eye(*rho.shape)\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EE | \n",
" EP | \n",
" EA | \n",
" E | \n",
" P | \n",
" A | \n",
"
\n",
" \n",
" \n",
" \n",
" | EE | \n",
" 1.0*** | \n",
" 0.439*** | \n",
" 0.05*** | \n",
" 0.886*** | \n",
" 0.415*** | \n",
" 0.095*** | \n",
"
\n",
" \n",
" | EP | \n",
" 0.439*** | \n",
" 1.0*** | \n",
" 0.244*** | \n",
" 0.385*** | \n",
" 0.807*** | \n",
" 0.19*** | \n",
"
\n",
" \n",
" | EA | \n",
" 0.05*** | \n",
" 0.244*** | \n",
" 1.0*** | \n",
" 0.009* | \n",
" 0.261*** | \n",
" 0.598*** | \n",
"
\n",
" \n",
" | E | \n",
" 0.886*** | \n",
" 0.385*** | \n",
" 0.009* | \n",
" 1.0*** | \n",
" 0.458*** | \n",
" 0.133*** | \n",
"
\n",
" \n",
" | P | \n",
" 0.415*** | \n",
" 0.807*** | \n",
" 0.261*** | \n",
" 0.458*** | \n",
" 1.0*** | \n",
" 0.336*** | \n",
"
\n",
" \n",
" | A | \n",
" 0.095*** | \n",
" 0.19*** | \n",
" 0.598*** | \n",
" 0.133*** | \n",
" 0.336*** | \n",
" 1.0*** | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EE EP EA E P A\n",
"EE 1.0*** 0.439*** 0.05*** 0.886*** 0.415*** 0.095***\n",
"EP 0.439*** 1.0*** 0.244*** 0.385*** 0.807*** 0.19***\n",
"EA 0.05*** 0.244*** 1.0*** 0.009* 0.261*** 0.598***\n",
"E 0.886*** 0.385*** 0.009* 1.0*** 0.458*** 0.133***\n",
"P 0.415*** 0.807*** 0.261*** 0.458*** 1.0*** 0.336***\n",
"A 0.095*** 0.19*** 0.598*** 0.133*** 0.336*** 1.0***"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_ident_v=pd.concat([df_val[['EEA', 'EPA', 'EAA', 'Actor','EA', 'PA', 'AA']].rename(columns={'EA':'E', 'PA':'P', 'AA':'A','EEA':'EE', 'EPA':'EP', 'EAA':'EA',\"Actor\": \"item\"})\n",
",df_val[['EEO', 'EPO', 'EAO', 'Object','EO', 'PO', 'AO']].rename(columns={'EO':'E', 'PO':'P', 'AO':'A','EEO':'EE', 'EPO':'EP', 'EAO':'EA',\"Object\": \"item\"})],axis=0)\n",
"pval_cor(df_ident_v)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Identity_MAE : [0.53, 0.54, 0.56]\n",
"Identity_RMSE : [0.73, 0.7, 0.72]\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EE | \n",
" EP | \n",
" EA | \n",
" E | \n",
" P | \n",
" A | \n",
"
\n",
" \n",
" \n",
" \n",
" | EE | \n",
" 1.0*** | \n",
" 0.448*** | \n",
" 0.044 | \n",
" 0.894*** | \n",
" 0.423*** | \n",
" 0.094 | \n",
"
\n",
" \n",
" | EP | \n",
" 0.448*** | \n",
" 1.0*** | \n",
" 0.242* | \n",
" 0.394*** | \n",
" 0.816*** | \n",
" 0.186 | \n",
"
\n",
" \n",
" | EA | \n",
" 0.044 | \n",
" 0.242* | \n",
" 1.0*** | \n",
" 0.009 | \n",
" 0.269** | \n",
" 0.611*** | \n",
"
\n",
" \n",
" | E | \n",
" 0.894*** | \n",
" 0.394*** | \n",
" 0.009 | \n",
" 1.0*** | \n",
" 0.462*** | \n",
" 0.133 | \n",
"
\n",
" \n",
" | P | \n",
" 0.423*** | \n",
" 0.816*** | \n",
" 0.269** | \n",
" 0.462*** | \n",
" 1.0*** | \n",
" 0.334*** | \n",
"
\n",
" \n",
" | A | \n",
" 0.094 | \n",
" 0.186 | \n",
" 0.611*** | \n",
" 0.133 | \n",
" 0.334*** | \n",
" 1.0*** | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EE EP EA E P A\n",
"EE 1.0*** 0.448*** 0.044 0.894*** 0.423*** 0.094\n",
"EP 0.448*** 1.0*** 0.242* 0.394*** 0.816*** 0.186\n",
"EA 0.044 0.242* 1.0*** 0.009 0.269** 0.611***\n",
"E 0.894*** 0.394*** 0.009 1.0*** 0.462*** 0.133\n",
"P 0.423*** 0.816*** 0.269** 0.462*** 1.0*** 0.334***\n",
"A 0.094 0.186 0.611*** 0.133 0.334*** 1.0***"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_ident_v2=df_ident_v.groupby('item').mean()\n",
"diff=np.absolute(np.array(df_ident_v2[['EE','EP', 'EA']])- np.array(df_ident_v2[['E', 'P', 'A']]))\n",
"print('Identity_MAE :',list(pd.DataFrame(diff).mean().round(decimals=2)\n",
"))\n",
"print('Identity_RMSE :',list(pd.DataFrame(np.sqrt(np.mean(diff**2,axis=0))).round(decimals=2)[0]))\n",
"\n",
"pval_cor(df_ident_v2.groupby('item').mean())"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"# df_ident_v.loc[df_ident_v.item=='worker',:].describe().round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"df_tst=get_output(I_b=n_I_test,B_b=n_B_test,M_b=n_M_test,batch_sz=3000,batch_num=10)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"df_trn=get_output(I_b=n_I_train,B_b=n_B_train,M_b=n_M_train,batch_sz=3000,batch_num=10)\n",
"# df_trn.corr().iloc[15:,:].round(decimals=2)\n",
"\n",
"# q=dta_ldr(I=n_I_train,B=n_B_train,M=n_M_train,batch_size=3200)\n",
"# preds = bert_predict(bert_regressor, q)\n",
"# np.mean(np.absolute(scaler_I.inverse_transform(preds.cpu())-scaler_I.inverse_transform(q[2])),axis=0).round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" index_in_dic term E P A E2 P2 A2 term2 len_Bert \\\n",
"499 499 judge 1.15 2.53 -0.22 1.15 2.53 -0.22 judge 1 \n",
"\n",
" cluster \n",
"499 3 index_in_dic orig_term term E P A E2 P2 A2 \\\n",
"517 517 judge judges -1.83 0.71 0.07 -1.83 0.71 0.07 \n",
"\n",
" term2 len_Bert cluster \n",
"517 judges 1 0 \n"
]
}
],
"source": [
"print(Identities.loc[Identities.term=='judge',].round(decimals=2),\n",
" Behaviors.loc[Behaviors.term=='judges',].round(decimals=2))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### New words"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"def gen_new(Identity,Behavior,Modifier,n_df,word_type):\n",
" if word_type=='identity':\n",
" ident1=n_df.sample(axis = 0,random_state=56)\n",
" else:ident1=Identity.sample(axis = 0,random_state=6)\n",
" ident2=Identity.sample(axis = 0,random_state=6)\n",
" if word_type=='behavior':\n",
" behav=n_df.sample(axis = 0,random_state=5)\n",
" else: behav=Behavior.sample(axis = 0,random_state=5)\n",
" if word_type=='modifier':\n",
" modif1=n_df.sample(axis = 0,random_state=55)\n",
" else: modif1=Modifier.sample(axis = 0)\n",
" modif2=Modifier.sample(axis = 0,random_state=96)\n",
" id1=list(ident1.term)\n",
" id2=list(ident2.term)\n",
" beh=list(behav.term)\n",
" mod1=list(modif1.term)\n",
" mod2=list(modif2.term)\n",
"# wrdvc_ident1=gs_model.get_vector((list(ident1.trm_org))[0], norm=True)\n",
" sents=' '.join(map(str, (mod1+id1+beh+mod2+id2)))\n",
" values=np.concatenate([(modif1[['E','P','A']]).to_numpy(),\n",
" (ident1[['E','P','A']]).to_numpy(),\n",
" (behav[['E','P','A']]).to_numpy(),\n",
" (modif2[['E','P','A']]).to_numpy(),\n",
" (ident2[['E','P','A']]).to_numpy()], axis=1)[0]\n",
"# print(values)\n",
" #indexx=[(ident1['index_in_dic']).to_numpy()][0][0]\n",
" indexx=torch.tensor([[(modif1['index_in_dic']).to_numpy()][0][0],\n",
" [(ident1['index_in_dic']).to_numpy()][0][0],\n",
" [(behav['index_in_dic']).to_numpy()][0][0],\n",
" [(modif2['index_in_dic']).to_numpy()][0][0],\n",
" [(ident2['index_in_dic']).to_numpy()][0][0]])\n",
" ys= torch.tensor(values)\n",
"\n",
"\n",
" inputs, masks = preprocessing_for_bert([sents])\n",
"# data=TensorDataset(inputs, masks, ys)\n",
" \n",
" yield inputs, masks, ys,indexx #torch.tensor(sents),\n",
"def ldr_new(I,B,M,N_df,WT,batch_size=32):\n",
" dt_ldr= [x for x in DataLoader([next(gen_new(I,B,M,N_df,WT)) for x in range(batch_size)], batch_size=batch_size)][0]\n",
" return(dt_ldr)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use the whole dataset"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"def gen_new(Identity,Behavior,Modifier,n_df,word_type):\n",
"\n",
" modif1,modif2,ident1,ident2,behav=Modifier.sample(axis = 0),Modifier.sample(axis = 0),Identity.sample(axis = 0),Identity.sample(axis = 0),Behavior.sample(axis = 0)\n",
"\n",
" if word_type=='identity': ident1=n_df.sample(axis = 0)\n",
" if word_type=='behavior': behav=n_df.sample(axis = 0)\n",
" if word_type=='modifier': modif1=n_df.sample(axis = 0)\n",
" \n",
" id1,id2,beh,mod1,mod2=list(ident1.term),list(ident2.term),list(behav.term),list(modif1.term),list(modif2.term)\n",
"\n",
"# wrdvc_ident1=gs_model.get_vector((list(ident1.trm_org))[0], norm=True)\n",
" sents=' '.join(map(str, (mod1+id1+beh+mod2+id2)))\n",
" values=np.concatenate([(modif1[['E','P','A']]).to_numpy(),\n",
" (ident1[['E','P','A']]).to_numpy(),\n",
" (behav[['E','P','A']]).to_numpy(),\n",
" (modif2[['E','P','A']]).to_numpy(),\n",
" (ident2[['E','P','A']]).to_numpy()], axis=1)[0]\n",
"# print(values)\n",
" #indexx=[(ident1['index_in_dic']).to_numpy()][0][0]\n",
" indexx=torch.tensor([[(modif1['index_in_dic']).to_numpy()][0][0],\n",
" [(ident1['index_in_dic']).to_numpy()][0][0],\n",
" [(behav['index_in_dic']).to_numpy()][0][0],\n",
" [(modif2['index_in_dic']).to_numpy()][0][0],\n",
" [(ident2['index_in_dic']).to_numpy()][0][0]])\n",
" ys= torch.tensor(values)\n",
" inputs, masks = preprocessing_for_bert([sents])\n",
"# data=TensorDataset(inputs, masks, ys)\n",
" yield inputs, masks, ys,indexx #torch.tensor(sents),\n",
"\n",
"\n",
"def gen_alt(Identity,Behavior,Modifier,n_df,word_type):\n",
"\n",
" modif1,modif2,ident1,ident2,behav=Modifier.sample(axis = 0),Modifier.sample(axis = 0),Identity.sample(axis = 0),Identity.sample(axis = 0),Behavior.sample(axis = 0)\n",
" if word_type=='identity': ident2=n_df.sample(axis = 0)\n",
" if word_type=='behavior': behav=n_df.sample(axis = 0)\n",
" if word_type=='modifier': modif2=n_df.sample(axis = 0)\n",
" \n",
" id1,id2,beh,mod1,mod2=list(ident1.term),list(ident2.term),list(behav.term),list(modif1.term),list(modif2.term)\n",
" sents=' '.join(map(str, (mod1+id1+beh+mod2+id2)))\n",
" values=np.concatenate([(modif1[['E','P','A']]).to_numpy(),\n",
" (ident1[['E','P','A']]).to_numpy(),\n",
" (behav[['E','P','A']]).to_numpy(),\n",
" (modif2[['E','P','A']]).to_numpy(),\n",
" (ident2[['E','P','A']]).to_numpy()], axis=1)[0]\n",
" indexx=torch.tensor([[(modif1['index_in_dic']).to_numpy()][0][0],\n",
" [(ident1['index_in_dic']).to_numpy()][0][0],\n",
" [(behav['index_in_dic']).to_numpy()][0][0],\n",
" [(modif2['index_in_dic']).to_numpy()][0][0],\n",
" [(ident2['index_in_dic']).to_numpy()][0][0]])\n",
" ys= torch.tensor(values)\n",
" inputs, masks = preprocessing_for_bert([sents])\n",
"\n",
" yield inputs, masks, ys,indexx #torch.tensor(sents),\n",
" \n",
" \n",
" \n",
"def ldr_new(I,B,M,N_df,WT,batch_size=32,alt=0):\n",
"# print(len(I),len(B),N_df,batch_size,'\\n',WT)\n",
" if alt:\n",
" dt_ldr= [x for x in DataLoader([next(gen_alt(I,B,M,N_df,WT)) for x in range(batch_size)], batch_size=batch_size)][0]\n",
" else:\n",
" dt_ldr= [x for x in DataLoader([next(gen_new(I,B,M,N_df,WT)) for x in range(batch_size)], batch_size=batch_size)][0]\n",
" return(dt_ldr)\n",
"\n",
"# def ldr_alt(I,B,M,N_df,WT,batch_size=32):\n",
"# if WT=='behavior':\n",
"# dt_ldr= [x for x in DataLoader([next(gen_new(I,B,M,N_df,WT)) for x in range(batch_size)], batch_size=batch_size)][0]\n",
"# else:\n",
"# batch1=int(batch_size/2)\n",
"# batch2=batch_size-batch1\n",
"# dt_ldr1= [x for x in DataLoader([next(gen_new(I,B,M,N_df,WT)) for x in range(batch1)], batch_size=batch1)][0]\n",
"# dt_ldr2= [x for x in DataLoader([next(gen_alt(I,B,M,N_df,WT)) for x in range(batch2)], batch_size=batch2)][0]\n",
"# dt_ldr=dt_ldr1+dt_ldr2\n",
" \n",
"# return(dt_ldr)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Should use the training set?\n",
"### 'index_in_dic':4000 not 1000"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"def get_output_new(w,wt,I_b=n_I_v,B_b=n_B_v,M_b=n_M_v,batch_sz=300,batch_num=1):\n",
" \n",
" df=pd.DataFrame()\n",
" for i in range(batch_num):\n",
" new_df=pd.DataFrame({'index_in_dic':4000,'term':w,'E':10,'P':10,'A':10,'E2':10,'P2':10,'A2':10,'term2':w,'len_Bert':3}, index=[0])\n",
" q=ldr_new(I=I_b,B=B_b,M=M_b,N_df=new_df,WT=wt,batch_size=batch_sz)\n",
" preds = bert_predict(bert_regressor.to(device), q)\n",
" if wt=='identity':\n",
" df_identity=pd.concat([Identities,new_df],axis=0)\n",
" df2=out_df(data=q,predictions=preds,df_ident=df_identity)\n",
" df=pd.concat([df,df2],axis=0)\n",
" if wt=='behavior': \n",
" df_behavior=pd.concat([Behaviors,new_df],axis=0)\n",
" df2=out_df(data=q,predictions=preds,df_beh=df_behavior)\n",
" df=pd.concat([df,df2],axis=0) \n",
" if wt=='modifier': \n",
" df_modifier=pd.concat([Modifiers,new_df],axis=0)\n",
" df2=out_df(data=q,predictions=preds,df_mod=df_modifier)\n",
" df=pd.concat([df,df2],axis=0) \n",
" \n",
" return(df)\n",
"\n",
"\n",
"def get_output_agg(w,wt,I_b=n_I_v,B_b=n_B_v,M_b=n_M_v,batch_sz=300,batch_num=1):\n",
" \n",
" df=pd.DataFrame()\n",
" for i in range(batch_num):\n",
" new_df=pd.DataFrame({'index_in_dic':4000,'term':w,'E':10,'P':10,'A':10,'E2':10,'P2':10,'A2':10,'term2':w,'len_Bert':3}, index=[0])\n",
" batch1=int(batch_sz/2)\n",
" batch2=batch_sz-batch1\n",
" if wt=='behavior': \n",
" q=ldr_new(I=I_b,B=B_b,M=M_b,N_df=new_df,WT=wt,batch_size=batch_sz)\n",
" preds = bert_predict(bert_regressor.to(device), q)\n",
" df_behavior=pd.concat([Behaviors,new_df],axis=0)\n",
" df2=out_df(data=q,predictions=preds,df_beh=df_behavior)\n",
" df=pd.concat([df,df2],axis=0)[['EEB','EPB','EAB','EB','PB', 'AB','ModA', 'Actor', 'Behavior', 'ModO', 'Object']].rename(columns={'EEB':'EE','EPB':'EP','EAB':'EA','EB':'E','PB':'P', 'AB':'A'}) \n",
"\n",
" if wt=='identity':\n",
" q=ldr_new(I=I_b,B=B_b,M=M_b,N_df=new_df,WT=wt,batch_size=batch1)\n",
" preds = bert_predict(bert_regressor.to(device), q)\n",
" df_identity=pd.concat([Identities,new_df],axis=0)\n",
" df2=out_df(data=q,predictions=preds,df_ident=df_identity)\n",
" df_act=pd.concat([df,df2],axis=0)\n",
" df_act=df_act.copy()[['EEA','EPA','EAA','EA','PA', 'AA','ModA', 'Actor', 'Behavior', 'ModO', 'Object']].rename(columns={'EEA':'EE','EPA':'EP','EAA':'EA','EA':'E','PA':'P', 'AA':'A'})\n",
" q=ldr_new(I=I_b,B=B_b,M=M_b,N_df=new_df,WT=wt,batch_size=batch2,alt=1)\n",
" preds = bert_predict(bert_regressor.to(device), q)\n",
" df_identity=pd.concat([Identities,new_df],axis=0)\n",
" df2=out_df(data=q,predictions=preds,df_ident=df_identity)\n",
" df_obj=pd.concat([df,df2],axis=0)\n",
" df_obj=df_obj.copy()[['EEO','EPO','EAO','EO', 'PO', 'AO','ModA', 'Actor', 'Behavior', 'ModO', 'Object']].rename(columns={'EEO':'EE','EPO':'EP','EAO':'EA','EO':'E','PO':'P', 'AO':'A'})\n",
" df=pd.concat([df_act,df_obj],axis=0)\n",
" if wt=='modifier': \n",
" q=ldr_new(I=I_b,B=B_b,M=M_b,N_df=new_df,WT=wt,batch_size=batch1)\n",
" preds = bert_predict(bert_regressor.to(device), q)\n",
" df_modifier=pd.concat([Modifiers,new_df],axis=0)[['index_in_dic', 'term', 'E', 'P', 'A', 'E2', 'P2', 'A2']]\n",
" df2=out_df(data=q,predictions=preds,df_mod=df_modifier)\n",
" df_act=pd.concat([df,df2],axis=0) \n",
" df_act=df_act.copy()[['EEMA', 'EPMA', 'EAMA','EM1', 'PM1', 'AM1','ModA', 'Actor', 'Behavior', 'ModO', 'Object']].rename(columns={'EEMA':'EE','EPMA':'EP','EAMA':'EA','EM1':'E','PM1':'P', 'AM1':'A'})\n",
" q=ldr_new(I=I_b,B=B_b,M=M_b,N_df=new_df,WT=wt,batch_size=batch2,alt=1)\n",
" preds = bert_predict(bert_regressor.to(device), q)\n",
" df_modifier=pd.concat([Modifiers,new_df],axis=0)\n",
" df2=out_df(data=q,predictions=preds,df_mod=df_modifier)\n",
" df_obj=pd.concat([df,df2],axis=0)\n",
" df_obj=df_obj.copy()[['EEMO', 'EPMO', 'EAMO','EM2', 'PM2', 'AM2','ModA', 'Actor', 'Behavior', 'ModO', 'Object']].rename(columns={'EEMO':'EE','EPMO':'EP','EAMO':'EA','EM2':'E','PM2':'P', 'AM2':'A'})\n",
" df=pd.concat([df_act,df_obj],axis=0) \n",
" return(df)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"def get_output_agg(w,wt,I_b=n_I_v,B_b=n_B_v,M_b=n_M_v,batch_sz=300,batch_num=1):\n",
" \n",
" df=pd.DataFrame()\n",
" for i in range(batch_num):\n",
" new_df=pd.DataFrame({'index_in_dic':4000,'term':w,'E':10,'P':10,'A':10,'E2':10,'P2':10,'A2':10,'term2':w,'len_Bert':3}, index=[0])\n",
" batch1=int(batch_sz/2)\n",
" batch2=batch_sz-batch1\n",
" if wt=='behavior': \n",
" q=ldr_new(I=I_b,B=B_b,M=M_b,N_df=new_df,WT=wt,batch_size=batch_sz)\n",
" preds = bert_predict(bert_regressor.to(device), q)\n",
" df_behavior=pd.concat([Behaviors,new_df],axis=0)\n",
" df2=out_df(data=q,predictions=preds,df_beh=df_behavior)\n",
" df=pd.concat([df,df2],axis=0)[['EEB','EPB','EAB','ModA', 'Actor', 'Behavior', 'ModO', 'Object']].rename(columns={'EEB':'EE','EPB':'EP','EAB':'EA'}) \n",
"\n",
" if wt=='identity':\n",
" q=ldr_new(I=I_b,B=B_b,M=M_b,N_df=new_df,WT=wt,batch_size=batch1)\n",
" preds = bert_predict(bert_regressor.to(device), q)\n",
" df_identity=pd.concat([Identities,new_df],axis=0)\n",
" df2=out_df(data=q,predictions=preds,df_ident=df_identity)\n",
" df_act=pd.concat([df,df2],axis=0)\n",
" df_act=df_act.copy()[['EEA','EPA','EAA','ModA', 'Actor', 'Behavior', 'ModO', 'Object']].rename(columns={'EEA':'EE','EPA':'EP','EAA':'EA'})\n",
" q=ldr_new(I=I_b,B=B_b,M=M_b,N_df=new_df,WT=wt,batch_size=batch2,alt=1)\n",
" preds = bert_predict(bert_regressor.to(device), q)\n",
" df_identity=pd.concat([Identities,new_df],axis=0)\n",
" df2=out_df(data=q,predictions=preds,df_ident=df_identity)\n",
" df_obj=pd.concat([df,df2],axis=0)\n",
" df_obj=df_obj.copy()[['EEO','EPO','EAO','ModA', 'Actor', 'Behavior', 'ModO', 'Object']].rename(columns={'EEO':'EE','EPO':'EP','EAO':'EA'})\n",
" df=pd.concat([df_act,df_obj],axis=0)\n",
" if wt=='modifier': \n",
" q=ldr_new(I=I_b,B=B_b,M=M_b,N_df=new_df,WT=wt,batch_size=batch1)\n",
" preds = bert_predict(bert_regressor.to(device), q)\n",
" df_modifier=pd.concat([Modifiers,new_df],axis=0)[['index_in_dic', 'term', 'E', 'P', 'A', 'E2', 'P2', 'A2']]\n",
" df2=out_df(data=q,predictions=preds,df_mod=df_modifier)\n",
" df_act=pd.concat([df,df2],axis=0) \n",
" df_act=df_act.copy()[['EEMA', 'EPMA', 'EAMA','ModA', 'Actor', 'Behavior', 'ModO', 'Object']].rename(columns={'EEMA':'EE','EPMA':'EP','EAMA':'EA'})\n",
" q=ldr_new(I=I_b,B=B_b,M=M_b,N_df=new_df,WT=wt,batch_size=batch2,alt=1)\n",
" preds = bert_predict(bert_regressor.to(device), q)\n",
" df_modifier=pd.concat([Modifiers,new_df],axis=0)\n",
" df2=out_df(data=q,predictions=preds,df_mod=df_modifier)\n",
" df_obj=pd.concat([df,df2],axis=0)\n",
" df_obj=df_obj.copy()[['EEMO', 'EPMO', 'EAMO','ModA', 'Actor', 'Behavior', 'ModO', 'Object']].rename(columns={'EEMO':'EE','EPMO':'EP','EAMO':'EA'})\n",
" df=pd.concat([df_act,df_obj],axis=0) \n",
" return(df)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_39756\\3313474982.py:2: FutureWarning: DataFrame.set_axis 'inplace' keyword is deprecated and will be removed in a future version. Use `obj = obj.set_axis(..., copy=False)` instead\n",
" df2=pd.concat([pd.DataFrame(scaler_M.inverse_transform(predictions[:,0:3].cpu())),\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EE | \n",
" EP | \n",
" EA | \n",
"
\n",
" \n",
" \n",
" \n",
" | count | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
"
\n",
" \n",
" | mean | \n",
" 3.53 | \n",
" 2.83 | \n",
" 0.39 | \n",
"
\n",
" \n",
" | std | \n",
" 0.33 | \n",
" 0.15 | \n",
" 0.10 | \n",
"
\n",
" \n",
" | min | \n",
" 1.39 | \n",
" 1.94 | \n",
" -0.12 | \n",
"
\n",
" \n",
" | 25% | \n",
" 3.40 | \n",
" 2.77 | \n",
" 0.33 | \n",
"
\n",
" \n",
" | 50% | \n",
" 3.58 | \n",
" 2.84 | \n",
" 0.39 | \n",
"
\n",
" \n",
" | 75% | \n",
" 3.71 | \n",
" 2.92 | \n",
" 0.45 | \n",
"
\n",
" \n",
" | max | \n",
" 4.12 | \n",
" 3.07 | \n",
" 0.76 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EE EP EA\n",
"count 300.00 300.00 300.00\n",
"mean 3.53 2.83 0.39\n",
"std 0.33 0.15 0.10\n",
"min 1.39 1.94 -0.12\n",
"25% 3.40 2.77 0.33\n",
"50% 3.58 2.84 0.39\n",
"75% 3.71 2.92 0.45\n",
"max 4.12 3.07 0.76"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_output_agg('loves','behavior',batch_sz=300).describe().round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EE | \n",
" EP | \n",
" EA | \n",
" ModA | \n",
" Actor | \n",
" Behavior | \n",
" ModO | \n",
" Object | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 3.38 | \n",
" 2.69 | \n",
" 1.08 | \n",
" happy | \n",
" nursing assistant | \n",
" hugs | \n",
" exuberant | \n",
" senior citizen | \n",
"
\n",
" \n",
" | 1 | \n",
" 3.51 | \n",
" 2.59 | \n",
" 1.13 | \n",
" happy | \n",
" bigot | \n",
" rewards | \n",
" intolerant | \n",
" optical engineer | \n",
"
\n",
" \n",
" | 2 | \n",
" 3.83 | \n",
" 2.77 | \n",
" 0.96 | \n",
" happy | \n",
" half sister | \n",
" deceives | \n",
" bright | \n",
" gossip | \n",
"
\n",
" \n",
" | 3 | \n",
" 3.75 | \n",
" 2.80 | \n",
" 0.86 | \n",
" happy | \n",
" broad | \n",
" holds up | \n",
" unscrupulous | \n",
" half sister | \n",
"
\n",
" \n",
" | 4 | \n",
" 3.70 | \n",
" 2.91 | \n",
" 0.87 | \n",
" happy | \n",
" funeral director | \n",
" contacts | \n",
" slack | \n",
" oddball | \n",
"
\n",
" \n",
" | ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" | 145 | \n",
" 3.11 | \n",
" 2.73 | \n",
" 1.08 | \n",
" dutiful | \n",
" nobody | \n",
" offends | \n",
" happy | \n",
" city councilor | \n",
"
\n",
" \n",
" | 146 | \n",
" 3.55 | \n",
" 2.66 | \n",
" 1.13 | \n",
" charmed | \n",
" minor | \n",
" watches | \n",
" happy | \n",
" philosopher | \n",
"
\n",
" \n",
" | 147 | \n",
" 3.31 | \n",
" 2.72 | \n",
" 0.68 | \n",
" nasty | \n",
" half wit | \n",
" disregards | \n",
" happy | \n",
" motel owner | \n",
"
\n",
" \n",
" | 148 | \n",
" 3.71 | \n",
" 3.00 | \n",
" 0.75 | \n",
" cool | \n",
" cop | \n",
" considers | \n",
" happy | \n",
" insurance agent | \n",
"
\n",
" \n",
" | 149 | \n",
" 3.40 | \n",
" 2.67 | \n",
" 0.75 | \n",
" jittery | \n",
" boss | \n",
" appeases | \n",
" happy | \n",
" informer | \n",
"
\n",
" \n",
"
\n",
"
300 rows × 8 columns
\n",
"
"
],
"text/plain": [
" EE EP EA ModA Actor Behavior ModO \\\n",
"0 3.38 2.69 1.08 happy nursing assistant hugs exuberant \n",
"1 3.51 2.59 1.13 happy bigot rewards intolerant \n",
"2 3.83 2.77 0.96 happy half sister deceives bright \n",
"3 3.75 2.80 0.86 happy broad holds up unscrupulous \n",
"4 3.70 2.91 0.87 happy funeral director contacts slack \n",
".. ... ... ... ... ... ... ... \n",
"145 3.11 2.73 1.08 dutiful nobody offends happy \n",
"146 3.55 2.66 1.13 charmed minor watches happy \n",
"147 3.31 2.72 0.68 nasty half wit disregards happy \n",
"148 3.71 3.00 0.75 cool cop considers happy \n",
"149 3.40 2.67 0.75 jittery boss appeases happy \n",
"\n",
" Object \n",
"0 senior citizen \n",
"1 optical engineer \n",
"2 gossip \n",
"3 half sister \n",
"4 oddball \n",
".. ... \n",
"145 city councilor \n",
"146 philosopher \n",
"147 motel owner \n",
"148 insurance agent \n",
"149 informer \n",
"\n",
"[300 rows x 8 columns]"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# gen_new(Identity,Behavior,Modifier,n_df,word_type)\n",
"# N_tmp_df=pd.DataFrame({'index_in_dic':4000,'term':'phones','E':10,'P':10,'A':10,'E2':10,'P2':10,'A2':10,'term2':'phones','len_Bert':3}, index=[0])\n",
"# next(gen_new(Identity=n_I_train,Behavior=n_B_train,Modifier=n_M_train,n_df=N_tmp_df,word_type='behavior'))\n",
"# [next(gen_new(n_I_train,n_B_train,n_M_train,n_df=N_tmp_df,word_type='behavior')) for x in range(10)]\n",
"# [x for x in DataLoader([next(gen_new(n_I_train,n_B_train,n_M_train,n_df=N_tmp_df,word_type='behavior')) for x in range(10)],\n",
"# batch_size=10)][0]\n",
"get_output_agg('happy','modifier',I_b=n_I_train,B_b=n_B_train,M_b=n_M_train,batch_sz=300,batch_num=1).round(decimals=2) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Did I use inverse transform?"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_21944\\3200981708.py:13: FutureWarning: DataFrame.set_axis 'inplace' keyword is deprecated and will be removed in a future version. Use `obj = obj.set_axis(..., copy=False)` instead\n",
" df_out=pd.concat([pd.DataFrame(scaler_M.inverse_transform(predictions[:,0:3].cpu())),\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEMA | \n",
" EPMA | \n",
" EAMA | \n",
" EEA | \n",
" EPA | \n",
" EAA | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
" EEMO | \n",
" EPMO | \n",
" EAMO | \n",
" EEO | \n",
" EPO | \n",
" EAO | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" -2.02 | \n",
" 0.49 | \n",
" 2.03 | \n",
" 1.57 | \n",
" 0.24 | \n",
" 0.61 | \n",
" 0.52 | \n",
" 1.14 | \n",
" 1.82 | \n",
" 3.54 | \n",
" 3.19 | \n",
" 0.35 | \n",
" 1.71 | \n",
" 1.43 | \n",
" 1.4 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEMA EPMA EAMA EEA EPA EAA EEB EPB EAB EEMO EPMO EAMO \\\n",
"0 -2.02 0.49 2.03 1.57 0.24 0.61 0.52 1.14 1.82 3.54 3.19 0.35 \n",
"\n",
" EEO EPO EAO \n",
"0 1.71 1.43 1.4 "
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def sent_gen(sentence):\n",
" sents=sentence\n",
" indexx=torch.tensor([1,1,1,1,1,1,1,1,1,1,1,1])\n",
" ys= torch.tensor([1,1,1,1,1,1,1,1,1,1,1,1])\n",
" inputs, masks = preprocessing_for_bert([sents])\n",
" yield inputs, masks, ys,indexx #torch.tensor(sents),\n",
"def sent_ldr(sent2,batch_size=1):\n",
" dt_ldr= [x for x in DataLoader([next(sent_gen(sent2)) for x in range(batch_size)], batch_size=batch_size)][0]\n",
" return(dt_ldr)\n",
"def EPA_sents(sent):\n",
" q=sent_ldr(sent)\n",
" predictions=bert_predict(bert_regressor.to(device), q)\n",
" df_out=pd.concat([pd.DataFrame(scaler_M.inverse_transform(predictions[:,0:3].cpu())),\n",
" pd.DataFrame(scaler_I.inverse_transform(predictions[:,3:6].cpu())),\n",
" pd.DataFrame(scaler_B.inverse_transform(predictions[:,6:9].cpu())),\n",
" pd.DataFrame(scaler_M.inverse_transform(predictions[:,9:12].cpu())),\n",
" pd.DataFrame(scaler_I.inverse_transform(predictions[:,12:15].cpu()))\n",
" ],axis=1).set_axis(['EEMA', 'EPMA', 'EAMA',\n",
" 'EEA', 'EPA', 'EAA', 'EEB', 'EPB', 'EAB',\n",
" 'EEMO', 'EPMO', 'EAMO','EEO', 'EPO', 'EAO'], axis=1, inplace=False)\n",
" return(df_out.round(decimals=2))\n",
"EPA_sents('angry student play supportive athletic')"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEMA | \n",
" EPMA | \n",
" EAMA | \n",
" EEA | \n",
" EPA | \n",
" EAA | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
" EEMO | \n",
" EPMO | \n",
" EAMO | \n",
" EEO | \n",
" EPO | \n",
" EAO | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" -1.96 | \n",
" 0.29 | \n",
" 2.17 | \n",
" 1.46 | \n",
" 1.26 | \n",
" 1.21 | \n",
" 0.1 | \n",
" 1.24 | \n",
" 1.53 | \n",
" 3.38 | \n",
" 3.06 | \n",
" 0.27 | \n",
" 1.79 | \n",
" 2.46 | \n",
" 0.65 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEMA EPMA EAMA EEA EPA EAA EEB EPB EAB EEMO EPMO EAMO \\\n",
"0 -1.96 0.29 2.17 1.46 1.26 1.21 0.1 1.24 1.53 3.38 3.06 0.27 \n",
"\n",
" EEO EPO EAO \n",
"0 1.79 2.46 0.65 "
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"EPA_sents('angry actress play supportive director')"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_39756\\3200981708.py:13: FutureWarning: DataFrame.set_axis 'inplace' keyword is deprecated and will be removed in a future version. Use `obj = obj.set_axis(..., copy=False)` instead\n",
" df_out=pd.concat([pd.DataFrame(scaler_M.inverse_transform(predictions[:,0:3].cpu())),\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEMA | \n",
" EPMA | \n",
" EAMA | \n",
" EEA | \n",
" EPA | \n",
" EAA | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
" EEMO | \n",
" EPMO | \n",
" EAMO | \n",
" EEO | \n",
" EPO | \n",
" EAO | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 2.55 | \n",
" 2.47 | \n",
" 0.18 | \n",
" 2.56 | \n",
" 2.7 | \n",
" 0.43 | \n",
" 2.26 | \n",
" 2.31 | \n",
" 0.86 | \n",
" 2.45 | \n",
" 2.54 | \n",
" 2.36 | \n",
" 0.42 | \n",
" -1.12 | \n",
" -0.85 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEMA EPMA EAMA EEA EPA EAA EEB EPB EAB EEMO EPMO EAMO \\\n",
"0 2.55 2.47 0.18 2.56 2.7 0.43 2.26 2.31 0.86 2.45 2.54 2.36 \n",
"\n",
" EEO EPO EAO \n",
"0 0.42 -1.12 -0.85 "
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"EPA_sents('supportive doctor help energetic patient')"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEMA | \n",
" EPMA | \n",
" EAMA | \n",
" EEA | \n",
" EPA | \n",
" EAA | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
" EEMO | \n",
" EPMO | \n",
" EAMO | \n",
" EEO | \n",
" EPO | \n",
" EAO | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" -1.78 | \n",
" 0.82 | \n",
" 1.99 | \n",
" 2.3 | \n",
" 2.6 | \n",
" 0.37 | \n",
" -2.44 | \n",
" 1.99 | \n",
" 1.01 | \n",
" -2.03 | \n",
" -1.92 | \n",
" -2.08 | \n",
" 0.13 | \n",
" -1.25 | \n",
" -0.7 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEMA EPMA EAMA EEA EPA EAA EEB EPB EAB EEMO EPMO EAMO EEO \\\n",
"0 -1.78 0.82 1.99 2.3 2.6 0.37 -2.44 1.99 1.01 -2.03 -1.92 -2.08 0.13 \n",
"\n",
" EPO EAO \n",
"0 -1.25 -0.7 "
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"EPA_sents('angry doctor kill weak patient')"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEMA | \n",
" EPMA | \n",
" EAMA | \n",
" EEA | \n",
" EPA | \n",
" EAA | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
" EEMO | \n",
" EPMO | \n",
" EAMO | \n",
" EEO | \n",
" EPO | \n",
" EAO | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 3.7 | \n",
" 2.86 | \n",
" 0.91 | \n",
" 0.65 | \n",
" 1.5 | \n",
" 1.0 | \n",
" -2.86 | \n",
" 2.04 | \n",
" 1.04 | \n",
" -2.07 | \n",
" -2.14 | \n",
" -1.3 | \n",
" 0.96 | \n",
" -1.25 | \n",
" 1.65 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEMA EPMA EAMA EEA EPA EAA EEB EPB EAB EEMO EPMO EAMO EEO \\\n",
"0 3.7 2.86 0.91 0.65 1.5 1.0 -2.86 2.04 1.04 -2.07 -2.14 -1.3 0.96 \n",
"\n",
" EPO EAO \n",
"0 -1.25 1.65 "
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"EPA_sents('Happy guy kill poor kid')"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEMA | \n",
" EPMA | \n",
" EAMA | \n",
" EEA | \n",
" EPA | \n",
" EAA | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
" EEMO | \n",
" EPMO | \n",
" EAMO | \n",
" EEO | \n",
" EPO | \n",
" EAO | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0.97 | \n",
" 0.15 | \n",
" 0.95 | \n",
" 1.76 | \n",
" 1.68 | \n",
" 0.4 | \n",
" 0.16 | \n",
" 1.25 | \n",
" 1.12 | \n",
" -0.07 | \n",
" 0.47 | \n",
" 0.62 | \n",
" 2.75 | \n",
" 3.12 | \n",
" 0.01 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEMA EPMA EAMA EEA EPA EAA EEB EPB EAB EEMO EPMO EAMO \\\n",
"0 0.97 0.15 0.95 1.76 1.68 0.4 0.16 1.25 1.12 -0.07 0.47 0.62 \n",
"\n",
" EEO EPO EAO \n",
"0 2.75 3.12 0.01 "
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"EPA_sents('undergraduate engineer hire talended doctor')"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEMA | \n",
" EPMA | \n",
" EAMA | \n",
" EEA | \n",
" EPA | \n",
" EAA | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
" EEMO | \n",
" EPMO | \n",
" EAMO | \n",
" EEO | \n",
" EPO | \n",
" EAO | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 3.69 | \n",
" 2.91 | \n",
" 1.12 | \n",
" 0.44 | \n",
" -1.29 | \n",
" -1.2 | \n",
" 0.54 | \n",
" 1.07 | \n",
" 1.03 | \n",
" 2.96 | \n",
" 2.92 | \n",
" 0.09 | \n",
" 2.62 | \n",
" 2.34 | \n",
" 0.19 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEMA EPMA EAMA EEA EPA EAA EEB EPB EAB EEMO EPMO EAMO \\\n",
"0 3.69 2.91 1.12 0.44 -1.29 -1.2 0.54 1.07 1.03 2.96 2.92 0.09 \n",
"\n",
" EEO EPO EAO \n",
"0 2.62 2.34 0.19 "
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"EPA_sents('happy patient call good doctor')"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEMA | \n",
" EPMA | \n",
" EAMA | \n",
" EEA | \n",
" EPA | \n",
" EAA | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
" EEMO | \n",
" EPMO | \n",
" EAMO | \n",
" EEO | \n",
" EPO | \n",
" EAO | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 3.97 | \n",
" 2.99 | \n",
" 1.13 | \n",
" 1.28 | \n",
" 2.67 | \n",
" -0.22 | \n",
" -0.95 | \n",
" 1.06 | \n",
" 0.57 | \n",
" 2.99 | \n",
" 2.81 | \n",
" -0.17 | \n",
" 2.51 | \n",
" 2.72 | \n",
" 0.07 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEMA EPMA EAMA EEA EPA EAA EEB EPB EAB EEMO EPMO EAMO \\\n",
"0 3.97 2.99 1.13 1.28 2.67 -0.22 -0.95 1.06 0.57 2.99 2.81 -0.17 \n",
"\n",
" EEO EPO EAO \n",
"0 2.51 2.72 0.07 "
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"EPA_sents('happy judge judge good doctor')"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['coach',\n",
" 'lawyer',\n",
" 'worker',\n",
" 'handyman',\n",
" 'kid',\n",
" 'host',\n",
" 'tutor',\n",
" 'roommate',\n",
" 'classmate',\n",
" 'teammate']"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lsts=['advisor','coach','lawyer','worker','handyman','kid','host','chatbot','tutor','roommate','classmate','teammate']\n",
"[x for x in lsts if x in list(Identities.term) ]"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" E | \n",
" P | \n",
" A | \n",
"
\n",
" \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
"Empty DataFrame\n",
"Columns: [E, P, A]\n",
"Index: []"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"qq='fight'\n",
"B_v[B_v.term==qq][['E','P','A']]"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" index_in_dic | \n",
" term | \n",
" E | \n",
" P | \n",
" A | \n",
" E2 | \n",
" P2 | \n",
" A2 | \n",
" term2 | \n",
" len_Bert | \n",
" cluster | \n",
"
\n",
" \n",
" \n",
" \n",
" | 495 | \n",
" 495 | \n",
" jeweler | \n",
" 0.85 | \n",
" 0.65 | \n",
" -0.64 | \n",
" 0.85 | \n",
" 0.65 | \n",
" -0.64 | \n",
" jeweler | \n",
" 2 | \n",
" 1 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" index_in_dic term E P A E2 P2 A2 term2 \\\n",
"495 495 jeweler 0.85 0.65 -0.64 0.85 0.65 -0.64 jeweler \n",
"\n",
" len_Bert cluster \n",
"495 2 1 "
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Identities.loc[Identities.term=='jeweler',]"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEA | \n",
" EPA | \n",
" EAA | \n",
"
\n",
" \n",
" \n",
" \n",
" | count | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
"
\n",
" \n",
" | mean | \n",
" 1.19 | \n",
" 2.61 | \n",
" -0.17 | \n",
"
\n",
" \n",
" | std | \n",
" 0.16 | \n",
" 0.13 | \n",
" 0.10 | \n",
"
\n",
" \n",
" | min | \n",
" 0.19 | \n",
" 2.13 | \n",
" -0.43 | \n",
"
\n",
" \n",
" | 25% | \n",
" 1.12 | \n",
" 2.53 | \n",
" -0.24 | \n",
"
\n",
" \n",
" | 50% | \n",
" 1.21 | \n",
" 2.60 | \n",
" -0.18 | \n",
"
\n",
" \n",
" | 75% | \n",
" 1.29 | \n",
" 2.70 | \n",
" -0.12 | \n",
"
\n",
" \n",
" | max | \n",
" 1.58 | \n",
" 2.98 | \n",
" 0.31 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEA EPA EAA\n",
"count 300.00 300.00 300.00\n",
"mean 1.19 2.61 -0.17\n",
"std 0.16 0.13 0.10\n",
"min 0.19 2.13 -0.43\n",
"25% 1.12 2.53 -0.24\n",
"50% 1.21 2.60 -0.18\n",
"75% 1.29 2.70 -0.12\n",
"max 1.58 2.98 0.31"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_output_new('judge','identity',I_b=n_I_v,B_b=n_B_v,M_b=n_M_v,batch_sz=300,batch_num=1\n",
" )[['EEA','EPA','EAA']].describe().round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EE | \n",
" EP | \n",
" EA | \n",
"
\n",
" \n",
" \n",
" \n",
" | count | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
"
\n",
" \n",
" | mean | \n",
" 1.15 | \n",
" 2.56 | \n",
" -0.24 | \n",
"
\n",
" \n",
" | std | \n",
" 0.15 | \n",
" 0.15 | \n",
" 0.12 | \n",
"
\n",
" \n",
" | min | \n",
" 0.61 | \n",
" 1.93 | \n",
" -0.55 | \n",
"
\n",
" \n",
" | 25% | \n",
" 1.06 | \n",
" 2.48 | \n",
" -0.33 | \n",
"
\n",
" \n",
" | 50% | \n",
" 1.15 | \n",
" 2.56 | \n",
" -0.24 | \n",
"
\n",
" \n",
" | 75% | \n",
" 1.26 | \n",
" 2.66 | \n",
" -0.15 | \n",
"
\n",
" \n",
" | max | \n",
" 1.59 | \n",
" 3.09 | \n",
" 0.18 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EE EP EA\n",
"count 300.00 300.00 300.00\n",
"mean 1.15 2.56 -0.24\n",
"std 0.15 0.15 0.12\n",
"min 0.61 1.93 -0.55\n",
"25% 1.06 2.48 -0.33\n",
"50% 1.15 2.56 -0.24\n",
"75% 1.26 2.66 -0.15\n",
"max 1.59 3.09 0.18"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_output_agg('judge','identity',I_b=n_I_v,B_b=n_B_v,M_b=n_M_v,batch_sz=300,batch_num=1\n",
" ).describe().round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
"
\n",
" \n",
" \n",
" \n",
" | count | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
"
\n",
" \n",
" | mean | \n",
" -1.46 | \n",
" 0.79 | \n",
" -0.06 | \n",
"
\n",
" \n",
" | std | \n",
" 0.18 | \n",
" 0.07 | \n",
" 0.10 | \n",
"
\n",
" \n",
" | min | \n",
" -2.33 | \n",
" 0.57 | \n",
" -0.39 | \n",
"
\n",
" \n",
" | 25% | \n",
" -1.57 | \n",
" 0.74 | \n",
" -0.12 | \n",
"
\n",
" \n",
" | 50% | \n",
" -1.47 | \n",
" 0.80 | \n",
" -0.07 | \n",
"
\n",
" \n",
" | 75% | \n",
" -1.34 | \n",
" 0.84 | \n",
" -0.00 | \n",
"
\n",
" \n",
" | max | \n",
" -1.03 | \n",
" 1.00 | \n",
" 0.38 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEB EPB EAB\n",
"count 300.00 300.00 300.00\n",
"mean -1.46 0.79 -0.06\n",
"std 0.18 0.07 0.10\n",
"min -2.33 0.57 -0.39\n",
"25% -1.57 0.74 -0.12\n",
"50% -1.47 0.80 -0.07\n",
"75% -1.34 0.84 -0.00\n",
"max -1.03 1.00 0.38"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_output_new('judges','behavior',I_b=n_I_v,B_b=n_B_v,M_b=n_M_v,batch_sz=300,batch_num=1\n",
" )[['EEB','EPB','EAB']].describe().round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" index_in_dic | \n",
" orig_term | \n",
" term | \n",
" E | \n",
" P | \n",
" A | \n",
" E2 | \n",
" P2 | \n",
" A2 | \n",
" term2 | \n",
" len_Bert | \n",
" cluster | \n",
"
\n",
" \n",
" \n",
" \n",
" | 757 | \n",
" 757 | \n",
" tell_off | \n",
" tells off | \n",
" -1.05 | \n",
" 0.96 | \n",
" 2.03 | \n",
" -1.05 | \n",
" 0.96 | \n",
" 2.03 | \n",
" tells off | \n",
" 2 | \n",
" 4 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" index_in_dic orig_term term E P A E2 P2 A2 \\\n",
"757 757 tell_off tells off -1.05 0.96 2.03 -1.05 0.96 2.03 \n",
"\n",
" term2 len_Bert cluster \n",
"757 tells off 2 4 "
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Behaviors.loc[Behaviors.orig_term=='tell_off',]"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
"
\n",
" \n",
" \n",
" \n",
" | count | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
"
\n",
" \n",
" | mean | \n",
" 1.52 | \n",
" 1.29 | \n",
" 0.83 | \n",
"
\n",
" \n",
" | std | \n",
" 0.21 | \n",
" 0.08 | \n",
" 0.10 | \n",
"
\n",
" \n",
" | min | \n",
" 0.04 | \n",
" 0.86 | \n",
" 0.51 | \n",
"
\n",
" \n",
" | 25% | \n",
" 1.40 | \n",
" 1.25 | \n",
" 0.77 | \n",
"
\n",
" \n",
" | 50% | \n",
" 1.54 | \n",
" 1.29 | \n",
" 0.83 | \n",
"
\n",
" \n",
" | 75% | \n",
" 1.65 | \n",
" 1.34 | \n",
" 0.90 | \n",
"
\n",
" \n",
" | max | \n",
" 2.08 | \n",
" 1.50 | \n",
" 1.11 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEB EPB EAB\n",
"count 300.00 300.00 300.00\n",
"mean 1.52 1.29 0.83\n",
"std 0.21 0.08 0.10\n",
"min 0.04 0.86 0.51\n",
"25% 1.40 1.25 0.77\n",
"50% 1.54 1.29 0.83\n",
"75% 1.65 1.34 0.90\n",
"max 2.08 1.50 1.11"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_output_new('calls','behavior',I_b=n_I_v,B_b=n_B_v,M_b=n_M_v,batch_sz=300,batch_num=1\n",
" )[['EEB','EPB','EAB']].describe().round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
"
\n",
" \n",
" \n",
" \n",
" | count | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
"
\n",
" \n",
" | mean | \n",
" 1.02 | \n",
" 0.98 | \n",
" 1.06 | \n",
"
\n",
" \n",
" | std | \n",
" 0.29 | \n",
" 0.10 | \n",
" 0.12 | \n",
"
\n",
" \n",
" | min | \n",
" -0.63 | \n",
" 0.61 | \n",
" 0.57 | \n",
"
\n",
" \n",
" | 25% | \n",
" 0.90 | \n",
" 0.92 | \n",
" 0.98 | \n",
"
\n",
" \n",
" | 50% | \n",
" 1.06 | \n",
" 0.97 | \n",
" 1.06 | \n",
"
\n",
" \n",
" | 75% | \n",
" 1.19 | \n",
" 1.05 | \n",
" 1.15 | \n",
"
\n",
" \n",
" | max | \n",
" 1.65 | \n",
" 1.30 | \n",
" 1.40 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEB EPB EAB\n",
"count 300.00 300.00 300.00\n",
"mean 1.02 0.98 1.06\n",
"std 0.29 0.10 0.12\n",
"min -0.63 0.61 0.57\n",
"25% 0.90 0.92 0.98\n",
"50% 1.06 0.97 1.06\n",
"75% 1.19 1.05 1.15\n",
"max 1.65 1.30 1.40"
]
},
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rnd_st=42\n",
"np.random.seed(rnd_st)\n",
"random.seed(rnd_st)\n",
"torch.manual_seed(rnd_st)\n",
"torch.cuda.manual_seed(rnd_st)\n",
"get_output_new('phones','behavior',I_b=n_I_v,B_b=n_B_v,M_b=n_M_v,batch_sz=300,batch_num=1\n",
" )[['EEB','EPB','EAB']].describe().round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" ModA | \n",
" Actor | \n",
" Behavior | \n",
" ModO | \n",
" Object | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" motivated | \n",
" Arab | \n",
" phones | \n",
" practical | \n",
" divorce lawyer | \n",
" 0.94 | \n",
" 1.03 | \n",
" 1.15 | \n",
"
\n",
" \n",
" | 1 | \n",
" doubtful | \n",
" divorcee | \n",
" phones | \n",
" cynical | \n",
" inmate | \n",
" 1.34 | \n",
" 1.11 | \n",
" 0.89 | \n",
"
\n",
" \n",
" | 2 | \n",
" nostalgic | \n",
" warden | \n",
" phones | \n",
" ambitious | \n",
" tutor | \n",
" 1.02 | \n",
" 0.89 | \n",
" 1.11 | \n",
"
\n",
" \n",
" | 3 | \n",
" practical | \n",
" best man | \n",
" phones | \n",
" clever | \n",
" stepmother | \n",
" 1.03 | \n",
" 0.93 | \n",
" 1.16 | \n",
"
\n",
" \n",
" | 4 | \n",
" suspicious | \n",
" Army enlistee | \n",
" phones | \n",
" aimless | \n",
" travel agent | \n",
" 0.99 | \n",
" 0.98 | \n",
" 0.96 | \n",
"
\n",
" \n",
" | ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" | 295 | \n",
" snooty | \n",
" craftsman | \n",
" phones | \n",
" gregarious | \n",
" jury foreman | \n",
" 0.86 | \n",
" 0.82 | \n",
" 0.94 | \n",
"
\n",
" \n",
" | 296 | \n",
" curious | \n",
" forest ranger | \n",
" phones | \n",
" discontented | \n",
" retailer | \n",
" 1.34 | \n",
" 0.98 | \n",
" 1.02 | \n",
"
\n",
" \n",
" | 297 | \n",
" skeptical | \n",
" patient | \n",
" phones | \n",
" able bodied | \n",
" cheat | \n",
" 0.96 | \n",
" 1.07 | \n",
" 1.04 | \n",
"
\n",
" \n",
" | 298 | \n",
" eager | \n",
" skilled worker | \n",
" phones | \n",
" intimidated | \n",
" Marine Corps enlistee | \n",
" 1.34 | \n",
" 1.13 | \n",
" 0.96 | \n",
"
\n",
" \n",
" | 299 | \n",
" experienced | \n",
" blogger | \n",
" phones | \n",
" unimaginative | \n",
" Air Force officer | \n",
" 1.16 | \n",
" 1.07 | \n",
" 1.18 | \n",
"
\n",
" \n",
"
\n",
"
300 rows × 8 columns
\n",
"
"
],
"text/plain": [
" ModA Actor Behavior ModO \\\n",
"0 motivated Arab phones practical \n",
"1 doubtful divorcee phones cynical \n",
"2 nostalgic warden phones ambitious \n",
"3 practical best man phones clever \n",
"4 suspicious Army enlistee phones aimless \n",
".. ... ... ... ... \n",
"295 snooty craftsman phones gregarious \n",
"296 curious forest ranger phones discontented \n",
"297 skeptical patient phones able bodied \n",
"298 eager skilled worker phones intimidated \n",
"299 experienced blogger phones unimaginative \n",
"\n",
" Object EEB EPB EAB \n",
"0 divorce lawyer 0.94 1.03 1.15 \n",
"1 inmate 1.34 1.11 0.89 \n",
"2 tutor 1.02 0.89 1.11 \n",
"3 stepmother 1.03 0.93 1.16 \n",
"4 travel agent 0.99 0.98 0.96 \n",
".. ... ... ... ... \n",
"295 jury foreman 0.86 0.82 0.94 \n",
"296 retailer 1.34 0.98 1.02 \n",
"297 cheat 0.96 1.07 1.04 \n",
"298 Marine Corps enlistee 1.34 1.13 0.96 \n",
"299 Air Force officer 1.16 1.07 1.18 \n",
"\n",
"[300 rows x 8 columns]"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_output_new('phones','behavior',I_b=n_I_train,B_b=n_B_train,M_b=n_M_train,batch_sz=300,batch_num=1\n",
" )[[\n",
" 'ModA', 'Actor', 'Behavior', 'ModO', 'Object','EEB', 'EPB', 'EAB']].round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" ModA | \n",
" Actor | \n",
" Behavior | \n",
" ModO | \n",
" Object | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" stoned | \n",
" undertaker | \n",
" phones | \n",
" respectful | \n",
" dental hygienist | \n",
" 0.98 | \n",
" 0.98 | \n",
" 0.99 | \n",
"
\n",
" \n",
" | 1 | \n",
" selfish | \n",
" dishwasher | \n",
" phones | \n",
" helpful | \n",
" capitalist | \n",
" 0.82 | \n",
" 1.01 | \n",
" 1.02 | \n",
"
\n",
" \n",
" | 2 | \n",
" dissatisfied | \n",
" dope | \n",
" phones | \n",
" respectful | \n",
" guy | \n",
" 0.92 | \n",
" 1.02 | \n",
" 1.14 | \n",
"
\n",
" \n",
" | 3 | \n",
" guilty | \n",
" co worker | \n",
" phones | \n",
" tormented | \n",
" waitress | \n",
" 1.41 | \n",
" 1.15 | \n",
" 1.04 | \n",
"
\n",
" \n",
" | 4 | \n",
" tough | \n",
" collaborator | \n",
" phones | \n",
" straightforward | \n",
" pharmacist | \n",
" 1.41 | \n",
" 1.08 | \n",
" 1.10 | \n",
"
\n",
" \n",
" | ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" | 295 | \n",
" neurotic | \n",
" racist | \n",
" phones | \n",
" idealistic | \n",
" flunky | \n",
" 1.06 | \n",
" 0.93 | \n",
" 0.97 | \n",
"
\n",
" \n",
" | 296 | \n",
" good natured | \n",
" hotshot | \n",
" phones | \n",
" sad | \n",
" pharmacist | \n",
" 0.94 | \n",
" 1.08 | \n",
" 1.05 | \n",
"
\n",
" \n",
" | 297 | \n",
" adventurous | \n",
" pharmacist | \n",
" phones | \n",
" quarrelsome | \n",
" crook | \n",
" 1.12 | \n",
" 0.93 | \n",
" 0.94 | \n",
"
\n",
" \n",
" | 298 | \n",
" tormented | \n",
" steel worker | \n",
" phones | \n",
" trusting | \n",
" waitress | \n",
" 1.07 | \n",
" 1.13 | \n",
" 1.10 | \n",
"
\n",
" \n",
" | 299 | \n",
" disagreeable | \n",
" sibling | \n",
" phones | \n",
" strange | \n",
" professional | \n",
" 1.15 | \n",
" 0.94 | \n",
" 1.10 | \n",
"
\n",
" \n",
"
\n",
"
300 rows × 8 columns
\n",
"
"
],
"text/plain": [
" ModA Actor Behavior ModO Object \\\n",
"0 stoned undertaker phones respectful dental hygienist \n",
"1 selfish dishwasher phones helpful capitalist \n",
"2 dissatisfied dope phones respectful guy \n",
"3 guilty co worker phones tormented waitress \n",
"4 tough collaborator phones straightforward pharmacist \n",
".. ... ... ... ... ... \n",
"295 neurotic racist phones idealistic flunky \n",
"296 good natured hotshot phones sad pharmacist \n",
"297 adventurous pharmacist phones quarrelsome crook \n",
"298 tormented steel worker phones trusting waitress \n",
"299 disagreeable sibling phones strange professional \n",
"\n",
" EEB EPB EAB \n",
"0 0.98 0.98 0.99 \n",
"1 0.82 1.01 1.02 \n",
"2 0.92 1.02 1.14 \n",
"3 1.41 1.15 1.04 \n",
"4 1.41 1.08 1.10 \n",
".. ... ... ... \n",
"295 1.06 0.93 0.97 \n",
"296 0.94 1.08 1.05 \n",
"297 1.12 0.93 0.94 \n",
"298 1.07 1.13 1.10 \n",
"299 1.15 0.94 1.10 \n",
"\n",
"[300 rows x 8 columns]"
]
},
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_output_new('phones','behavior',batch_sz=300,batch_num=1)[[\n",
" 'ModA', 'Actor', 'Behavior', 'ModO', 'Object','EEB', 'EPB', 'EAB']].round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
"
\n",
" \n",
" \n",
" \n",
" | count | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
"
\n",
" \n",
" | mean | \n",
" 1.05 | \n",
" 1.00 | \n",
" 1.05 | \n",
"
\n",
" \n",
" | std | \n",
" 0.26 | \n",
" 0.09 | \n",
" 0.10 | \n",
"
\n",
" \n",
" | min | \n",
" -0.78 | \n",
" 0.56 | \n",
" 0.67 | \n",
"
\n",
" \n",
" | 25% | \n",
" 0.94 | \n",
" 0.94 | \n",
" 0.99 | \n",
"
\n",
" \n",
" | 50% | \n",
" 1.06 | \n",
" 1.00 | \n",
" 1.05 | \n",
"
\n",
" \n",
" | 75% | \n",
" 1.20 | \n",
" 1.06 | \n",
" 1.11 | \n",
"
\n",
" \n",
" | max | \n",
" 1.66 | \n",
" 1.26 | \n",
" 1.43 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEB EPB EAB\n",
"count 300.00 300.00 300.00\n",
"mean 1.05 1.00 1.05\n",
"std 0.26 0.09 0.10\n",
"min -0.78 0.56 0.67\n",
"25% 0.94 0.94 0.99\n",
"50% 1.06 1.00 1.05\n",
"75% 1.20 1.06 1.11\n",
"max 1.66 1.26 1.43"
]
},
"execution_count": 67,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_output_new('phones','behavior',batch_sz=300,batch_num=1)[[\n",
" 'ModA', 'Actor', 'Behavior', 'ModO', 'Object','EEB', 'EPB', 'EAB']].describe().round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
"
\n",
" \n",
" \n",
" \n",
" | count | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
"
\n",
" \n",
" | mean | \n",
" -1.12 | \n",
" 1.06 | \n",
" 0.37 | \n",
"
\n",
" \n",
" | std | \n",
" 0.34 | \n",
" 0.16 | \n",
" 0.17 | \n",
"
\n",
" \n",
" | min | \n",
" -2.31 | \n",
" 0.45 | \n",
" -0.03 | \n",
"
\n",
" \n",
" | 25% | \n",
" -1.27 | \n",
" 0.97 | \n",
" 0.26 | \n",
"
\n",
" \n",
" | 50% | \n",
" -1.13 | \n",
" 1.08 | \n",
" 0.35 | \n",
"
\n",
" \n",
" | 75% | \n",
" -0.94 | \n",
" 1.16 | \n",
" 0.46 | \n",
"
\n",
" \n",
" | max | \n",
" 0.22 | \n",
" 1.40 | \n",
" 1.05 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEB EPB EAB\n",
"count 300.00 300.00 300.00\n",
"mean -1.12 1.06 0.37\n",
"std 0.34 0.16 0.17\n",
"min -2.31 0.45 -0.03\n",
"25% -1.27 0.97 0.26\n",
"50% -1.13 1.08 0.35\n",
"75% -0.94 1.16 0.46\n",
"max 0.22 1.40 1.05"
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_output_new('judge','behavior',I_b=n_I_v,B_b=n_B_v,M_b=n_M_v,batch_sz=300,batch_num=1\n",
" )[['EEB','EPB','EAB']].describe().round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
"
\n",
" \n",
" \n",
" \n",
" | count | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
"
\n",
" \n",
" | mean | \n",
" -1.97 | \n",
" 1.09 | \n",
" 2.59 | \n",
"
\n",
" \n",
" | std | \n",
" 0.30 | \n",
" 0.13 | \n",
" 0.21 | \n",
"
\n",
" \n",
" | min | \n",
" -2.67 | \n",
" 0.54 | \n",
" 1.28 | \n",
"
\n",
" \n",
" | 25% | \n",
" -2.19 | \n",
" 0.99 | \n",
" 2.48 | \n",
"
\n",
" \n",
" | 50% | \n",
" -1.99 | \n",
" 1.12 | \n",
" 2.63 | \n",
"
\n",
" \n",
" | 75% | \n",
" -1.81 | \n",
" 1.18 | \n",
" 2.72 | \n",
"
\n",
" \n",
" | max | \n",
" -0.80 | \n",
" 1.43 | \n",
" 3.07 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEB EPB EAB\n",
"count 300.00 300.00 300.00\n",
"mean -1.97 1.09 2.59\n",
"std 0.30 0.13 0.21\n",
"min -2.67 0.54 1.28\n",
"25% -2.19 0.99 2.48\n",
"50% -1.99 1.12 2.63\n",
"75% -1.81 1.18 2.72\n",
"max -0.80 1.43 3.07"
]
},
"execution_count": 69,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_output_new(qq,'behavior',I_b=n_I_v,B_b=n_B_v,M_b=n_M_v,batch_sz=300,batch_num=1\n",
" )[['EEB','EPB','EAB']].describe().round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
"
\n",
" \n",
" \n",
" \n",
" | count | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
"
\n",
" \n",
" | mean | \n",
" -1.84 | \n",
" 1.09 | \n",
" 2.48 | \n",
"
\n",
" \n",
" | std | \n",
" 0.35 | \n",
" 0.11 | \n",
" 0.22 | \n",
"
\n",
" \n",
" | min | \n",
" -2.54 | \n",
" 0.79 | \n",
" 1.69 | \n",
"
\n",
" \n",
" | 25% | \n",
" -2.08 | \n",
" 1.02 | \n",
" 2.34 | \n",
"
\n",
" \n",
" | 50% | \n",
" -1.88 | \n",
" 1.09 | \n",
" 2.50 | \n",
"
\n",
" \n",
" | 75% | \n",
" -1.60 | \n",
" 1.16 | \n",
" 2.64 | \n",
"
\n",
" \n",
" | max | \n",
" -0.80 | \n",
" 1.40 | \n",
" 2.99 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEB EPB EAB\n",
"count 300.00 300.00 300.00\n",
"mean -1.84 1.09 2.48\n",
"std 0.35 0.11 0.22\n",
"min -2.54 0.79 1.69\n",
"25% -2.08 1.02 2.34\n",
"50% -1.88 1.09 2.50\n",
"75% -1.60 1.16 2.64\n",
"max -0.80 1.40 2.99"
]
},
"execution_count": 70,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_output_new('fight','behavior',I_b=n_I_train,B_b=n_B_train,M_b=n_M_train,batch_sz=300,batch_num=1\n",
" )[['EEB','EPB','EAB']].describe().round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
"
\n",
" \n",
" \n",
" \n",
" | count | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
"
\n",
" \n",
" | mean | \n",
" -1.07 | \n",
" 1.08 | \n",
" 0.41 | \n",
"
\n",
" \n",
" | std | \n",
" 0.44 | \n",
" 0.18 | \n",
" 0.17 | \n",
"
\n",
" \n",
" | min | \n",
" -2.41 | \n",
" 0.40 | \n",
" 0.07 | \n",
"
\n",
" \n",
" | 25% | \n",
" -1.28 | \n",
" 1.00 | \n",
" 0.30 | \n",
"
\n",
" \n",
" | 50% | \n",
" -1.11 | \n",
" 1.08 | \n",
" 0.38 | \n",
"
\n",
" \n",
" | 75% | \n",
" -0.89 | \n",
" 1.17 | \n",
" 0.49 | \n",
"
\n",
" \n",
" | max | \n",
" 3.21 | \n",
" 2.72 | \n",
" 1.27 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEB EPB EAB\n",
"count 300.00 300.00 300.00\n",
"mean -1.07 1.08 0.41\n",
"std 0.44 0.18 0.17\n",
"min -2.41 0.40 0.07\n",
"25% -1.28 1.00 0.30\n",
"50% -1.11 1.08 0.38\n",
"75% -0.89 1.17 0.49\n",
"max 3.21 2.72 1.27"
]
},
"execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_output_new('judge','behavior',I_b=n_I_train,B_b=n_B_train,M_b=n_M_train,batch_sz=300,batch_num=1\n",
" )[['EEB','EPB','EAB']].describe().round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEA | \n",
" EPA | \n",
" EAA | \n",
"
\n",
" \n",
" \n",
" \n",
" | count | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
"
\n",
" \n",
" | mean | \n",
" 1.20 | \n",
" 2.57 | \n",
" -0.12 | \n",
"
\n",
" \n",
" | std | \n",
" 0.12 | \n",
" 0.16 | \n",
" 0.10 | \n",
"
\n",
" \n",
" | min | \n",
" 0.82 | \n",
" 2.09 | \n",
" -0.34 | \n",
"
\n",
" \n",
" | 25% | \n",
" 1.14 | \n",
" 2.47 | \n",
" -0.19 | \n",
"
\n",
" \n",
" | 50% | \n",
" 1.21 | \n",
" 2.57 | \n",
" -0.13 | \n",
"
\n",
" \n",
" | 75% | \n",
" 1.29 | \n",
" 2.68 | \n",
" -0.05 | \n",
"
\n",
" \n",
" | max | \n",
" 1.51 | \n",
" 2.96 | \n",
" 0.14 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEA EPA EAA\n",
"count 300.00 300.00 300.00\n",
"mean 1.20 2.57 -0.12\n",
"std 0.12 0.16 0.10\n",
"min 0.82 2.09 -0.34\n",
"25% 1.14 2.47 -0.19\n",
"50% 1.21 2.57 -0.13\n",
"75% 1.29 2.68 -0.05\n",
"max 1.51 2.96 0.14"
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_output_new('judge','identity',I_b=n_I_train,B_b=n_B_train,M_b=n_M_train,batch_sz=300,batch_num=1\n",
" )[['EEA','EPA','EAA']].describe().round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
"
\n",
" \n",
" \n",
" \n",
" | EEB | \n",
" 1.0*** | \n",
" 0.365*** | \n",
" -0.55*** | \n",
"
\n",
" \n",
" | EPB | \n",
" 0.365*** | \n",
" 1.0*** | \n",
" 0.181** | \n",
"
\n",
" \n",
" | EAB | \n",
" -0.55*** | \n",
" 0.181** | \n",
" 1.0*** | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEB EPB EAB\n",
"EEB 1.0*** 0.365*** -0.55***\n",
"EPB 0.365*** 1.0*** 0.181**\n",
"EAB -0.55*** 0.181** 1.0***"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pval_cor(get_output_new('fight','behavior',I_b=n_I_train,B_b=n_B_train,M_b=n_M_train,batch_sz=300,batch_num=1\n",
" )[['EEB','EPB','EAB']])"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
"
\n",
" \n",
" \n",
" \n",
" | count | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
"
\n",
" \n",
" | mean | \n",
" -1.83 | \n",
" 1.11 | \n",
" 2.49 | \n",
"
\n",
" \n",
" | std | \n",
" 0.32 | \n",
" 0.12 | \n",
" 0.21 | \n",
"
\n",
" \n",
" | min | \n",
" -2.57 | \n",
" 0.80 | \n",
" 1.78 | \n",
"
\n",
" \n",
" | 25% | \n",
" -2.05 | \n",
" 1.02 | \n",
" 2.35 | \n",
"
\n",
" \n",
" | 50% | \n",
" -1.86 | \n",
" 1.11 | \n",
" 2.52 | \n",
"
\n",
" \n",
" | 75% | \n",
" -1.63 | \n",
" 1.18 | \n",
" 2.64 | \n",
"
\n",
" \n",
" | max | \n",
" -0.65 | \n",
" 1.40 | \n",
" 2.92 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEB EPB EAB\n",
"count 300.00 300.00 300.00\n",
"mean -1.83 1.11 2.49\n",
"std 0.32 0.12 0.21\n",
"min -2.57 0.80 1.78\n",
"25% -2.05 1.02 2.35\n",
"50% -1.86 1.11 2.52\n",
"75% -1.63 1.18 2.64\n",
"max -0.65 1.40 2.92"
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_output_new(qq,'behavior',I_b=n_I_train,B_b=n_B_train,M_b=n_M_train,batch_sz=300,batch_num=1\n",
" )[['EEB','EPB','EAB']].describe().round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEMA | \n",
" EPMA | \n",
" EAMA | \n",
" EEA | \n",
" EPA | \n",
" EAA | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
" EEMO | \n",
" ... | \n",
" PM2 | \n",
" AM2 | \n",
" EO | \n",
" PO | \n",
" AO | \n",
" ModA | \n",
" Actor | \n",
" Behavior | \n",
" ModO | \n",
" Object | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" -1.85 | \n",
" -1.32 | \n",
" 0.21 | \n",
" 0.72 | \n",
" 1.40 | \n",
" 1.62 | \n",
" -1.69 | \n",
" 1.01 | \n",
" 2.48 | \n",
" 0.92 | \n",
" ... | \n",
" 0.33 | \n",
" -2.28 | \n",
" 1.33 | \n",
" 1.65 | \n",
" 1.92 | \n",
" scared | \n",
" competitor | \n",
" fight | \n",
" reserved | \n",
" restaurant operator | \n",
"
\n",
" \n",
" | 1 | \n",
" -1.83 | \n",
" -1.17 | \n",
" -0.98 | \n",
" 1.70 | \n",
" 2.53 | \n",
" 2.61 | \n",
" -1.85 | \n",
" 1.03 | \n",
" 2.27 | \n",
" -1.67 | \n",
" ... | \n",
" -1.10 | \n",
" 0.48 | \n",
" 0.77 | \n",
" -1.02 | \n",
" 0.76 | \n",
" uneasy | \n",
" athlete | \n",
" fight | \n",
" impractical | \n",
" assembly line worker | \n",
"
\n",
" \n",
" | 2 | \n",
" 2.05 | \n",
" 1.46 | \n",
" -0.95 | \n",
" 2.45 | \n",
" 2.21 | \n",
" -0.12 | \n",
" -2.30 | \n",
" 1.06 | \n",
" 2.79 | \n",
" -1.19 | \n",
" ... | \n",
" 1.11 | \n",
" 1.96 | \n",
" 1.25 | \n",
" 0.18 | \n",
" 0.48 | \n",
" poised | \n",
" academic | \n",
" fight | \n",
" bossy | \n",
" chap | \n",
"
\n",
" \n",
" | 3 | \n",
" -2.37 | \n",
" -1.53 | \n",
" -2.14 | \n",
" 0.53 | \n",
" 0.79 | \n",
" 0.52 | \n",
" -1.44 | \n",
" 1.07 | \n",
" 2.37 | \n",
" -1.87 | \n",
" ... | \n",
" -1.25 | \n",
" -2.42 | \n",
" -1.24 | \n",
" -0.19 | \n",
" 2.04 | \n",
" low | \n",
" stockholder | \n",
" fight | \n",
" cheerless | \n",
" smart aleck | \n",
"
\n",
" \n",
" | 4 | \n",
" -1.59 | \n",
" -1.33 | \n",
" -1.59 | \n",
" 2.43 | \n",
" 1.61 | \n",
" -0.86 | \n",
" -2.31 | \n",
" 0.99 | \n",
" 2.43 | \n",
" -1.89 | \n",
" ... | \n",
" -0.75 | \n",
" 1.94 | \n",
" 2.39 | \n",
" 1.66 | \n",
" -0.31 | \n",
" slack | \n",
" grandfather | \n",
" fight | \n",
" manic | \n",
" tutor | \n",
"
\n",
" \n",
" | ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" | 295 | \n",
" -0.54 | \n",
" 0.23 | \n",
" 0.17 | \n",
" -1.41 | \n",
" -2.90 | \n",
" -1.35 | \n",
" -1.50 | \n",
" 0.87 | \n",
" 2.33 | \n",
" -1.98 | \n",
" ... | \n",
" -0.60 | \n",
" 0.97 | \n",
" 0.76 | \n",
" 0.14 | \n",
" -0.66 | \n",
" mischievous | \n",
" hostage | \n",
" fight | \n",
" terrified | \n",
" Hindu | \n",
"
\n",
" \n",
" | 296 | \n",
" 3.18 | \n",
" 2.24 | \n",
" -0.68 | \n",
" -1.30 | \n",
" -1.31 | \n",
" 0.38 | \n",
" -1.69 | \n",
" 1.19 | \n",
" 2.36 | \n",
" 3.03 | \n",
" ... | \n",
" 3.22 | \n",
" 1.51 | \n",
" -1.75 | \n",
" -1.63 | \n",
" 0.35 | \n",
" thoughtful | \n",
" poser | \n",
" fight | \n",
" motivated | \n",
" inmate | \n",
"
\n",
" \n",
" | 297 | \n",
" 1.55 | \n",
" 2.12 | \n",
" 1.10 | \n",
" 0.44 | \n",
" -0.08 | \n",
" 1.16 | \n",
" -1.83 | \n",
" 1.04 | \n",
" 2.37 | \n",
" -2.27 | \n",
" ... | \n",
" 0.62 | \n",
" 1.39 | \n",
" -1.88 | \n",
" -0.89 | \n",
" -0.77 | \n",
" masculine | \n",
" salesclerk | \n",
" fight | \n",
" hostile | \n",
" pessimist | \n",
"
\n",
" \n",
" | 298 | \n",
" -2.51 | \n",
" -0.26 | \n",
" 1.01 | \n",
" 1.58 | \n",
" 1.55 | \n",
" 1.45 | \n",
" -1.65 | \n",
" 1.24 | \n",
" 2.39 | \n",
" 2.30 | \n",
" ... | \n",
" 1.78 | \n",
" -0.64 | \n",
" -1.61 | \n",
" -0.70 | \n",
" -1.26 | \n",
" conceited | \n",
" brother | \n",
" fight | \n",
" elegant | \n",
" miser | \n",
"
\n",
" \n",
" | 299 | \n",
" 0.10 | \n",
" 0.72 | \n",
" 0.01 | \n",
" 1.94 | \n",
" 0.60 | \n",
" 1.21 | \n",
" -2.00 | \n",
" 1.24 | \n",
" 2.93 | \n",
" 2.19 | \n",
" ... | \n",
" 2.23 | \n",
" 0.04 | \n",
" 0.91 | \n",
" 2.29 | \n",
" 1.30 | \n",
" obstinant | \n",
" granddaughter | \n",
" fight | \n",
" esteemed | \n",
" policeman | \n",
"
\n",
" \n",
"
\n",
"
300 rows × 35 columns
\n",
"
"
],
"text/plain": [
" EEMA EPMA EAMA EEA EPA EAA EEB EPB EAB EEMO ... PM2 \\\n",
"0 -1.85 -1.32 0.21 0.72 1.40 1.62 -1.69 1.01 2.48 0.92 ... 0.33 \n",
"1 -1.83 -1.17 -0.98 1.70 2.53 2.61 -1.85 1.03 2.27 -1.67 ... -1.10 \n",
"2 2.05 1.46 -0.95 2.45 2.21 -0.12 -2.30 1.06 2.79 -1.19 ... 1.11 \n",
"3 -2.37 -1.53 -2.14 0.53 0.79 0.52 -1.44 1.07 2.37 -1.87 ... -1.25 \n",
"4 -1.59 -1.33 -1.59 2.43 1.61 -0.86 -2.31 0.99 2.43 -1.89 ... -0.75 \n",
".. ... ... ... ... ... ... ... ... ... ... ... ... \n",
"295 -0.54 0.23 0.17 -1.41 -2.90 -1.35 -1.50 0.87 2.33 -1.98 ... -0.60 \n",
"296 3.18 2.24 -0.68 -1.30 -1.31 0.38 -1.69 1.19 2.36 3.03 ... 3.22 \n",
"297 1.55 2.12 1.10 0.44 -0.08 1.16 -1.83 1.04 2.37 -2.27 ... 0.62 \n",
"298 -2.51 -0.26 1.01 1.58 1.55 1.45 -1.65 1.24 2.39 2.30 ... 1.78 \n",
"299 0.10 0.72 0.01 1.94 0.60 1.21 -2.00 1.24 2.93 2.19 ... 2.23 \n",
"\n",
" AM2 EO PO AO ModA Actor Behavior \\\n",
"0 -2.28 1.33 1.65 1.92 scared competitor fight \n",
"1 0.48 0.77 -1.02 0.76 uneasy athlete fight \n",
"2 1.96 1.25 0.18 0.48 poised academic fight \n",
"3 -2.42 -1.24 -0.19 2.04 low stockholder fight \n",
"4 1.94 2.39 1.66 -0.31 slack grandfather fight \n",
".. ... ... ... ... ... ... ... \n",
"295 0.97 0.76 0.14 -0.66 mischievous hostage fight \n",
"296 1.51 -1.75 -1.63 0.35 thoughtful poser fight \n",
"297 1.39 -1.88 -0.89 -0.77 masculine salesclerk fight \n",
"298 -0.64 -1.61 -0.70 -1.26 conceited brother fight \n",
"299 0.04 0.91 2.29 1.30 obstinant granddaughter fight \n",
"\n",
" ModO Object \n",
"0 reserved restaurant operator \n",
"1 impractical assembly line worker \n",
"2 bossy chap \n",
"3 cheerless smart aleck \n",
"4 manic tutor \n",
".. ... ... \n",
"295 terrified Hindu \n",
"296 motivated inmate \n",
"297 hostile pessimist \n",
"298 elegant miser \n",
"299 esteemed policeman \n",
"\n",
"[300 rows x 35 columns]"
]
},
"execution_count": 75,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_output_new(qq,'behavior',I_b=n_I_train,B_b=n_B_train,M_b=n_M_train,batch_sz=300,batch_num=1\n",
" ).round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 76,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lsts=['kill','negotiate','play','surf',]\n",
"df_f=get_output_new('kill','behavior')\n",
"[x for x in lsts if x in list(Behaviors.term) ]\n"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
"df_f=get_output_new('doctor','identity')"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEMA | \n",
" EPMA | \n",
" EAMA | \n",
" EEA | \n",
" EPA | \n",
" EAA | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
" EEMO | \n",
" EPMO | \n",
" EAMO | \n",
" EEO | \n",
" EPO | \n",
" EAO | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" -2.28 | \n",
" -0.7 | \n",
" 0.29 | \n",
" -0.84 | \n",
" -0.39 | \n",
" 0.56 | \n",
" -1.41 | \n",
" 1.62 | \n",
" 0.42 | \n",
" 1.32 | \n",
" 0.86 | \n",
" 1.35 | \n",
" 0.57 | \n",
" -0.53 | \n",
" -0.47 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEMA EPMA EAMA EEA EPA EAA EEB EPB EAB EEMO EPMO EAMO \\\n",
"0 -2.28 -0.7 0.29 -0.84 -0.39 0.56 -1.41 1.62 0.42 1.32 0.86 1.35 \n",
"\n",
" EEO EPO EAO \n",
"0 0.57 -0.53 -0.47 "
]
},
"execution_count": 78,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"EPA_sents('upset chatbot subdue euphoric plumber')"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEMA | \n",
" EPMA | \n",
" EAMA | \n",
" EEA | \n",
" EPA | \n",
" EAA | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
" EEMO | \n",
" ... | \n",
" AA | \n",
" EB | \n",
" PB | \n",
" AB | \n",
" EM2 | \n",
" PM2 | \n",
" AM2 | \n",
" EO | \n",
" PO | \n",
" AO | \n",
"
\n",
" \n",
" \n",
" \n",
" | count | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
" ... | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
" 300.00 | \n",
"
\n",
" \n",
" | mean | \n",
" 0.11 | \n",
" 0.45 | \n",
" 0.12 | \n",
" 2.77 | \n",
" 3.03 | \n",
" 0.44 | \n",
" -0.28 | \n",
" 1.20 | \n",
" 0.35 | \n",
" -0.10 | \n",
" ... | \n",
" 10.07 | \n",
" -0.37 | \n",
" 1.25 | \n",
" 0.37 | \n",
" -0.08 | \n",
" 0.37 | \n",
" -0.07 | \n",
" 0.36 | \n",
" 0.46 | \n",
" 0.49 | \n",
"
\n",
" \n",
" | std | \n",
" 1.94 | \n",
" 1.30 | \n",
" 0.95 | \n",
" 0.17 | \n",
" 0.14 | \n",
" 0.09 | \n",
" 1.76 | \n",
" 0.68 | \n",
" 0.70 | \n",
" 1.93 | \n",
" ... | \n",
" 0.00 | \n",
" 2.05 | \n",
" 0.88 | \n",
" 0.89 | \n",
" 2.26 | \n",
" 1.40 | \n",
" 1.26 | \n",
" 1.67 | \n",
" 1.26 | \n",
" 0.90 | \n",
"
\n",
" \n",
" | min | \n",
" -3.34 | \n",
" -1.93 | \n",
" -2.35 | \n",
" 1.76 | \n",
" 2.42 | \n",
" 0.20 | \n",
" -3.89 | \n",
" -0.30 | \n",
" -1.35 | \n",
" -3.26 | \n",
" ... | \n",
" 10.07 | \n",
" -4.18 | \n",
" -0.73 | \n",
" -2.50 | \n",
" -3.76 | \n",
" -1.93 | \n",
" -2.76 | \n",
" -3.80 | \n",
" -2.57 | \n",
" -1.49 | \n",
"
\n",
" \n",
" | 25% | \n",
" -1.67 | \n",
" -0.68 | \n",
" -0.46 | \n",
" 2.69 | \n",
" 2.96 | \n",
" 0.37 | \n",
" -1.61 | \n",
" 0.74 | \n",
" -0.17 | \n",
" -1.83 | \n",
" ... | \n",
" 10.07 | \n",
" -1.80 | \n",
" 0.77 | \n",
" -0.32 | \n",
" -2.10 | \n",
" -0.79 | \n",
" -0.90 | \n",
" -0.86 | \n",
" -0.35 | \n",
" -0.28 | \n",
"
\n",
" \n",
" | 50% | \n",
" -0.38 | \n",
" 0.27 | \n",
" 0.04 | \n",
" 2.78 | \n",
" 3.04 | \n",
" 0.44 | \n",
" -0.77 | \n",
" 1.17 | \n",
" 0.36 | \n",
" -0.79 | \n",
" ... | \n",
" 10.07 | \n",
" -0.95 | \n",
" 1.24 | \n",
" 0.46 | \n",
" -0.80 | \n",
" 0.05 | \n",
" -0.04 | \n",
" 0.83 | \n",
" 0.66 | \n",
" 0.52 | \n",
"
\n",
" \n",
" | 75% | \n",
" 2.11 | \n",
" 1.53 | \n",
" 0.70 | \n",
" 2.87 | \n",
" 3.13 | \n",
" 0.49 | \n",
" 1.26 | \n",
" 1.69 | \n",
" 0.87 | \n",
" 1.84 | \n",
" ... | \n",
" 10.07 | \n",
" 1.49 | \n",
" 1.96 | \n",
" 0.98 | \n",
" 2.35 | \n",
" 1.70 | \n",
" 0.66 | \n",
" 1.68 | \n",
" 1.59 | \n",
" 1.15 | \n",
"
\n",
" \n",
" | max | \n",
" 3.47 | \n",
" 3.01 | \n",
" 2.62 | \n",
" 3.22 | \n",
" 3.36 | \n",
" 0.70 | \n",
" 3.09 | \n",
" 2.58 | \n",
" 2.39 | \n",
" 3.35 | \n",
" ... | \n",
" 10.07 | \n",
" 3.71 | \n",
" 3.63 | \n",
" 2.47 | \n",
" 3.40 | \n",
" 3.20 | \n",
" 2.65 | \n",
" 2.73 | \n",
" 2.81 | \n",
" 2.32 | \n",
"
\n",
" \n",
"
\n",
"
8 rows × 30 columns
\n",
"
"
],
"text/plain": [
" EEMA EPMA EAMA EEA EPA EAA EEB EPB EAB \\\n",
"count 300.00 300.00 300.00 300.00 300.00 300.00 300.00 300.00 300.00 \n",
"mean 0.11 0.45 0.12 2.77 3.03 0.44 -0.28 1.20 0.35 \n",
"std 1.94 1.30 0.95 0.17 0.14 0.09 1.76 0.68 0.70 \n",
"min -3.34 -1.93 -2.35 1.76 2.42 0.20 -3.89 -0.30 -1.35 \n",
"25% -1.67 -0.68 -0.46 2.69 2.96 0.37 -1.61 0.74 -0.17 \n",
"50% -0.38 0.27 0.04 2.78 3.04 0.44 -0.77 1.17 0.36 \n",
"75% 2.11 1.53 0.70 2.87 3.13 0.49 1.26 1.69 0.87 \n",
"max 3.47 3.01 2.62 3.22 3.36 0.70 3.09 2.58 2.39 \n",
"\n",
" EEMO ... AA EB PB AB EM2 PM2 AM2 \\\n",
"count 300.00 ... 300.00 300.00 300.00 300.00 300.00 300.00 300.00 \n",
"mean -0.10 ... 10.07 -0.37 1.25 0.37 -0.08 0.37 -0.07 \n",
"std 1.93 ... 0.00 2.05 0.88 0.89 2.26 1.40 1.26 \n",
"min -3.26 ... 10.07 -4.18 -0.73 -2.50 -3.76 -1.93 -2.76 \n",
"25% -1.83 ... 10.07 -1.80 0.77 -0.32 -2.10 -0.79 -0.90 \n",
"50% -0.79 ... 10.07 -0.95 1.24 0.46 -0.80 0.05 -0.04 \n",
"75% 1.84 ... 10.07 1.49 1.96 0.98 2.35 1.70 0.66 \n",
"max 3.35 ... 10.07 3.71 3.63 2.47 3.40 3.20 2.65 \n",
"\n",
" EO PO AO \n",
"count 300.00 300.00 300.00 \n",
"mean 0.36 0.46 0.49 \n",
"std 1.67 1.26 0.90 \n",
"min -3.80 -2.57 -1.49 \n",
"25% -0.86 -0.35 -0.28 \n",
"50% 0.83 0.66 0.52 \n",
"75% 1.68 1.59 1.15 \n",
"max 2.73 2.81 2.32 \n",
"\n",
"[8 rows x 30 columns]"
]
},
"execution_count": 79,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_f.describe().round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEA | \n",
" EPA | \n",
" EAA | \n",
"
\n",
" \n",
" \n",
" \n",
" | count | \n",
" 300.000000 | \n",
" 300.000000 | \n",
" 300.000000 | \n",
"
\n",
" \n",
" | mean | \n",
" 0.488282 | \n",
" -1.367861 | \n",
" 0.824750 | \n",
"
\n",
" \n",
" | std | \n",
" 0.117244 | \n",
" 0.127288 | \n",
" 0.102284 | \n",
"
\n",
" \n",
" | min | \n",
" 0.157316 | \n",
" -1.760292 | \n",
" 0.488504 | \n",
"
\n",
" \n",
" | 25% | \n",
" 0.412799 | \n",
" -1.448211 | \n",
" 0.763780 | \n",
"
\n",
" \n",
" | 50% | \n",
" 0.491181 | \n",
" -1.372683 | \n",
" 0.825545 | \n",
"
\n",
" \n",
" | 75% | \n",
" 0.565524 | \n",
" -1.282036 | \n",
" 0.891511 | \n",
"
\n",
" \n",
" | max | \n",
" 0.800528 | \n",
" -0.966103 | \n",
" 1.156149 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEA EPA EAA\n",
"count 300.000000 300.000000 300.000000\n",
"mean 0.488282 -1.367861 0.824750\n",
"std 0.117244 0.127288 0.102284\n",
"min 0.157316 -1.760292 0.488504\n",
"25% 0.412799 -1.448211 0.763780\n",
"50% 0.491181 -1.372683 0.825545\n",
"75% 0.565524 -1.282036 0.891511\n",
"max 0.800528 -0.966103 1.156149"
]
},
"execution_count": 80,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_output_new('27 year old','identity')[['EEA','EPA','EAA']].describe()"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEMA | \n",
" EPMA | \n",
" EAMA | \n",
"
\n",
" \n",
" \n",
" \n",
" | count | \n",
" 300.000000 | \n",
" 300.000000 | \n",
" 300.000000 | \n",
"
\n",
" \n",
" | mean | \n",
" 2.897354 | \n",
" 2.168899 | \n",
" 0.860349 | \n",
"
\n",
" \n",
" | std | \n",
" 0.187449 | \n",
" 0.138975 | \n",
" 0.182203 | \n",
"
\n",
" \n",
" | min | \n",
" 2.403197 | \n",
" 1.781457 | \n",
" 0.046930 | \n",
"
\n",
" \n",
" | 25% | \n",
" 2.764331 | \n",
" 2.082838 | \n",
" 0.738104 | \n",
"
\n",
" \n",
" | 50% | \n",
" 2.909754 | \n",
" 2.170886 | \n",
" 0.881962 | \n",
"
\n",
" \n",
" | 75% | \n",
" 3.023832 | \n",
" 2.262232 | \n",
" 0.991849 | \n",
"
\n",
" \n",
" | max | \n",
" 3.424148 | \n",
" 2.558727 | \n",
" 1.394213 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEMA EPMA EAMA\n",
"count 300.000000 300.000000 300.000000\n",
"mean 2.897354 2.168899 0.860349\n",
"std 0.187449 0.138975 0.182203\n",
"min 2.403197 1.781457 0.046930\n",
"25% 2.764331 2.082838 0.738104\n",
"50% 2.909754 2.170886 0.881962\n",
"75% 3.023832 2.262232 0.991849\n",
"max 3.424148 2.558727 1.394213"
]
},
"execution_count": 81,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_output_new('active','modifier')[['EEMA','EPMA','EAMA']].describe()"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEMA | \n",
" EPMA | \n",
" EAMA | \n",
"
\n",
" \n",
" \n",
" \n",
" | count | \n",
" 300.000000 | \n",
" 300.000000 | \n",
" 300.000000 | \n",
"
\n",
" \n",
" | mean | \n",
" -3.045832 | \n",
" -0.668253 | \n",
" -1.153561 | \n",
"
\n",
" \n",
" | std | \n",
" 0.182821 | \n",
" 0.122327 | \n",
" 0.156787 | \n",
"
\n",
" \n",
" | min | \n",
" -3.584314 | \n",
" -1.047460 | \n",
" -1.518406 | \n",
"
\n",
" \n",
" | 25% | \n",
" -3.178315 | \n",
" -0.739255 | \n",
" -1.269563 | \n",
"
\n",
" \n",
" | 50% | \n",
" -3.055030 | \n",
" -0.655889 | \n",
" -1.164065 | \n",
"
\n",
" \n",
" | 75% | \n",
" -2.946763 | \n",
" -0.592226 | \n",
" -1.057461 | \n",
"
\n",
" \n",
" | max | \n",
" -2.377005 | \n",
" -0.332306 | \n",
" -0.602842 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EEMA EPMA EAMA\n",
"count 300.000000 300.000000 300.000000\n",
"mean -3.045832 -0.668253 -1.153561\n",
"std 0.182821 0.122327 0.156787\n",
"min -3.584314 -1.047460 -1.518406\n",
"25% -3.178315 -0.739255 -1.269563\n",
"50% -3.055030 -0.655889 -1.164065\n",
"75% -2.946763 -0.592226 -1.057461\n",
"max -2.377005 -0.332306 -0.602842"
]
},
"execution_count": 82,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_output_new('abandoned','modifier')[['EEMA','EPMA','EAMA']].describe()"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EEMA | \n",
" EPMA | \n",
" EAMA | \n",
" EEA | \n",
" EPA | \n",
" EAA | \n",
" EEB | \n",
" EPB | \n",
" EAB | \n",
" EEMO | \n",
" ... | \n",
" AA | \n",
" EB | \n",
" PB | \n",
" AB | \n",
" EM2 | \n",
" PM2 | \n",
" AM2 | \n",
" EO | \n",
" PO | \n",
" AO | \n",
"
\n",
" \n",
" \n",
" \n",
" | count | \n",
" 300.000000 | \n",
" 300.000000 | \n",
" 300.000000 | \n",
" 300.000000 | \n",
" 300.000000 | \n",
" 300.000000 | \n",
" 300.000000 | \n",
" 300.000000 | \n",
" 300.000000 | \n",
" 300.000000 | \n",
" ... | \n",
" 300.000000 | \n",
" 300.000000 | \n",
" 300.000000 | \n",
" 3.000000e+02 | \n",
" 300.000000 | \n",
" 300.000000 | \n",
" 300.000000 | \n",
" 300.000000 | \n",
" 300.000000 | \n",
" 300.000000 | \n",
"
\n",
" \n",
" | mean | \n",
" 0.062780 | \n",
" 0.592260 | \n",
" 0.221773 | \n",
" 0.346075 | \n",
" 0.582873 | \n",
" 0.598837 | \n",
" 0.371312 | \n",
" 1.447973 | \n",
" 0.913698 | \n",
" -0.191019 | \n",
" ... | \n",
" 0.530200 | \n",
" 19.626388 | \n",
" 9.704184 | \n",
" 9.984769e+00 | \n",
" -0.134833 | \n",
" 0.417400 | \n",
" 0.008433 | \n",
" 0.488467 | \n",
" 0.596233 | \n",
" 0.413967 | \n",
"
\n",
" \n",
" | std | \n",
" 2.350052 | \n",
" 1.507310 | \n",
" 1.291631 | \n",
" 1.639276 | \n",
" 1.388081 | \n",
" 0.941282 | \n",
" 0.269116 | \n",
" 0.167522 | \n",
" 0.121970 | \n",
" 2.254651 | \n",
" ... | \n",
" 0.957094 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 1.779325e-15 | \n",
" 2.181123 | \n",
" 1.532932 | \n",
" 1.316188 | \n",
" 1.594445 | \n",
" 1.280634 | \n",
" 0.910204 | \n",
"
\n",
" \n",
" | min | \n",
" -3.790748 | \n",
" -2.423410 | \n",
" -2.970218 | \n",
" -4.400205 | \n",
" -2.740134 | \n",
" -2.016754 | \n",
" -0.646917 | \n",
" 0.959772 | \n",
" 0.463262 | \n",
" -3.457727 | \n",
" ... | \n",
" -2.290000 | \n",
" 19.626388 | \n",
" 9.704184 | \n",
" 9.984769e+00 | \n",
" -3.690000 | \n",
" -2.830000 | \n",
" -2.630000 | \n",
" -3.870000 | \n",
" -2.850000 | \n",
" -2.240000 | \n",
"
\n",
" \n",
" | 25% | \n",
" -2.034492 | \n",
" -0.688835 | \n",
" -0.666306 | \n",
" -0.907877 | \n",
" -0.452238 | \n",
" 0.005147 | \n",
" 0.212077 | \n",
" 1.336761 | \n",
" 0.832745 | \n",
" -2.095801 | \n",
" ... | \n",
" -0.050000 | \n",
" 19.626388 | \n",
" 9.704184 | \n",
" 9.984769e+00 | \n",
" -2.010000 | \n",
" -0.850000 | \n",
" -0.932500 | \n",
" -0.392500 | \n",
" -0.290000 | \n",
" -0.140000 | \n",
"
\n",
" \n",
" | 50% | \n",
" -0.722680 | \n",
" 0.412652 | \n",
" 0.283587 | \n",
" 0.741714 | \n",
" 0.633256 | \n",
" 0.575360 | \n",
" 0.378328 | \n",
" 1.430377 | \n",
" 0.910863 | \n",
" -0.993144 | \n",
" ... | \n",
" 0.520000 | \n",
" 19.626388 | \n",
" 9.704184 | \n",
" 9.984769e+00 | \n",
" -0.985000 | \n",
" 0.235000 | \n",
" -0.010000 | \n",
" 0.835000 | \n",
" 0.740000 | \n",
" 0.400000 | \n",
"
\n",
" \n",
" | 75% | \n",
" 2.610461 | \n",
" 1.891361 | \n",
" 1.217855 | \n",
" 1.536643 | \n",
" 1.621428 | \n",
" 1.291292 | \n",
" 0.546459 | \n",
" 1.563355 | \n",
" 1.000382 | \n",
" 2.175765 | \n",
" ... | \n",
" 1.190000 | \n",
" 19.626388 | \n",
" 9.704184 | \n",
" 9.984769e+00 | \n",
" 2.060000 | \n",
" 1.995000 | \n",
" 0.882500 | \n",
" 1.640000 | \n",
" 1.545000 | \n",
" 1.020000 | \n",
"
\n",
" \n",
" | max | \n",
" 4.432653 | \n",
" 3.920668 | \n",
" 2.868834 | \n",
" 3.754321 | \n",
" 3.444081 | \n",
" 3.178367 | \n",
" 1.262929 | \n",
" 1.955977 | \n",
" 1.287469 | \n",
" 3.936452 | \n",
" ... | \n",
" 3.080000 | \n",
" 19.626388 | \n",
" 9.704184 | \n",
" 9.984769e+00 | \n",
" 3.560000 | \n",
" 3.530000 | \n",
" 2.670000 | \n",
" 3.380000 | \n",
" 3.700000 | \n",
" 2.920000 | \n",
"
\n",
" \n",
"
\n",
"
8 rows × 30 columns
\n",
"
"
],
"text/plain": [
" EEMA EPMA EAMA EEA EPA EAA \\\n",
"count 300.000000 300.000000 300.000000 300.000000 300.000000 300.000000 \n",
"mean 0.062780 0.592260 0.221773 0.346075 0.582873 0.598837 \n",
"std 2.350052 1.507310 1.291631 1.639276 1.388081 0.941282 \n",
"min -3.790748 -2.423410 -2.970218 -4.400205 -2.740134 -2.016754 \n",
"25% -2.034492 -0.688835 -0.666306 -0.907877 -0.452238 0.005147 \n",
"50% -0.722680 0.412652 0.283587 0.741714 0.633256 0.575360 \n",
"75% 2.610461 1.891361 1.217855 1.536643 1.621428 1.291292 \n",
"max 4.432653 3.920668 2.868834 3.754321 3.444081 3.178367 \n",
"\n",
" EEB EPB EAB EEMO ... AA \\\n",
"count 300.000000 300.000000 300.000000 300.000000 ... 300.000000 \n",
"mean 0.371312 1.447973 0.913698 -0.191019 ... 0.530200 \n",
"std 0.269116 0.167522 0.121970 2.254651 ... 0.957094 \n",
"min -0.646917 0.959772 0.463262 -3.457727 ... -2.290000 \n",
"25% 0.212077 1.336761 0.832745 -2.095801 ... -0.050000 \n",
"50% 0.378328 1.430377 0.910863 -0.993144 ... 0.520000 \n",
"75% 0.546459 1.563355 1.000382 2.175765 ... 1.190000 \n",
"max 1.262929 1.955977 1.287469 3.936452 ... 3.080000 \n",
"\n",
" EB PB AB EM2 PM2 \\\n",
"count 300.000000 300.000000 3.000000e+02 300.000000 300.000000 \n",
"mean 19.626388 9.704184 9.984769e+00 -0.134833 0.417400 \n",
"std 0.000000 0.000000 1.779325e-15 2.181123 1.532932 \n",
"min 19.626388 9.704184 9.984769e+00 -3.690000 -2.830000 \n",
"25% 19.626388 9.704184 9.984769e+00 -2.010000 -0.850000 \n",
"50% 19.626388 9.704184 9.984769e+00 -0.985000 0.235000 \n",
"75% 19.626388 9.704184 9.984769e+00 2.060000 1.995000 \n",
"max 19.626388 9.704184 9.984769e+00 3.560000 3.530000 \n",
"\n",
" AM2 EO PO AO \n",
"count 300.000000 300.000000 300.000000 300.000000 \n",
"mean 0.008433 0.488467 0.596233 0.413967 \n",
"std 1.316188 1.594445 1.280634 0.910204 \n",
"min -2.630000 -3.870000 -2.850000 -2.240000 \n",
"25% -0.932500 -0.392500 -0.290000 -0.140000 \n",
"50% -0.010000 0.835000 0.740000 0.400000 \n",
"75% 0.882500 1.640000 1.545000 1.020000 \n",
"max 2.670000 3.380000 3.700000 2.920000 \n",
"\n",
"[8 rows x 30 columns]"
]
},
"execution_count": 83,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_output_new('runs','behavior',batch_sz=300,I_b=n_I_train,B_b=n_B_train,M_b=n_M_train).describe()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [NOTE] If component is other, the function should return empty. I just revised this file"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Process started at 2022-10-31 08:00:14.898664\n",
"Process ended at 2022-10-31 10:06:17.497273\n"
]
}
],
"source": [
"print('Process started at',datetime.now() )\n",
"df=pd.read_csv('allidentbehmod.csv')\n",
"def epa_file(dff):\n",
" df=dff.loc[dff.component=='identity',].copy() \n",
" df['term']=df['term'].apply(lambda x: x.replace('_', ' '))\n",
" df['E'],df['P'],df['A'],df['E_std'],df['P_std'],df['A_std']=np.vstack(df.apply(lambda x: get_output_new(x.term,x.component,batch_sz=300,I_b=n_I_train,B_b=n_B_train,M_b=n_M_train)[['EEA','EPA','EAA']].agg(['mean','std']).values.reshape(-1).round(decimals=2),axis=1).values).transpose()\n",
" df_b=dff.loc[dff.component=='behavior',].copy() \n",
" df_b['term']=df_b['term'].apply(lambda x: x.replace('_', ' '))\n",
" df_b['E'],df_b['P'],df_b['A'],df_b['E_std'],df_b['P_std'],df_b['A_std']=np.vstack(df_b.apply(lambda x: get_output_new(x.term,x.component,batch_sz=300,I_b=n_I_train,B_b=n_B_train,M_b=n_M_train)[['EEB','EPB','EAB']].agg(['mean','std']).values.reshape(-1).round(decimals=2),axis=1).values).transpose()\n",
" df_m=dff.loc[dff.component=='modifier',].copy() \n",
" df_m['term']=df_m['term'].apply(lambda x: x.replace('_', ' '))\n",
" df_m['E'],df_m['P'],df_m['A'],df_m['E_std'],df_m['P_std'],df_m['A_std']=np.vstack(df_m.apply(lambda x: get_output_new(x.term,x.component,batch_sz=300,I_b=n_I_train,B_b=n_B_train,M_b=n_M_train)[['EEMA','EPMA','EAMA']].agg(['mean','std']).values.reshape(-1).round(decimals=2),axis=1).values).transpose()\n",
"\n",
" return(pd.concat([df,df_b,df_m],axis=0))\n",
"df_file=epa_file(df) #started 17:42\n",
"print('Process ended at',datetime.now() )\n",
"\n",
"df_file.to_csv('allidentbehmod_val_10_30_22.csv')"
]
},
{
"cell_type": "code",
"execution_count": 212,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Process started at 2022-10-31 16:05:10.060873\n",
"Process ended at 2022-10-31 18:14:08.643715\n"
]
}
],
"source": [
"print('Process started at',datetime.now() )\n",
"df=pd.read_csv('allidentbehmod.csv').dropna()\n",
"def epa_file2(dff):\n",
" df=dff.loc[dff.component!='other',].copy() \n",
" df['term']=df['term'].apply(lambda x: x.replace('_', ' '))\n",
" df['E'],df['P'],df['A'],df['E_std'],df['P_std'],df['A_std']=np.vstack(df.apply(lambda x: get_output_agg(x.term,x.component,batch_sz=300,I_b=n_I_train,B_b=n_B_train,M_b=n_M_train)[['EE', 'EP', 'EA']].agg(['mean','std']).values.reshape(-1).round(decimals=2),axis=1).values).transpose()\n",
"\n",
" return(df)\n",
"df_file2=epa_file2(df) #started 17:42\n",
"print('Process ended at',datetime.now() )\n",
"\n",
"df_file2.to_csv('allidentbehmod_v2_10_30_22.csv')"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Process started at 2022-11-02 11:18:36.739996\n",
"Process ended at 2022-11-02 13:26:08.280003\n"
]
}
],
"source": [
"print('Process started at',datetime.now() )\n",
"df=pd.read_csv('allidentbehmod2.csv').dropna()\n",
"def epa_file2(dff):\n",
" df=dff.loc[dff.component!='other',].copy() \n",
" df['term']=df['term'].apply(lambda x: x.replace('_', ' '))\n",
" df['E'],df['P'],df['A'],df['E_std'],df['P_std'],df['A_std']=np.vstack(df.apply(lambda x: get_output_agg(x.term,x.component,batch_sz=300,I_b=n_I_train,B_b=n_B_train,M_b=n_M_train)[['EE', 'EP', 'EA']].agg(['mean','std']).values.reshape(-1).round(decimals=2),axis=1).values).transpose()\n",
"\n",
" return(df)\n",
"df_file2=epa_file2(df) #started 17:42\n",
"print('Process ended at',datetime.now() )\n",
"\n",
"# df_file2.to_csv('allidentbehmod_v2_11_02_22.csv')\n",
"\n",
"df_org=pd.read_csv('allidentbehmod_org.csv').dropna()\n",
"pd.merge(df_org.rename(columns={'term':'Original_term'}),\n",
" df_file2[['Unnamed: 0', 'term', 'E', 'P', 'A', 'E_std', 'P_std','A_std']],\n",
" right_on = ['Unnamed: 0'], left_on = ['Unnamed: 0'],\n",
" how='outer').rename(columns={'Unnamed: 0': 'index'})[['index', 'Original_term','term', 'component', 'n', 'E', 'P', 'A',\n",
" 'E_std', 'P_std', 'A_std']].to_csv('allidentbehmod_v2_11_02_22.csv')"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
"outputs": [],
"source": [
"pd.merge(df_org.rename(columns={'term':'Original_term'}),\n",
" df_file2[['Unnamed: 0', 'term', 'E', 'P', 'A', 'E_std', 'P_std','A_std']],\n",
" right_on = ['Unnamed: 0'], left_on = ['Unnamed: 0'],\n",
" how='outer').rename(columns={'Unnamed: 0': 'index'})[['index', 'Original_term','term', 'component', 'n', 'E', 'P', 'A',\n",
" 'E_std', 'P_std', 'A_std']].to_csv('allidentbehmod_v2_11_02_22.csv')"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4510"
]
},
"execution_count": 97,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(pd.merge(df_org.rename(columns={'term':'Original_term'}),\n",
" df_file2[['Unnamed: 0', 'term', 'E', 'P', 'A', 'E_std', 'P_std','A_std']],\n",
" right_on = ['Unnamed: 0'], left_on = ['Unnamed: 0'],\n",
" how='outer').rename(columns={'Unnamed: 0': 'index'})[['index', 'Original_term','term', 'component', 'n', 'E', 'P', 'A',\n",
" 'E_std', 'P_std', 'A_std']])"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4503"
]
},
"execution_count": 92,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pd.concat(df_file2[['Unnamed: 0', 'term', 'component', 'n', 'E', 'P', 'A', 'E_std', 'P_std','A_std']],axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Process started at 2022-10-31 00:12:27.334780\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_8232\\2045016687.py:6: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" df['term']=df.term.apply(lambda x: x.replace('_', ' '))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Process ended at 2022-10-31 00:52:27.760690\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_8232\\2045016687.py:7: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" df['E'],df['P'],df['A'],df['E_std'],df['P_std'],df['A_std']=np.vstack(df.apply(lambda x: get_output_new(x.term,x.component,batch_sz=300,I_b=n_I_train,B_b=n_B_train,M_b=n_M_train)[['EEA','EPA','EAA']].agg(['mean','std']).values.reshape(-1).round(decimals=2),axis=1).values).transpose()\n",
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_8232\\2045016687.py:7: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" df['E'],df['P'],df['A'],df['E_std'],df['P_std'],df['A_std']=np.vstack(df.apply(lambda x: get_output_new(x.term,x.component,batch_sz=300,I_b=n_I_train,B_b=n_B_train,M_b=n_M_train)[['EEA','EPA','EAA']].agg(['mean','std']).values.reshape(-1).round(decimals=2),axis=1).values).transpose()\n",
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_8232\\2045016687.py:7: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" df['E'],df['P'],df['A'],df['E_std'],df['P_std'],df['A_std']=np.vstack(df.apply(lambda x: get_output_new(x.term,x.component,batch_sz=300,I_b=n_I_train,B_b=n_B_train,M_b=n_M_train)[['EEA','EPA','EAA']].agg(['mean','std']).values.reshape(-1).round(decimals=2),axis=1).values).transpose()\n",
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_8232\\2045016687.py:7: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" df['E'],df['P'],df['A'],df['E_std'],df['P_std'],df['A_std']=np.vstack(df.apply(lambda x: get_output_new(x.term,x.component,batch_sz=300,I_b=n_I_train,B_b=n_B_train,M_b=n_M_train)[['EEA','EPA','EAA']].agg(['mean','std']).values.reshape(-1).round(decimals=2),axis=1).values).transpose()\n",
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_8232\\2045016687.py:7: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" df['E'],df['P'],df['A'],df['E_std'],df['P_std'],df['A_std']=np.vstack(df.apply(lambda x: get_output_new(x.term,x.component,batch_sz=300,I_b=n_I_train,B_b=n_B_train,M_b=n_M_train)[['EEA','EPA','EAA']].agg(['mean','std']).values.reshape(-1).round(decimals=2),axis=1).values).transpose()\n",
"C:\\Users\\mosta\\AppData\\Local\\Temp\\ipykernel_8232\\2045016687.py:7: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" df['E'],df['P'],df['A'],df['E_std'],df['P_std'],df['A_std']=np.vstack(df.apply(lambda x: get_output_new(x.term,x.component,batch_sz=300,I_b=n_I_train,B_b=n_B_train,M_b=n_M_train)[['EEA','EPA','EAA']].agg(['mean','std']).values.reshape(-1).round(decimals=2),axis=1).values).transpose()\n"
]
}
],
"source": [
"# print('Process started at',datetime.now() )\n",
"# df=pd.read_csv('allidentbehmod.csv')\n",
"# def epa_file(df):\n",
"# df=df.loc[df.component=='behavior',]\n",
" \n",
"# df['term']=df.term.apply(lambda x: x.replace('_', ' '))\n",
"# df['E'],df['P'],df['A'],df['E_std'],df['P_std'],df['A_std']=np.vstack(df.apply(lambda x: get_output_new(x.term,x.component,batch_sz=300,I_b=n_I_train,B_b=n_B_train,M_b=n_M_train)[['EEA','EPA','EAA']].agg(['mean','std']).values.reshape(-1).round(decimals=2),axis=1).values).transpose()\n",
"# return(df)\n",
"# df_file=epa_file(df) #started 17:42\n",
"# print('Process ended at',datetime.now() )\n",
"\n",
"# df_file.to_csv('allidentbehmod_val_10_30_22_beh.csv')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## TODO\n",
"\n",
"Other than min and max distance for EA we can use centroid distance and use it ..."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df_f.round(decimals=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Experiment design for Inverse Impression Change Equation (IICE)\n",
"I can use any word to find inverse impression change equations for words with given EPA values. For example, if we have 'mother' in the identity dictionary, the rest of the terms in MABMO grammar can be selected from outside the dictionary\n",
"\n",
"We may like to see how neural networks with a pipeline similar to BERT work here. Evan we may use different approaches for feature selection (by feature, I mean words in MABMO grammar). In other worrd, it is different from the last problem in the sense that training set can be potentially unlimited and our goal is to find something equivalent to experiment design for IICE. Then we use these words and run a survey to find new impression change equations....\n",
"\n",
"I should look for the problems people reported in applicability of ICE"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df_f.loc[df_f.EEA>3,].round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df_f.loc[df_f.EEA<1.87,].round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"EPA_sents('happy doctor help wonderful mother')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"EPA_sents('sadist doctor kill worst abuser')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"EPA_sents('sadist doctor kill worst abuser')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Should i use bert output for inverse impression change? Or at least tuning layers?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df_f[['EEA','EPA','EAA']].describe().round(decimals=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Compare with word2vec"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from gensim.models.keyedvectors import KeyedVectors\n",
"from pylab import *\n",
"# !curl -O \"https://raw.githubusercontent.com/vitorcurtis/RWinOut/master/RWinOut.py\"\n",
"# %load_ext RWinOut\n",
"%load_ext rpy2.ipython"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df_train=pd.concat([I_train,B_train,M_train],axis=0)\n",
"df_val_tst=pd.concat([I_v,B_v,M_v,I_test,B_test,M_test],axis=0)\n",
"df_train[['term2', 'E', 'P', 'A']].rename(columns={'term2': 'term'}, inplace=False).to_csv('df_train.csv')\n",
"df_val_tst[['term2', 'E', 'P', 'A']].rename(columns={'term2': 'term'}, inplace=False).to_csv('df_val_tst.csv')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%R \n",
"library(\"ggplot2\")\n",
"library('dplyr') #\n",
"library(tidyverse)\n",
"\n",
"df_train <- read.csv('df_train.csv',stringsAsFactor=F)\n",
"df_val_tst <- read.csv('df_val_tst.csv',stringsAsFactor=F)\n",
"# Combined_Surveyor_2015_behaviors=read.csv(\"http://affectcontroltheory.org///wp-content/uploads/2019/10/FullSurveyorInteract_Behaviors.csv\",stringsAsFactor=F)\n",
"# Combined_Surveyor_2015_all=rbind(Combined_Surveyor_2015_behaviors,Combined_Surveyor_2015_modifier,Combined_Surveyor_2015)\n",
"# s <-capture.output(summary(Combined_Surveyor_2015_all))\n",
"df<-read.csv(file=\"D:/ACT/word embeding/GoogleNews_Embedding.csv\", header=TRUE,row.names=1, sep=\",\")\n",
"df<-as.matrix(data.frame(df))\n",
"\n",
"nrm <- function(x) x/(sqrt(sum(x^2)))\n",
"df<-t(apply(df,1,nrm))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%R \n",
"\n",
"length(train_list)\n",
"# length(df_train$term)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%R \n",
"# ACT_data_test=Combined_Surveyor_2015_all\n",
"# set.seed(500)\n",
"train_list=df_train$term\n",
"train_list=train_list[train_list %in% row.names(df)]\n",
"\n",
"val_tst_list=df_val_tst[,]$term\n",
"val_tst_list=val_tst_list[val_tst_list %in% row.names(df)]\n",
"\n",
"\n",
"# new_df=ACT_data_test[ACT_data_test$term %in% term_list,]\n",
"# smp_size <- floor(0.85 * nrow(new_df))\n",
"# set.seed(123)\n",
"# train_ind <- sample(seq_len(nrow(new_df)), size = smp_size)\n",
"train2 <- filter(df_train, term %in% train_list)#new_df[train_ind, ]\n",
"# test2 <- new_df[-train_ind, ]\n",
"test2 <- filter(df_val_tst, term %in% val_tst_list) \n",
"\n",
"## ***************************** Find the mapping **************\n",
"R=((solve((t(df[train2$term,]))%*%(df[train2$term,])))%*%(t(df[train2$term,])))%*%as.matrix(train2[,c('E','P','A')])\n",
"test2[,c('ER','PR','AR')]=df[test2$term,]%*%R\n",
"train2[,c('ER','PR','AR')]=df[train2$term,]%*%R\n",
"\n",
"## ***************************** Regression on the mapping **************\n",
"model_E <- step(glm(E ~(ER+PR+AR)^2, data = train2), trace=0) \n",
"test2$pred_E <- model_E %>% predict(test2)\n",
"train2$pred_E <- model_E %>% predict(train2)\n",
"model_P <- step(glm(P ~(ER+PR+AR)^2, data = train2), trace=0)\n",
"test2$pred_P <- model_P %>% predict(test2)\n",
"train2$pred_P <- model_P %>% predict(train2)\n",
"model <- step(glm(A ~(ER+PR+AR)^2, data = train2), trace=0)\n",
"test2$pred_A <- model %>% predict(test2)\n",
"train2$pred_A <- model %>% predict(train2)\n",
"\n",
"write.csv(train2,'R_train_out.csv', row.names = FALSE)\n",
"write.csv(test2,'R_tst_val_out.csv', row.names = FALSE)\n",
" \n",
"s <-capture.output(cor(test2[,c(\"E\",\"P\",'A','pred_E','pred_P','pred_A')])[4:6,1:3])\n",
"s"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"R_train_out = pd.read_csv(\"R_train_out.csv\")[['term', 'E', 'P', 'A', 'pred_E', 'pred_P','pred_A']]\n",
"R_tst_val_out = pd.read_csv(\"R_tst_val_out.csv\")[['term', 'E', 'P', 'A','pred_E', 'pred_P','pred_A']]\n",
"R_tst_val_out.corr().iloc[1:,3:].round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"R_train_out.corr().iloc[:,3:].round(decimals=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"w2v_test_i=R_tst_val_out[R_tst_val_out.term.isin(I_test.term)].rename(columns={\n",
" 'pred_E': 'EEI','pred_P': 'EPI','pred_A': 'EAI'}, inplace=False)\n",
"w2v_test_b=R_tst_val_out[R_tst_val_out.term.isin(B_test.term)].rename(columns={\n",
" 'pred_E': 'EEB','pred_P': 'EPB','pred_A': 'EAB'}, inplace=False)\n",
"w2v_test_m=R_tst_val_out[R_tst_val_out.term.isin(M_test.term)].rename(columns={\n",
" 'pred_E': 'EEM','pred_P': 'EPM','pred_A': 'EAM'}, inplace=False)\n",
"w2v_v_i=R_tst_val_out[R_tst_val_out.term.isin(I_v.term)].rename(columns={\n",
" 'pred_E': 'EEI','pred_P': 'EPI','pred_A': 'EAI'}, inplace=False)\n",
"w2v_v_b=R_tst_val_out[R_tst_val_out.term.isin(B_v.term)].rename(columns={\n",
" 'pred_E': 'EEB','pred_P': 'EPB','pred_A': 'EAB'}, inplace=False)\n",
"w2v_v_m=R_tst_val_out[R_tst_val_out.term.isin(M_v.term)].rename(columns={\n",
" 'pred_E': 'EEM','pred_P': 'EPM','pred_A': 'EAM'}, inplace=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def w2v_cor(dfi=w2v_v_i,dfb=w2v_v_b,dfm=w2v_v_m):\n",
" return(pd.concat([dfi.corr().iloc[:3,3:],dfb.corr().iloc[:3,3:],dfm.corr().iloc[:3,3:]],axis=1).round(decimals=2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"np.absolute(w2v_v_i.E-w2v_v_i.EEI)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def MAE_w2v(dfi=w2v_v_i,dfb=w2v_v_b,dfm=w2v_v_m):\n",
" dfMAE=pd.concat([np.absolute(dfi.E-dfi.EEI),np.absolute(dfi.P-dfi.EPI),np.absolute(dfi.A-dfi.EAI), \n",
" np.absolute(dfb.E-dfb.EEB),np.absolute(dfb.P-dfb.EPB),np.absolute(dfb.A-dfb.EAB),\n",
" np.absolute(dfm.E-dfm.EEM),np.absolute(dfm.P-dfm.EPM),np.absolute(dfm.A-dfm.EAM)],axis=1).set_axis(\n",
" ['EEI','EPI', 'EAI', 'EEB', 'EPB', 'EAB','EEM', 'EPM', 'EAM'], axis=1, inplace=False) \n",
" return(dfMAE.describe().round(decimals=2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def RMSE_w2v(dfi=w2v_v_i,dfb=w2v_v_b,dfm=w2v_v_m):\n",
" dfMAE=pd.DataFrame([np.sqrt(np.mean((dfi.E-dfi.EEI)**2,axis=0)),\n",
" np.sqrt(np.mean((dfi.P-dfi.EPI)**2,axis=0)),\n",
" np.sqrt(np.mean((dfi.A-dfi.EAI)**2,axis=0)), \n",
" np.sqrt(np.mean((dfb.E-dfb.EEB)**2,axis=0)),\n",
" np.sqrt(np.mean((dfb.P-dfb.EPB)**2,axis=0)),\n",
" np.sqrt(np.mean((dfb.A-dfb.EAB)**2,axis=0)),\n",
" np.sqrt(np.mean((dfm.E-dfm.EEM)**2,axis=0)),\n",
" np.sqrt(np.mean((dfm.P-dfm.EPM)**2,axis=0)),\n",
" np.sqrt(np.mean((dfm.A-dfm.EAM)**2,axis=0))]).set_axis(\n",
" ['EEI','EPI', 'EAI', 'EEB', 'EPB', 'EAB','EEM', 'EPM', 'EAM'], axis=0).set_axis(\n",
" ['RMSE'], axis=1) \n",
" return(dfMAE.round(decimals=2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"w2v_cor()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"w2v_cor(w2v_test_i,w2v_test_b,w2v_test_m)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"MAE_w2v()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"MAE_w2v(w2v_test_i,w2v_test_b,w2v_test_m)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"RMSE_w2v()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"RMSE_w2v(w2v_test_i,w2v_test_b,w2v_test_m)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Use the epa estimates as features for predicting emotions in a chatbot application"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0.7587553858757019,\n",
" 0.948117733001709,\n",
" 0.6421530246734619,\n",
" 0.7107032537460327,\n",
" 0.34743067622184753,\n",
" 0.3779359757900238,\n",
" 1.3710752725601196,\n",
" -0.14437827467918396,\n",
" 0.23478737473487854,\n",
" 0.7699854969978333,\n",
" 0.8394991755485535,\n",
" 0.5351914167404175,\n",
" 0.04319348186254501,\n",
" -0.7777518630027771,\n",
" -0.14642126858234406]"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def EPA_sents_lst(sent):\n",
" return(bert_predict(bert_regressor.to(device), sent_ldr(sent)).cpu().tolist()[0])\n",
"\n",
"EPA_sents_lst('hey, how are you?')"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Process started at 2023-01-18 08:41:41.557185\n",
"Process ended at 2023-01-18 09:36:29.620465\n"
]
}
],
"source": [
"# def EPA_sents_lst(sent):\n",
"# q=sent_ldr(sent)\n",
"# predictions=bert_predict(bert_regressor.to(device), q)\n",
"# return((scaler_M.inverse_transform(predictions[:,0:3].cpu())).tolist()[0] + \\\n",
"# (scaler_I.inverse_transform(predictions[:,3:6].cpu())).tolist()[0] + \\\n",
"# (scaler_B.inverse_transform(predictions[:,6:9].cpu())).tolist()[0] + \\\n",
"# (scaler_M.inverse_transform(predictions[:,9:12].cpu())).tolist()[0] + \\\n",
"# (scaler_I.inverse_transform(predictions[:,12:15].cpu())).tolist()[0])\n",
"def EPA_sents_lst(sent):\n",
" return(bert_predict(bert_regressor.to(device), sent_ldr(sent)).cpu().tolist()[0])\n",
"print('Process started at',datetime.now() )\n",
"\n",
"df_test_c=pd.read_pickle('E:/Chatbot/Emotion/emowoz-public-main/baselines/BERT/df_test_multiwoz_carry.pkl')\n",
"df_val_c=pd.read_pickle('E:/Chatbot/Emotion/emowoz-public-main/baselines/BERT/df_val_multiwoz_carry.pkl')\n",
"df_train_c=pd.read_pickle('E:/Chatbot/Emotion/emowoz-public-main/baselines/BERT/df_train_multiwoz_carry.pkl')\n",
"\n",
"df_train_c['BERTNN_text']=df_train_c.text.apply(lambda x: EPA_sents_lst(x))\n",
"df_val_c['BERTNN_text']=df_val_c.text.apply(lambda x: EPA_sents_lst(x))\n",
"df_test_c['BERTNN_text']=df_test_c.text.apply(lambda x: EPA_sents_lst(x))\n",
"\n",
"df_train_c['BERTNN_reply']=df_train_c.reply.apply(lambda x: EPA_sents_lst(x))\n",
"df_val_c['BERTNN_reply']=df_val_c.reply.apply(lambda x: EPA_sents_lst(x))\n",
"df_test_c['BERTNN_reply']=df_test_c.reply.apply(lambda x: EPA_sents_lst(x))\n",
"\n",
"df_test_c.to_pickle('E:/Chatbot/Emotion/emowoz-public-main/baselines/BERT/df_test_multiwoz_carry.pkl')\n",
"df_val_c.to_pickle('E:/Chatbot/Emotion/emowoz-public-main/baselines/BERT/df_val_multiwoz_carry.pkl')\n",
"df_train_c.to_pickle('E:/Chatbot/Emotion/emowoz-public-main/baselines/BERT/df_train_multiwoz_carry.pkl')\n",
"\n",
"print('Process ended at',datetime.now() )\n"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Process started at 2023-01-18 22:05:12.234557\n",
"Process ended at 2023-01-18 22:44:32.764497\n"
]
}
],
"source": [
"print('Process started at',datetime.now() )\n",
"df_test_=pd.read_pickle('E:/Chatbot/ijcnlp_dailydialog/df_daily_test_multiwoz_carry.pkl')\n",
"df_val_=pd.read_pickle('E:/Chatbot/ijcnlp_dailydialog/df_daily_val_multiwoz_carry.pkl')\n",
"df_train_=pd.read_pickle('E:/Chatbot/ijcnlp_dailydialog/df_daily_train_multiwoz_carry.pkl')\n",
"\n",
"df_train_['BERTNN_text']=df_train_.text.apply(lambda x: EPA_sents_lst(x))\n",
"df_val_['BERTNN_text']=df_val_.text.apply(lambda x: EPA_sents_lst(x))\n",
"df_test_['BERTNN_text']=df_test_.text.apply(lambda x: EPA_sents_lst(x))\n",
"\n",
"df_test_.to_pickle('E:/Chatbot/ijcnlp_dailydialog/df_daily_test_multiwoz_carry.pkl')\n",
"df_val_.to_pickle('E:/Chatbot/ijcnlp_dailydialog/df_daily_val_multiwoz_carry.pkl')\n",
"df_train_.to_pickle('E:/Chatbot/ijcnlp_dailydialog/df_daily_train_multiwoz_carry.pkl')\n",
"\n",
"\n",
"print('Process ended at',datetime.now() )\n"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Process started at 2023-01-19 14:10:01.182188\n",
"Process ended at 2023-01-19 14:53:31.330811\n"
]
}
],
"source": [
"print('Process started at',datetime.now() )\n",
"df_test_=pd.read_pickle('E:/Chatbot/ijcnlp_dailydialog/DD_TS_test_carry.pkl')\n",
"df_val_=pd.read_pickle('E:/Chatbot/ijcnlp_dailydialog/DD_TS_val_carry.pkl')\n",
"df_train_=pd.read_pickle('E:/Chatbot/ijcnlp_dailydialog/DD_TS_train_carry.pkl')\n",
"\n",
"df_train_['BERTNN_text']=df_train_.text.apply(lambda x: EPA_sents_lst(x))\n",
"df_val_['BERTNN_text']=df_val_.text.apply(lambda x: EPA_sents_lst(x))\n",
"df_test_['BERTNN_text']=df_test_.text.apply(lambda x: EPA_sents_lst(x))\n",
"\n",
"df_test_.to_pickle('E:/Chatbot/ijcnlp_dailydialog/DD_TS_test_carry.pkl')\n",
"df_val_.to_pickle('E:/Chatbot/ijcnlp_dailydialog/DD_TS_val_carry.pkl')\n",
"df_train_.to_pickle('E:/Chatbot/ijcnlp_dailydialog/DD_TS_train_carry.pkl')\n",
"\n",
"\n",
"print('Process ended at',datetime.now() )\n"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"def EPA_sents_lst(sent):\n",
" q=sent_ldr(sent)\n",
" predictions=bert_predict(bert_regressor.to(device), q)\n",
" return((scaler_M.inverse_transform(predictions[:,0:3].cpu())).tolist()[0] + \\\n",
" (scaler_I.inverse_transform(predictions[:,3:6].cpu())).tolist()[0] + \\\n",
" (scaler_B.inverse_transform(predictions[:,6:9].cpu())).tolist()[0] + \\\n",
" (scaler_M.inverse_transform(predictions[:,9:12].cpu())).tolist()[0] + \\\n",
" (scaler_I.inverse_transform(predictions[:,12:15].cpu())).tolist()[0])\n",
"\n",
"\n",
"df_train_['BERTNN_text2']=df_train_.text.apply(lambda x: EPA_sents_lst(x))\n",
"df_val_['BERTNN_text2']=df_val_.text.apply(lambda x: EPA_sents_lst(x))\n",
"df_test_['BERTNN_text2']=df_test_.text.apply(lambda x: EPA_sents_lst(x))\n",
"\n",
"df_test_.to_pickle('E:/Chatbot/ijcnlp_dailydialog/df_daily_test_multiwoz_carry.pkl')\n",
"df_val_.to_pickle('E:/Chatbot/ijcnlp_dailydialog/df_daily_val_multiwoz_carry.pkl')\n",
"df_train_.to_pickle('E:/Chatbot/ijcnlp_dailydialog/df_daily_train_multiwoz_carry.pkl')"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" EE | \n",
" EP | \n",
" EA | \n",
" E | \n",
" P | \n",
" A | \n",
"
\n",
" \n",
" \n",
" \n",
" | EE | \n",
" 1.0 | \n",
" 0.553*** | \n",
" -0.073 | \n",
" 0.915*** | \n",
" 0.521*** | \n",
" -0.049 | \n",
"
\n",
" \n",
" | EP | \n",
" 0.553*** | \n",
" 1.0 | \n",
" 0.178** | \n",
" 0.507*** | \n",
" 0.841*** | \n",
" 0.155** | \n",
"
\n",
" \n",
" | EA | \n",
" -0.073 | \n",
" 0.178** | \n",
" 1.0 | \n",
" -0.056 | \n",
" 0.188** | \n",
" 0.739*** | \n",
"
\n",
" \n",
" | E | \n",
" 0.915*** | \n",
" 0.507*** | \n",
" -0.056 | \n",
" 1.0 | \n",
" 0.578*** | \n",
" -0.002 | \n",
"
\n",
" \n",
" | P | \n",
" 0.521*** | \n",
" 0.841*** | \n",
" 0.188** | \n",
" 0.578*** | \n",
" 1.0 | \n",
" 0.243*** | \n",
"
\n",
" \n",
" | A | \n",
" -0.049 | \n",
" 0.155** | \n",
" 0.739*** | \n",
" -0.002 | \n",
" 0.243*** | \n",
" 1.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" EE EP EA E P A\n",
"EE 1.0 0.553*** -0.073 0.915*** 0.521*** -0.049\n",
"EP 0.553*** 1.0 0.178** 0.507*** 0.841*** 0.155**\n",
"EA -0.073 0.178** 1.0 -0.056 0.188** 0.739***\n",
"E 0.915*** 0.507*** -0.056 1.0 0.578*** -0.002\n",
"P 0.521*** 0.841*** 0.188** 0.578*** 1.0 0.243***\n",
"A -0.049 0.155** 0.739*** -0.002 0.243*** 1.0"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pval_cor_spearman(pd.concat([df_beh_v3,df_mod_v2,df_ident_v2],axis=0))"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "BERT_multiple_values.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "py39latest",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
},
"vscode": {
"interpreter": {
"hash": "8abe35e408fc4788c839a6a78b4e65acd91a92940917dbaad05cb9d29446d025"
}
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"0784db0a7fad434ea5c58e56f3f9db48": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_a9a34097523c4bcf9e6b100dc50bc3bf",
"placeholder": "",
"style": "IPY_MODEL_a997f2a2c54f4e1a840254b0a4ebfc2b",
"value": "Downloading: 100%"
}
},
"07bf536d70f3420282feb7b6c2b7c502": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"081eda02be0b40edad3cbab832a79667": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"09adee9a7b3e42a3801db22b06be34a7": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"0f638f53620245d3b5d67f172a0d4abc": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_1e8f16fa716f466f8fd8772d765ad7a1",
"placeholder": "",
"style": "IPY_MODEL_c928829cd6394d50b991e569bfd5caf2",
"value": " 455k/455k [00:00<00:00, 5.83MB/s]"
}
},
"11075e61698a485896267f188e133958": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"16a8c19ffa4c4e55978cd6ee01c1f9f1": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_1e89ca8823414adcb55b57953bd2b25b",
"placeholder": "",
"style": "IPY_MODEL_77900d6e71924c998a5bd41d2097e4a0",
"value": "Downloading: 100%"
}
},
"1ae4f6f212d8403ebf22ba448ce6a39d": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_c3fbd6eef6014346a51db8a808867835",
"max": 28,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_8aefab5c69bb49e68aa1a7ce56c219cf",
"value": 28
}
},
"1e89ca8823414adcb55b57953bd2b25b": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"1e8f16fa716f466f8fd8772d765ad7a1": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"1ea6b565ffe945d5a4c527f5607e52fc": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"2cf178b819ca4a3db05cc9ad0d2490e4": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"2eb6df39faa741a8821554482cae1f88": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_07bf536d70f3420282feb7b6c2b7c502",
"placeholder": "",
"style": "IPY_MODEL_09adee9a7b3e42a3801db22b06be34a7",
"value": " 420M/420M [00:12<00:00, 38.7MB/s]"
}
},
"355008ec0b0e4fee9ad82cff6fce97ef": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_1ea6b565ffe945d5a4c527f5607e52fc",
"placeholder": "",
"style": "IPY_MODEL_8deb2d8214b54a6ab0f7fac14c78b0dc",
"value": " 570/570 [00:00<00:00, 12.3kB/s]"
}
},
"3e28697501b2497a94ca53174a5d647c": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"4031a2bedcad41e7af9e574c1a6a3b4d": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"4036232b409b467eaa5b9e4f2ddcced4": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"4c94071de1ba46479762d0576e0a5dac": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"560910e1c5454e9c9ca6085d67055c50": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"59909d8ca5af415aae0e2e7f64001d78": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_860f138d9a4c44729176915246acfe3c",
"placeholder": "",
"style": "IPY_MODEL_2cf178b819ca4a3db05cc9ad0d2490e4",
"value": " 28.0/28.0 [00:00<00:00, 504B/s]"
}
},
"5af534d8e04e4e769eecb9b71ce097b3": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_11075e61698a485896267f188e133958",
"max": 466062,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_4036232b409b467eaa5b9e4f2ddcced4",
"value": 466062
}
},
"5b4469b58a744da1a15ea28a918ad1aa": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_817e41eac37448fc82e1ab4a4eaa7f54",
"IPY_MODEL_eb2d14c36ee74587afd67d216dcf5ef8",
"IPY_MODEL_b95438470d39403491e43171994c2468"
],
"layout": "IPY_MODEL_081eda02be0b40edad3cbab832a79667"
}
},
"61c8d744d6be4250a9ecacd387114314": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_0784db0a7fad434ea5c58e56f3f9db48",
"IPY_MODEL_67b617180f634cd7a87825bd4b717f12",
"IPY_MODEL_355008ec0b0e4fee9ad82cff6fce97ef"
],
"layout": "IPY_MODEL_ed3def74372245c38b386b3b40ce4efc"
}
},
"67b617180f634cd7a87825bd4b717f12": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_4c94071de1ba46479762d0576e0a5dac",
"max": 570,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_dc05deaed5f04aea9270527e989e378d",
"value": 570
}
},
"75e5886c17654c76ac6ce8e6fe258ed5": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_dd498127d1a54e6da1e92698cea49426",
"max": 440473133,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_4031a2bedcad41e7af9e574c1a6a3b4d",
"value": 440473133
}
},
"77900d6e71924c998a5bd41d2097e4a0": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"817e41eac37448fc82e1ab4a4eaa7f54": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_3e28697501b2497a94ca53174a5d647c",
"placeholder": "",
"style": "IPY_MODEL_d1fc7c2fafd74effbf11b42c74ffb7b0",
"value": "Downloading: 100%"
}
},
"81b86cf2fbb443368ad2584d3e80ea7a": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"858ecd3c04e0426ead38774836b5eb25": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"860f138d9a4c44729176915246acfe3c": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"885e801a160c405d83d42f9dbea58ced": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"8aefab5c69bb49e68aa1a7ce56c219cf": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"8deb2d8214b54a6ab0f7fac14c78b0dc": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"942a7604d1504756b8b4bf1fe7cc93b6": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_a616ae222e66416d8becaa5f5eedfe4a",
"IPY_MODEL_75e5886c17654c76ac6ce8e6fe258ed5",
"IPY_MODEL_2eb6df39faa741a8821554482cae1f88"
],
"layout": "IPY_MODEL_560910e1c5454e9c9ca6085d67055c50"
}
},
"9df19fff43aa44d1b08816730ec83029": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"9faf62522de4493d9ee007edde7601de": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_16a8c19ffa4c4e55978cd6ee01c1f9f1",
"IPY_MODEL_1ae4f6f212d8403ebf22ba448ce6a39d",
"IPY_MODEL_59909d8ca5af415aae0e2e7f64001d78"
],
"layout": "IPY_MODEL_b8346621449d4a09a0db59845e1e292e"
}
},
"a082e16c77c7446086f61b4125867a82": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_bb25c0c52c724b0f9af44199f5843ab4",
"IPY_MODEL_5af534d8e04e4e769eecb9b71ce097b3",
"IPY_MODEL_0f638f53620245d3b5d67f172a0d4abc"
],
"layout": "IPY_MODEL_e8c08401834f4724b1abcee5ad17293d"
}
},
"a616ae222e66416d8becaa5f5eedfe4a": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_ca6befc2c114477dabcfe90808879956",
"placeholder": "",
"style": "IPY_MODEL_858ecd3c04e0426ead38774836b5eb25",
"value": "Downloading: 100%"
}
},
"a997f2a2c54f4e1a840254b0a4ebfc2b": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"a9a34097523c4bcf9e6b100dc50bc3bf": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"b8346621449d4a09a0db59845e1e292e": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"b95438470d39403491e43171994c2468": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_9df19fff43aa44d1b08816730ec83029",
"placeholder": "",
"style": "IPY_MODEL_885e801a160c405d83d42f9dbea58ced",
"value": " 226k/226k [00:00<00:00, 3.95MB/s]"
}
},
"bb25c0c52c724b0f9af44199f5843ab4": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_e8433bee864a4dcaa55f9a20345c79a5",
"placeholder": "",
"style": "IPY_MODEL_81b86cf2fbb443368ad2584d3e80ea7a",
"value": "Downloading: 100%"
}
},
"c3fbd6eef6014346a51db8a808867835": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"c928829cd6394d50b991e569bfd5caf2": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"ca6befc2c114477dabcfe90808879956": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"d1fc7c2fafd74effbf11b42c74ffb7b0": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"dc05deaed5f04aea9270527e989e378d": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"dc14d974b2734c6e997c415d885d1b79": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"dd498127d1a54e6da1e92698cea49426": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"e8433bee864a4dcaa55f9a20345c79a5": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"e8c08401834f4724b1abcee5ad17293d": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"eb2d14c36ee74587afd67d216dcf5ef8": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.5.0",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_f6953e66e3d445938d1650c850e6f45a",
"max": 231508,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_dc14d974b2734c6e997c415d885d1b79",
"value": 231508
}
},
"ed3def74372245c38b386b3b40ce4efc": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"f6953e66e3d445938d1650c850e6f45a": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.2.0",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
}
}
}
},
"nbformat": 4,
"nbformat_minor": 1
}