ia-nechaev commited on
Commit
bc78cab
·
verified ·
1 Parent(s): 0d0afb0

Update prediction_multilabel.py

Browse files
Files changed (1) hide show
  1. prediction_multilabel.py +4 -4
prediction_multilabel.py CHANGED
@@ -10,15 +10,15 @@ from sentence_transformers import util
10
  torch.manual_seed(1)
11
 
12
  # Load datasets
13
- df_inmemory = pd.read_csv('labeled.csv') # labeled text extracted from 230 CSR GRI reports, 150 International companies, 2017-2021 period
14
- df_paragraph = pd.read_csv('prediction_demo.csv') # paragraphs to predict the label, extracted from 1.2k CSR reports, 150 German PLC companies, 2010-2021 period, 645k paragraphs)
15
 
16
  # Load stored embeddings
17
- with open('embeddings_prediction.pkl', "rb") as f:
18
  stored_data = pickle.load(f)
19
  pred_embeddings = stored_data['parg_embeddings']
20
 
21
- with open('embeddings_labeled.pkl', "rb") as f:
22
  stored_data = pickle.load(f)
23
  embeddings = stored_data['sent_embeddings']
24
 
 
10
  torch.manual_seed(1)
11
 
12
  # Load datasets
13
+ df_inmemory = pd.read_csv('raw_data/labeled.csv') # labeled text extracted from 230 CSR GRI reports, 150 International companies, 2017-2021 period
14
+ df_paragraph = pd.read_csv('raw_data/prediction_demo.csv') # paragraphs to predict the label, extracted from 1.2k CSR reports, 150 German PLC companies, 2010-2021 period, 645k paragraphs)
15
 
16
  # Load stored embeddings
17
+ with open('embeddings/embeddings_prediction.pkl', "rb") as f:
18
  stored_data = pickle.load(f)
19
  pred_embeddings = stored_data['parg_embeddings']
20
 
21
+ with open('embeddings/embeddings_labeled.pkl', "rb") as f:
22
  stored_data = pickle.load(f)
23
  embeddings = stored_data['sent_embeddings']
24