Add files using upload-large-folder tool
Browse files- README.md +96 -3
- dir/Baseline4_advSample.csv +0 -0
- dir/DCS.py +19 -0
- dir/ECL_MST.py +132 -0
- dir/Feature_Collector.py +83 -0
- dir/GetFeatures_gen.py +0 -0
- dir/MatDB.py +18 -0
- dir/TestPoolMP.py +53 -0
- dir/TestPool_Unit.py +54 -0
- dir/TestPool_Unit_clique.py +57 -0
- dir/Train_bron.py +830 -0
- dir/Train_clique.py +769 -0
- dir/Train_n_Save_NNet.py +680 -0
- dir/__pycache__/TestPool_Unit_clique.cpython-36.pyc +0 -0
- dir/__pycache__/heap_n_clique.cpython-36.pyc +0 -0
- dir/bronclique.py +362 -0
- dir/bucket_by_conflicting_nodes_0.py +85 -0
- dir/bucket_by_conflicting_nodes_1.py +72 -0
- dir/bucket_by_conflicting_nodes_2.py +73 -0
- dir/bucket_by_conflicting_nodes_3.py +73 -0
- dir/bucket_by_conflicting_nodes_4.py +73 -0
- dir/bz2_counter.py +8 -0
- dir/cliq.csv +0 -0
- dir/datainspect.py +37 -0
- dir/dcs_skt_bzipper.py +196 -0
- dir/evaluate.py +81 -0
- dir/generate_dcs_and_skt_csv.py +66 -0
- dir/gt2.py +148 -0
- dir/heap_n_PrimMST.py +201 -0
- dir/heap_n_clique.py +325 -0
- dir/heldoutmatchtest.py +37 -0
- dir/lemmawise_labeller.py +77 -0
- dir/nnet.py +371 -0
- dir/pvb.p +0 -0
- dir/rom.txt +33 -0
- dir/rom2.txt +12 -0
- dir/romtoslp.py +35 -0
- dir/romtoslp.pyc +0 -0
- dir/sandhiRules.p +0 -0
- dir/sentences.py +319 -0
- dir/sh_TestPool_MP_clique.py +168 -0
- dir/test_clique.py +174 -0
- dir/unpack.py +39 -0
- dir/utilities.py +323 -0
- dir/verbs_vs_cngs_matrix_countonly.p +0 -0
- dir/weighted.py +81 -0
- dir/wordTypeCheckFunction.py +281 -0
- dir/word_definite.py +0 -0
- dir/word_definite[d_1500_BM2_v12].py +0 -0
- requirements.txt +68 -0
README.md
CHANGED
|
@@ -1,3 +1,96 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Word Segmentation in Sanskrit Using Energy Based Models
|
| 2 |
+
|
| 3 |
+
This is a reconstruction of the paper [Free as in Free Word Order: An Energy Based Model for Word Segmentation and Morphological Tagging in Sanskrit](https://aclanthology.org/D18-1276/)
|
| 4 |
+
|
| 5 |
+
You can refer to the original repository [here](https://zenodo.org/records/1035413#.W35s8hjhUUs)
|
| 6 |
+
|
| 7 |
+
## Folder Structure
|
| 8 |
+
```
|
| 9 |
+
├── dir/
|
| 10 |
+
├──wordsegmentation/
|
| 11 |
+
├── skt_dcs_DS.bz2_4K_bigram_mir_10K/
|
| 12 |
+
├── skt_dcs_DS.bz2_4K_bigram_mir_heldout/
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
## Prerequisites
|
| 16 |
+
* Python3
|
| 17 |
+
* scipy
|
| 18 |
+
* numpy
|
| 19 |
+
* csv
|
| 20 |
+
* pickle
|
| 21 |
+
* multiprocessing
|
| 22 |
+
* bz2
|
| 23 |
+
## Instructions for Training
|
| 24 |
+
1. Change your current directory to 'dir'
|
| 25 |
+
|
| 26 |
+
2. Run the file Train_clique.py by using the following command
|
| 27 |
+
|
| 28 |
+
```python
|
| 29 |
+
python Train_clique.py
|
| 30 |
+
```
|
| 31 |
+
**TRAINING OUTPUTS** are already stored in `dir/outputs/train_t3896665073989`.
|
| 32 |
+
**NOTE**: To train on different input features like BM2,BM3,BR2,BR3,PM2,PM3,PR,PR3 please modify the bz2_input_folder value in the main function before beginning the training.
|
| 33 |
+
|
| 34 |
+
| Feature Code | `bz2_input_folder` Path |
|
| 35 |
+
|--------------|------------------------------------------------------------------|
|
| 36 |
+
| BM2 | wordsegmentation/skt_dcs_DS.bz2_4K_bigram_mir_10K/ |
|
| 37 |
+
| BM3 | wordsegmentation/skt_dcs_DS.bz2_1L_bigram_mir_10K/ |
|
| 38 |
+
| BR2 | wordsegmentation/skt_dcs_DS.bz2_4K_bigram_rfe_10K/ |
|
| 39 |
+
| BR3 | wordsegmentation/skt_dcs_DS.bz2_1L_bigram_rfe_10K/ |
|
| 40 |
+
| PM2 | wordsegmentation/skt_dcs_DS.bz2_4K_pmi_mir_10K/ |
|
| 41 |
+
| PM3 | wordsegmentation/skt_dcs_DS.bz2_1L_pmi_mir_10K2/ |
|
| 42 |
+
| PR2 | wordsegmentation/skt_dcs_DS.bz2_4K_pmi_rfe_10K/ |
|
| 43 |
+
| PR3 | wordsegmentation/skt_dcs_DS.bz2_1L_pmi_rfe_10K/ |
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
## Instructions for Testing
|
| 47 |
+
|
| 48 |
+
After training, please modify the 'modelList' dictionary in 'test_clique.py' with the name of the neural network that has been saved during training. While testing for a feature, please provide the name of the neural net which was trained for the same feature.
|
| 49 |
+
|
| 50 |
+
We only provide the trained model for the feature BM2 which was our best performing feature. If the name of the neural net is not changed, then the testing will be performed on the pre-trained model for BM2 provided in outputs/train_t7978754709018
|
| 51 |
+
|
| 52 |
+
To test with a particular feature vector use the tag of the feature while execution
|
| 53 |
+
|
| 54 |
+
```python
|
| 55 |
+
python test_clique.py -t <tag>
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
For example:
|
| 59 |
+
```python
|
| 60 |
+
python test_clique.py -t BM2
|
| 61 |
+
```
|
| 62 |
+
After finishing the testing please run the following command to see the precision and recall values for both the word and word++ prediction tasks
|
| 63 |
+
```python
|
| 64 |
+
python evaluate.py <tag>
|
| 65 |
+
```
|
| 66 |
+
For example:
|
| 67 |
+
```python
|
| 68 |
+
python evaluate.py BM2
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
#Citing
|
| 72 |
+
```bibtex
|
| 73 |
+
@inproceedings{krishna-etal-2018-free,
|
| 74 |
+
title = "Free as in Free Word Order: An Energy Based Model for Word Segmentation and Morphological Tagging in {S}anskrit",
|
| 75 |
+
author = "Krishna, Amrith and
|
| 76 |
+
Santra, Bishal and
|
| 77 |
+
Bandaru, Sasi Prasanth and
|
| 78 |
+
Sahu, Gaurav and
|
| 79 |
+
Sharma, Vishnu Dutt and
|
| 80 |
+
Satuluri, Pavankumar and
|
| 81 |
+
Goyal, Pawan",
|
| 82 |
+
editor = "Riloff, Ellen and
|
| 83 |
+
Chiang, David and
|
| 84 |
+
Hockenmaier, Julia and
|
| 85 |
+
Tsujii, Jun{'}ichi",
|
| 86 |
+
booktitle = "Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing",
|
| 87 |
+
month = oct # "-" # nov,
|
| 88 |
+
year = "2018",
|
| 89 |
+
address = "Brussels, Belgium",
|
| 90 |
+
publisher = "Association for Computational Linguistics",
|
| 91 |
+
url = "https://aclanthology.org/D18-1276/",
|
| 92 |
+
doi = "10.18653/v1/D18-1276",
|
| 93 |
+
pages = "2550--2561",
|
| 94 |
+
abstract = "The configurational information in sentences of a free word order language such as Sanskrit is of limited use. Thus, the context of the entire sentence will be desirable even for basic processing tasks such as word segmentation. We propose a structured prediction framework that jointly solves the word segmentation and morphological tagging tasks in Sanskrit. We build an energy based model where we adopt approaches generally employed in graph based parsing techniques (McDonald et al., 2005a; Carreras, 2007). Our model outperforms the state of the art with an F-Score of 96.92 (percentage improvement of 7.06{\%}) while using less than one tenth of the task-specific training data. We find that the use of a graph based approach instead of a traditional lattice-based sequential labelling approach leads to a percentage gain of 12.6{\%} in F-Score for the segmentation task."
|
| 95 |
+
}
|
| 96 |
+
```
|
dir/Baseline4_advSample.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dir/DCS.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import warnings
|
| 3 |
+
from romtoslp import *
|
| 4 |
+
class DCS:
|
| 5 |
+
def __init__(self,sent_id,sentence):
|
| 6 |
+
self.sent_id=sent_id
|
| 7 |
+
self.sentence=sentence
|
| 8 |
+
self.dcs_chunks=[]
|
| 9 |
+
self.lemmas=[]
|
| 10 |
+
self.cng=[]
|
| 11 |
+
|
| 12 |
+
def SeeDCS(dcsObj):
|
| 13 |
+
print('DCS ANALYZE')
|
| 14 |
+
print('-'*15)
|
| 15 |
+
print(dcsObj.sentence)
|
| 16 |
+
print(dcsObj.lemmas)
|
| 17 |
+
print("Lemmas:", [rom_slp(c) for arr in dcsObj.lemmas for c in arr])
|
| 18 |
+
print(dcsObj.cng)
|
| 19 |
+
print()
|
dir/ECL_MST.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# coding: utf-8
|
| 3 |
+
|
| 4 |
+
# In[38]:
|
| 5 |
+
|
| 6 |
+
#!/usr/bin/env python3
|
| 7 |
+
import numpy as np
|
| 8 |
+
import sys
|
| 9 |
+
from collections import defaultdict, namedtuple
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
|
| 12 |
+
# Arc is same as edge
|
| 13 |
+
Arc = namedtuple('Arc', ('tail', 'weight', 'head'))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def min_spanning_arborescence(arcs, source):
|
| 17 |
+
good_arcs = []
|
| 18 |
+
quotient_map = {arc.tail: arc.tail for arc in arcs}
|
| 19 |
+
quotient_map[source] = source
|
| 20 |
+
while True:
|
| 21 |
+
min_arc_by_tail_rep = {}
|
| 22 |
+
successor_rep = {}
|
| 23 |
+
for arc in arcs:
|
| 24 |
+
if arc.tail == source:
|
| 25 |
+
continue
|
| 26 |
+
tail_rep = quotient_map[arc.tail]
|
| 27 |
+
head_rep = quotient_map[arc.head]
|
| 28 |
+
if tail_rep == head_rep:
|
| 29 |
+
continue
|
| 30 |
+
if tail_rep not in min_arc_by_tail_rep or min_arc_by_tail_rep[tail_rep].weight > arc.weight:
|
| 31 |
+
min_arc_by_tail_rep[tail_rep] = arc
|
| 32 |
+
successor_rep[tail_rep] = head_rep
|
| 33 |
+
cycle_reps = find_cycle(successor_rep, source)
|
| 34 |
+
if cycle_reps is None:
|
| 35 |
+
good_arcs.extend(min_arc_by_tail_rep.values())
|
| 36 |
+
return spanning_arborescence(good_arcs, source)
|
| 37 |
+
good_arcs.extend(min_arc_by_tail_rep[cycle_rep] for cycle_rep in cycle_reps)
|
| 38 |
+
cycle_rep_set = set(cycle_reps)
|
| 39 |
+
cycle_rep = cycle_rep_set.pop()
|
| 40 |
+
quotient_map = {node: cycle_rep if node_rep in cycle_rep_set else node_rep for node, node_rep in quotient_map.items()}
|
| 41 |
+
|
| 42 |
+
def find_cycle(successor, source):
|
| 43 |
+
visited = {source}
|
| 44 |
+
for node in successor:
|
| 45 |
+
cycle = []
|
| 46 |
+
while node not in visited:
|
| 47 |
+
visited.add(node)
|
| 48 |
+
cycle.append(node)
|
| 49 |
+
node = successor[node]
|
| 50 |
+
if node in cycle:
|
| 51 |
+
return cycle[cycle.index(node):]
|
| 52 |
+
return None
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def spanning_arborescence(arcs, source):
|
| 56 |
+
arcs_by_head = defaultdict(list)
|
| 57 |
+
for arc in arcs:
|
| 58 |
+
if arc.tail == source:
|
| 59 |
+
continue
|
| 60 |
+
arcs_by_head[arc.head].append(arc)
|
| 61 |
+
solution_arc_by_tail = {}
|
| 62 |
+
stack = arcs_by_head[source]
|
| 63 |
+
while stack:
|
| 64 |
+
stack = sorted(stack)
|
| 65 |
+
arc = stack.pop(0)
|
| 66 |
+
if arc.tail in solution_arc_by_tail:
|
| 67 |
+
continue
|
| 68 |
+
solution_arc_by_tail[arc.tail] = arc
|
| 69 |
+
stack.extend(arcs_by_head[arc.tail])
|
| 70 |
+
return solution_arc_by_tail
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def MST_ECL(nodelist, WScalarMat, conflicts_Dict1, source):
|
| 74 |
+
pre_nodes = []
|
| 75 |
+
nodes = []
|
| 76 |
+
WScalarMat1 = deepcopy(WScalarMat)
|
| 77 |
+
|
| 78 |
+
mst_nodes = defaultdict(lambda: [])
|
| 79 |
+
mst_nodes_bool = np.array([False]*len(nodelist))
|
| 80 |
+
mst_adj_graph = np.ndarray(WScalarMat.shape, np.bool)*False
|
| 81 |
+
|
| 82 |
+
while(len(nodes)<len(nodelist)):
|
| 83 |
+
i = int(np.argmin(WScalarMat1)/len(nodelist))
|
| 84 |
+
j = np.argmin(WScalarMat1)%len(nodelist)
|
| 85 |
+
if(i not in nodes and j not in nodes and i!=j and i not in conflicts_Dict1[j]):
|
| 86 |
+
pre_nodes.extend([i,j])
|
| 87 |
+
nodes.extend([i,j])
|
| 88 |
+
for x in conflicts_Dict1[i]:
|
| 89 |
+
if x not in nodes:
|
| 90 |
+
nodes.append(x)
|
| 91 |
+
for x in conflicts_Dict1[j]:
|
| 92 |
+
if x not in nodes:
|
| 93 |
+
nodes.append(x)
|
| 94 |
+
elif(i not in nodes and j in pre_nodes and i!=j and i not in conflicts_Dict1[j]):
|
| 95 |
+
pre_nodes.append(i)
|
| 96 |
+
nodes.append(i)
|
| 97 |
+
for x in conflicts_Dict1[i]:
|
| 98 |
+
if x not in nodes:
|
| 99 |
+
nodes.append(x)
|
| 100 |
+
elif(j not in nodes and i in pre_nodes and i!=j and j not in conflicts_Dict1[i]):
|
| 101 |
+
pre_nodes.append(j)
|
| 102 |
+
nodes.append(j)
|
| 103 |
+
for x in conflicts_Dict1[j]:
|
| 104 |
+
if x not in nodes:
|
| 105 |
+
nodes.append(x)
|
| 106 |
+
WScalarMat1[i][j] = sys.maxsize
|
| 107 |
+
|
| 108 |
+
pre_nodes.sort()
|
| 109 |
+
for i in pre_nodes:
|
| 110 |
+
mst_nodes_bool[i] = True
|
| 111 |
+
mst_nodes[nodelist[i].chunk_id].append(nodelist[i])
|
| 112 |
+
mst_nodes = dict(mst_nodes)
|
| 113 |
+
|
| 114 |
+
# list of arcs(edges)
|
| 115 |
+
list_arcs = []
|
| 116 |
+
for i in range(len(nodelist)):
|
| 117 |
+
for j in range(len(nodelist)):
|
| 118 |
+
if i in pre_nodes and j in pre_nodes and WScalarMat[i][j] != 0.0:
|
| 119 |
+
list_arcs.append(Arc(j,WScalarMat[i][j],i))
|
| 120 |
+
|
| 121 |
+
Resultant_Arcs = min_spanning_arborescence(list_arcs,pre_nodes[0])
|
| 122 |
+
#print(Resultant_Arcs)
|
| 123 |
+
for i in Resultant_Arcs.values():
|
| 124 |
+
#print(i.head,i.tail)
|
| 125 |
+
mst_adj_graph[i.head][i.tail] = True
|
| 126 |
+
|
| 127 |
+
return(mst_nodes_bool,mst_nodes,mst_adj_graph)
|
| 128 |
+
|
| 129 |
+
# Example_Graph_1 = [Arc(1, 9, 0), Arc(2, 10, 0), Arc(3, 9, 0), Arc(2, 20, 1), Arc(3, 3, 1), Arc(1,30, 2), Arc(3, 30, 2), Arc(2, 0, 3), Arc(1, 11, 3)]
|
| 130 |
+
# Example_Graph_2 = [Arc(1, 10, 0),Arc(2, 7, 0),Arc(1, 9, 2),Arc(4, 10, 1),Arc(3, 2, 2),Arc(3, 20, 4),Arc(4, 23, 2),Arc(5, 5, 2),Arc(3, 7, 5)]
|
| 131 |
+
# print(min_spanning_arborescence(Example_Graph_2,0))
|
| 132 |
+
|
dir/Feature_Collector.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import word_definite as WD
|
| 2 |
+
from MatDB import *
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
from IPython.display import display
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import math
|
| 8 |
+
np.set_printoptions(suppress=False, precision=16)
|
| 9 |
+
|
| 10 |
+
matDB = MatDB()
|
| 11 |
+
|
| 12 |
+
WD.word_definite_extInit(matDB)
|
| 13 |
+
|
| 14 |
+
print(len(matDB.mat_tupCount_1D))
|
| 15 |
+
|
| 16 |
+
smallTupList = []
|
| 17 |
+
for tup1, val in matDB.mat_tupCount_1D.items():
|
| 18 |
+
if val > 20:
|
| 19 |
+
smallTupList.append(tup1)
|
| 20 |
+
|
| 21 |
+
l = len(smallTupList)
|
| 22 |
+
print(l)
|
| 23 |
+
|
| 24 |
+
batchCount = 2000
|
| 25 |
+
perm = np.random.permutation(l)[:batchCount]
|
| 26 |
+
# perm2 = np.random.permutation(l)[:batchCount]
|
| 27 |
+
|
| 28 |
+
fN = 4*443**2 + 9*443 + 9
|
| 29 |
+
print(fN)
|
| 30 |
+
|
| 31 |
+
all_pairs = {}
|
| 32 |
+
|
| 33 |
+
index = 0
|
| 34 |
+
for k in range(batchCount):
|
| 35 |
+
tup1 = smallTupList[perm[k]]
|
| 36 |
+
lem1 = tup1.split('_')
|
| 37 |
+
cng1 = lem1[1]
|
| 38 |
+
lem1 = lem1[0]
|
| 39 |
+
|
| 40 |
+
node1 = WD.word_definite(None, lem1, cng1, 0, 0)
|
| 41 |
+
for tup2, co_occurrence in matDB.mat_tup2tup_countonly[tup1].items():
|
| 42 |
+
if co_occurrence > 4:
|
| 43 |
+
lem2 = tup2.split('_')
|
| 44 |
+
cng2 = lem2[1]
|
| 45 |
+
lem2 = lem2[0]
|
| 46 |
+
node2 = WD.word_definite(None, lem2, cng2, 0, 1)
|
| 47 |
+
all_pairs[index] = (node1, node2)
|
| 48 |
+
index += 1
|
| 49 |
+
|
| 50 |
+
with open('outputs/log_001.txt', 'a') as log_handle:
|
| 51 |
+
log_handle.write('Will get feature vectors for {} pairs\n'.format(index))
|
| 52 |
+
total_examples = index
|
| 53 |
+
|
| 54 |
+
pairs_per_file = 500
|
| 55 |
+
|
| 56 |
+
def tryForVal(mat, key1, key2):
|
| 57 |
+
try:
|
| 58 |
+
v = mat[key1][key2]
|
| 59 |
+
except:
|
| 60 |
+
v = 0
|
| 61 |
+
return v
|
| 62 |
+
|
| 63 |
+
for pairx in range(math.ceil(len(all_pairs)/pairs_per_file)):
|
| 64 |
+
subset_pairs = range(pairx*pairs_per_file, min(len(all_pairs), (pairx + 1)*pairs_per_file))
|
| 65 |
+
featureMatrix = np.zeros((fN, len(subset_pairs)))
|
| 66 |
+
targetDict = {}
|
| 67 |
+
index = 0
|
| 68 |
+
current_pairs = {}
|
| 69 |
+
for hi in subset_pairs:
|
| 70 |
+
node1 = all_pairs[hi][0]
|
| 71 |
+
node2 = all_pairs[hi][1]
|
| 72 |
+
current_pairs[index] = '{}^{}'.format(node1.tup, node2.tup)
|
| 73 |
+
featureMatrix[:, index, None] = WD.Get_Features(node1, node2)
|
| 74 |
+
targetDict[index] = (tryForVal(matDB.mat_tup2tup_countonly, node1.tup, node2.tup),\
|
| 75 |
+
tryForVal(matDB.mat_lem2lem_countonly, node1.lemma, node2.lemma),\
|
| 76 |
+
tryForVal(matDB.mat_lem2tup_countonly, node1.lemma, node2.tup),\
|
| 77 |
+
tryForVal(matDB.mat_tup2lem_countonly, node1.tup, node2.lemma))
|
| 78 |
+
index += 1
|
| 79 |
+
if index % min(math.ceil(pairs_per_file/2), 100) == 0:
|
| 80 |
+
with open('outputs/log_001.txt', 'a') as log_handle:
|
| 81 |
+
log_handle.write('Checkpoint S{}E{} of {}\n'.format(pairx, index, pairs_per_file))
|
| 82 |
+
pickle.dump({'all_pairs': current_pairs, 'featureMatrix': featureMatrix, 'targetDict': targetDict},\
|
| 83 |
+
open('outputs/featureSet_{}samples_8L_{}.p'.format(pairs_per_file, pairx), 'wb'), protocol = 4)
|
dir/GetFeatures_gen.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dir/MatDB.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
class MatDB:
|
| 3 |
+
def __init__(self):
|
| 4 |
+
self.mat_lem2lem_countonly = pickle.load(open('../NewData/gauravs/mat_lem2lem_old_countonly.p', 'rb'), encoding = u'utf-8')
|
| 5 |
+
self.mat_lem2cng_countonly = pickle.load(open('../NewData/gauravs/mat_lemma2cng_new_countonly.p', 'rb'), encoding = u'utf-8')
|
| 6 |
+
self.mat_lem2tup_countonly = pickle.load(open('../NewData/gauravs/mat_lem2tup_old_countonly.p', 'rb'), encoding = u'utf-8')
|
| 7 |
+
|
| 8 |
+
self.mat_cng2lem_countonly = pickle.load(open('../NewData/gauravs/mat_cng2lemma_new_countonly.p', 'rb'), encoding = u'utf-8')
|
| 9 |
+
self.mat_cng2tup_countonly = pickle.load(open('../NewData/gauravs/mat_cng2tup_new_countonly.p', 'rb'), encoding = u'utf-8')
|
| 10 |
+
self.mat_cng2cng_countonly = pickle.load(open('../NewData/gauravs/mat_cng2cng_new_countonly.p', 'rb'), encoding = u'utf-8')
|
| 11 |
+
|
| 12 |
+
self.mat_tup2cng_countonly = pickle.load(open('../NewData/gauravs/mat_tup2cng_new_countonly.p', 'rb'), encoding = u'utf-8')
|
| 13 |
+
self.mat_tup2lem_countonly = pickle.load(open('../NewData/gauravs/mat_tup2lem_old_countonly.p', 'rb'), encoding = u'utf-8')
|
| 14 |
+
self.mat_tup2tup_countonly = pickle.load(open('../NewData/gauravs/mat_tup2tup_new_countonly.p', 'rb'), encoding = u'utf-8')
|
| 15 |
+
|
| 16 |
+
self.mat_lemCount_1D = pickle.load(open('../NewData/gauravs/Temporary_1D/mat_lemCount_1D.p', 'rb'), encoding = u'utf-8')
|
| 17 |
+
self.mat_cngCount_1D = pickle.load(open('../NewData/gauravs/Temporary_1D/mat_cngCount_1D.p', 'rb'), encoding = u'utf-8')
|
| 18 |
+
self.mat_tupCount_1D = pickle.load(open('../NewData/gauravs/Temporary_1D/mat_tupCount_1D.p', 'rb'), encoding = u'utf-8')
|
dir/TestPoolMP.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import multiprocessing as mp
|
| 2 |
+
import TestPool_Unit
|
| 3 |
+
from shutil import copyfile
|
| 4 |
+
|
| 5 |
+
def Evaluate(result_arr):
|
| 6 |
+
print('Files Processed: ', len(result_arr))
|
| 7 |
+
recalls = []
|
| 8 |
+
recalls_of_word = []
|
| 9 |
+
precisions = []
|
| 10 |
+
precisions_of_words = []
|
| 11 |
+
for entry in result_arr:
|
| 12 |
+
(word_match, lemma_match, n_dcsWords, n_output_nodes) = entry
|
| 13 |
+
recalls.append(lemma_match/n_dcsWords)
|
| 14 |
+
recalls_of_word.append(word_match/n_dcsWords)
|
| 15 |
+
|
| 16 |
+
precisions.append(lemma_match/n_output_nodes)
|
| 17 |
+
precisions_of_words.append(word_match/n_output_nodes)
|
| 18 |
+
print('Avg. Micro Recall of Lemmas: {}'.format(np.mean(np.array(recalls))))
|
| 19 |
+
print('Avg. Micro Recall of Words: {}'.format(np.mean(np.array(recalls_of_word))))
|
| 20 |
+
print('Avg. Micro Precision of Lemmas: {}'.format(np.mean(np.array(precisions))))
|
| 21 |
+
print('Avg. Micro Precision of Words: {}'.format(np.mean(np.array(precisions_of_words))))
|
| 22 |
+
|
| 23 |
+
modelFile = 'outputs/train_nnet_t764815831413.p'
|
| 24 |
+
# Backup the model file
|
| 25 |
+
copyfile(modelFile, modelFile + '.bk')
|
| 26 |
+
|
| 27 |
+
modelFile = modelFile + '.bk'
|
| 28 |
+
|
| 29 |
+
# Create Queue, Result array
|
| 30 |
+
queue = mp.Queue()
|
| 31 |
+
result_arr = []
|
| 32 |
+
|
| 33 |
+
# Start 6 workers - 8 slows down the pc
|
| 34 |
+
proc_count = 10
|
| 35 |
+
procs = [None]*proc_count
|
| 36 |
+
for i in range(proc_count):
|
| 37 |
+
vpid = i
|
| 38 |
+
procs[i] = mp.Process(target = TestPool_Unit.pooled_Test, args = (modelFile, vpid, queue, 700))
|
| 39 |
+
|
| 40 |
+
# Start Processes
|
| 41 |
+
for i in range(proc_count):
|
| 42 |
+
procs[i].start()
|
| 43 |
+
|
| 44 |
+
# Properly Join
|
| 45 |
+
for i in range(proc_count):
|
| 46 |
+
procs[i].join()
|
| 47 |
+
|
| 48 |
+
# Fetch partial results
|
| 49 |
+
while not queue.empty():
|
| 50 |
+
result_arr.append(queue.get())
|
| 51 |
+
|
| 52 |
+
# Evaluate results till now
|
| 53 |
+
Evaluate(result_arr)
|
dir/TestPool_Unit.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from multiprocessing import Process
|
| 2 |
+
import multiprocessing as mp
|
| 3 |
+
import os, sys
|
| 4 |
+
from sentences import *
|
| 5 |
+
import numpy as np
|
| 6 |
+
from Train_n_Save_NNet import *
|
| 7 |
+
|
| 8 |
+
def pooled_Test(modelFile, vpid, queue, testfolder, filePerProcess = 100, _dump = False, _outFile = None):
|
| 9 |
+
n_chkpt = 100
|
| 10 |
+
print('Child process with vpid:{}, pid:{} started.'.format(vpid, os.getpid()))
|
| 11 |
+
trainer = Trainer()
|
| 12 |
+
trainer.Load(modelFile)
|
| 13 |
+
|
| 14 |
+
TestFiles = []
|
| 15 |
+
for f in os.listdir(testfolder):
|
| 16 |
+
if '.ds.bz2' in f:
|
| 17 |
+
TestFiles.append(f)
|
| 18 |
+
|
| 19 |
+
print('vpid:{}: Range is {} -> {} / {}'.format(vpid, vpid*filePerProcess, vpid*filePerProcess + filePerProcess, len(TestFiles)))
|
| 20 |
+
if _dump:
|
| 21 |
+
_outFile = '{}_proc{}.csv'.format(_outFile, vpid)
|
| 22 |
+
with open(_outFile, 'w') as fh:
|
| 23 |
+
print('File refreshed', _outFile)
|
| 24 |
+
|
| 25 |
+
loaded_SKT = pickle.load(open('../Simultaneous_CompatSKT_ho.p', 'rb'))
|
| 26 |
+
loaded_DCS = pickle.load(open('../Simultaneous_DCS_ho.p', 'rb'))
|
| 27 |
+
|
| 28 |
+
#loader = pickle.load(open('../bz2Dataset_10K.p', 'rb'))
|
| 29 |
+
#TestFiles = loader['TestFiles']
|
| 30 |
+
#TrainFiles = loader['TrainFiles']
|
| 31 |
+
|
| 32 |
+
for i in range(vpid*filePerProcess, vpid*filePerProcess + filePerProcess):
|
| 33 |
+
#if i % n_chkpt == 0:
|
| 34 |
+
#print('Checkpoint {}, vpid: {}'.format(i, vpid))
|
| 35 |
+
#sys.stdout.flush()
|
| 36 |
+
fn = TestFiles[i]
|
| 37 |
+
fn = fn.replace('.ds.bz2', '.p2')
|
| 38 |
+
|
| 39 |
+
dsbz2_name = testfolder + TestFiles[i]
|
| 40 |
+
|
| 41 |
+
sentenceObj = loaded_SKT[fn]
|
| 42 |
+
dcsObj = loaded_DCS[fn]
|
| 43 |
+
try:
|
| 44 |
+
if _dump:
|
| 45 |
+
results = trainer.Test(sentenceObj, dcsObj, dsbz2_name, _dump=True, _outFile = _outFile)
|
| 46 |
+
else:
|
| 47 |
+
results = trainer.Test(sentenceObj, dcsObj, dsbz2_name)
|
| 48 |
+
except EOFError as e:
|
| 49 |
+
print('BADFILE', dsbz2_name)
|
| 50 |
+
|
| 51 |
+
if results is not None:
|
| 52 |
+
queue.put(results)
|
| 53 |
+
print('Child process with vpid:{}, pid:{} closed.'.format(vpid, os.getpid()))
|
| 54 |
+
|
dir/TestPool_Unit_clique.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from multiprocessing import Process
|
| 2 |
+
import multiprocessing as mp
|
| 3 |
+
import os, sys
|
| 4 |
+
from sentences import *
|
| 5 |
+
import numpy as np
|
| 6 |
+
from Train_clique import *
|
| 7 |
+
|
| 8 |
+
def pooled_Test(modelFile, vpid, queue, testfolder, filePerProcess = 100, _dump = False, _outFile = None):
|
| 9 |
+
n_chkpt = 100
|
| 10 |
+
print('Child process with vpid:{}, pid:{} started.'.format(vpid, os.getpid()))
|
| 11 |
+
trainer = Trainer()
|
| 12 |
+
trainer.Load(modelFile)
|
| 13 |
+
|
| 14 |
+
TestFiles = []
|
| 15 |
+
for f in os.listdir(testfolder):
|
| 16 |
+
if '.ds.bz2' in f:
|
| 17 |
+
TestFiles.append(f)
|
| 18 |
+
# print('TestFIles')
|
| 19 |
+
# print(TestFiles)
|
| 20 |
+
print('vpid:{}: Range is {} -> {} / {}'.format(vpid, vpid*filePerProcess, vpid*filePerProcess + filePerProcess, len(TestFiles)))
|
| 21 |
+
if _dump:
|
| 22 |
+
_outFile = '{}_proc{}.csv'.format(_outFile, vpid)
|
| 23 |
+
with open(_outFile, 'w') as fh:
|
| 24 |
+
print('File refreshed', _outFile)
|
| 25 |
+
|
| 26 |
+
loaded_SKT = pickle.load(open('Simultaneous_CompatSKT_ho.p', 'rb'))
|
| 27 |
+
loaded_DCS = pickle.load(open('Simultaneous_DCS_ho.p', 'rb'))
|
| 28 |
+
|
| 29 |
+
#loader = pickle.load(open('../bz2Dataset_10K.p', 'rb'))
|
| 30 |
+
#TestFiles = loader['TestFiles']
|
| 31 |
+
#TrainFiles = loader['TrainFiles']
|
| 32 |
+
|
| 33 |
+
for i in range(vpid*filePerProcess, vpid*filePerProcess + filePerProcess):
|
| 34 |
+
#if i % n_chkpt == 0:
|
| 35 |
+
#print('Checkpoint {}, vpid: {}'.format(i, vpid))
|
| 36 |
+
#sys.stdout.flush()
|
| 37 |
+
fn = TestFiles[i]
|
| 38 |
+
fn = fn.replace('.ds.bz2', '.p2')
|
| 39 |
+
|
| 40 |
+
dsbz2_name = testfolder + TestFiles[i]
|
| 41 |
+
|
| 42 |
+
sentenceObj = loaded_SKT[fn]
|
| 43 |
+
# print(fn)
|
| 44 |
+
# print(type(fn))
|
| 45 |
+
dcsObj = loaded_DCS[fn]
|
| 46 |
+
try:
|
| 47 |
+
if _dump:
|
| 48 |
+
results = trainer.Test(sentenceObj, dcsObj, dsbz2_name, _dump=True, _outFile = _outFile)
|
| 49 |
+
else:
|
| 50 |
+
results = trainer.Test(sentenceObj, dcsObj, dsbz2_name)
|
| 51 |
+
except EOFError as e:
|
| 52 |
+
print('BADFILE', dsbz2_name)
|
| 53 |
+
|
| 54 |
+
if results is not None:
|
| 55 |
+
queue.put(results)
|
| 56 |
+
print('Child process with vpid:{}, pid:{} closed.'.format(vpid, os.getpid()))
|
| 57 |
+
|
dir/Train_bron.py
ADDED
|
@@ -0,0 +1,830 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
IMPORTS
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
## bUILT-iN pACKAGES
|
| 6 |
+
import sys, os, time, bz2, zlib, pickle, math, json, csv
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
import numpy as np
|
| 9 |
+
np.set_printoptions(suppress=True)
|
| 10 |
+
from IPython.display import display
|
| 11 |
+
|
| 12 |
+
## lAST sUMMER
|
| 13 |
+
from romtoslp import *
|
| 14 |
+
from sentences import *
|
| 15 |
+
from DCS import *
|
| 16 |
+
import MatDB
|
| 17 |
+
from bronclique import *
|
| 18 |
+
from ECL_MST import *
|
| 19 |
+
|
| 20 |
+
import word_definite as WD
|
| 21 |
+
# from heap_n_PrimMST import *
|
| 22 |
+
from nnet import *
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
################################################################################################
|
| 26 |
+
########################### LOAD SENTENCE AND DCS OBJECT FILES ###############################
|
| 27 |
+
################################################################################################
|
| 28 |
+
"""
|
| 29 |
+
# loaded_SKT = pickle.load(open('../Simultaneous_CompatSKT_10K.p', 'rb'))
|
| 30 |
+
# loaded_DCS = pickle.load(open('../Simultaneous_DCS_10K.p', 'rb'))
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
################################################################################################
|
| 34 |
+
########################### OPENS AND EXTRACTS DCS AND SKT DATA STRUCTURES ###################
|
| 35 |
+
###################################### FROM BZ2 FILES ########################################
|
| 36 |
+
"""
|
| 37 |
+
def open_dsbz2(filename):
|
| 38 |
+
with bz2.BZ2File(filename, 'r') as f:
|
| 39 |
+
loader = pickle.load(f)
|
| 40 |
+
|
| 41 |
+
conflicts_Dict_correct = loader['conflicts_Dict_correct']
|
| 42 |
+
nodelist_to_correct_mapping = loader['nodelist_to_correct_mapping']
|
| 43 |
+
nodelist_correct = loader['nodelist_correct']
|
| 44 |
+
featVMat_correct = loader['featVMat_correct']
|
| 45 |
+
featVMat = loader['featVMat']
|
| 46 |
+
conflicts_Dict = loader['conflicts_Dict']
|
| 47 |
+
nodelist = loader['nodelist']
|
| 48 |
+
|
| 49 |
+
return (nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 50 |
+
nodelist, conflicts_Dict, featVMat)
|
| 51 |
+
|
| 52 |
+
"""
|
| 53 |
+
################################################################################################
|
| 54 |
+
###################### CREATE SEVERAL DATA STRUCTURES FROM SENTENCE/DCS ######################
|
| 55 |
+
########################### NODELIST, ADJACENCY LIST, GRAPH, HEAP #############################
|
| 56 |
+
################################################################################################
|
| 57 |
+
"""
|
| 58 |
+
def GetTrainingKit(sentenceObj, dcsObj):
|
| 59 |
+
nodelist = GetNodes(sentenceObj)
|
| 60 |
+
|
| 61 |
+
# Nodelist with only the correct_nodes
|
| 62 |
+
nodelist2 = GetNodes(sentenceObj)
|
| 63 |
+
nodelist2_to_correct_mapping = {}
|
| 64 |
+
nodelist_correct = []
|
| 65 |
+
search_key = 0
|
| 66 |
+
first_key = 0
|
| 67 |
+
for chunk_id in range(len(dcsObj.lemmas)):
|
| 68 |
+
while nodelist2[first_key].chunk_id != chunk_id:
|
| 69 |
+
first_key += 1
|
| 70 |
+
for j in range(len(dcsObj.lemmas[chunk_id])):
|
| 71 |
+
search_key = first_key
|
| 72 |
+
while (nodelist2[search_key].lemma != rom_slp(dcsObj.lemmas[chunk_id][j])) or (nodelist2[search_key].cng != dcsObj.cng[chunk_id][j]):
|
| 73 |
+
search_key += 1
|
| 74 |
+
if search_key >= len(nodelist2) or nodelist2[search_key].chunk_id > chunk_id:
|
| 75 |
+
break
|
| 76 |
+
# print((rom_slp(dcsObj.lemmas[chunk_id][j]), dcsObj.cng[chunk_id][j]))
|
| 77 |
+
# print(nodelist[search_key])
|
| 78 |
+
nodelist2_to_correct_mapping[len(nodelist_correct)] = search_key
|
| 79 |
+
nodelist_correct.append(nodelist2[search_key])
|
| 80 |
+
return (nodelist, nodelist_correct, nodelist2_to_correct_mapping)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def GetGraph(nodelist, neuralnet):
|
| 84 |
+
if not neuralnet.outer_relu:
|
| 85 |
+
conflicts_Dict = Get_Conflicts(nodelist)
|
| 86 |
+
|
| 87 |
+
featVMat = Get_Feat_Vec_Matrix(nodelist, conflicts_Dict)
|
| 88 |
+
|
| 89 |
+
(WScalarMat, SigmoidGateOutput) = Get_W_Scalar_Matrix_from_FeatVect_Matrix(featVMat, nodelist, conflicts_Dict, neuralnet)
|
| 90 |
+
return (conflicts_Dict, featVMat, WScalarMat, SigmoidGateOutput)
|
| 91 |
+
else:
|
| 92 |
+
conflicts_Dict = Get_Conflicts(nodelist)
|
| 93 |
+
|
| 94 |
+
featVMat = Get_Feat_Vec_Matrix(nodelist, conflicts_Dict)
|
| 95 |
+
|
| 96 |
+
WScalarMat = Get_W_Scalar_Matrix_from_FeatVect_Matrix(featVMat, nodelist, conflicts_Dict, neuralnet)
|
| 97 |
+
return (conflicts_Dict, featVMat, WScalarMat)
|
| 98 |
+
|
| 99 |
+
# NEW LOSS FUNCTION
|
| 100 |
+
def GetLoss(_mst_adj_graph, _mask_de_correct_edges, _negLogLikelies):
|
| 101 |
+
_negLogLikelies = _negLogLikelies.copy()
|
| 102 |
+
_negLogLikelies[~_mst_adj_graph] = 0
|
| 103 |
+
_negLogLikelies[~_mask_de_correct_edges] *= -1 # BAKA!!! Check before you try to fix this again
|
| 104 |
+
return np.sum(_negLogLikelies)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
"""
|
| 109 |
+
################################################################################################
|
| 110 |
+
############################## MAIN FUNCTION #################################################
|
| 111 |
+
################################################################################################
|
| 112 |
+
"""
|
| 113 |
+
trainingStatus = defaultdict(lambda: bool(False))
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
"""
|
| 117 |
+
################################################################################################
|
| 118 |
+
############################## TRAIN FUNCTION ################################################
|
| 119 |
+
################################################################################################
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def train_generator(loaded_SKT, loaded_DCS, bz2_input_folder, n_trainset = -1, iterationPerBatch = 10, filePerBatch = 20, _debug = True, superEpochs = 1):
|
| 123 |
+
# Train
|
| 124 |
+
if n_trainset == -1:
|
| 125 |
+
n_trainset = len(TrainFiles)
|
| 126 |
+
totalBatchToTrain = math.ceil(n_trainset/filePerBatch)
|
| 127 |
+
else:
|
| 128 |
+
totalBatchToTrain = math.ceil(n_trainset/filePerBatch)
|
| 129 |
+
|
| 130 |
+
register_nnet(trainer.neuralnet, bz2_input_folder)
|
| 131 |
+
print('Epoch:'+str(superEpochs))
|
| 132 |
+
print('iters:'+str(totalBatchToTrain))
|
| 133 |
+
for _epoch in range(superEpochs):
|
| 134 |
+
for iterout in range(totalBatchToTrain):
|
| 135 |
+
# Add timer
|
| 136 |
+
startT = time.time()
|
| 137 |
+
|
| 138 |
+
# Change current batch
|
| 139 |
+
if(iterout % 50 == 0):
|
| 140 |
+
trainer.Save(p_name.replace('.p', '_e{}_i{}.p'.format(_epoch, iterout)))
|
| 141 |
+
else:
|
| 142 |
+
trainer.Save(p_name)
|
| 143 |
+
print('Epoch: {}, Batch: {}'.format(_epoch, iterout))
|
| 144 |
+
files_for_batch = TrainFiles[iterout*filePerBatch:(iterout + 1)*filePerBatch]
|
| 145 |
+
# print(files_for_batch)
|
| 146 |
+
# trainer.Load('outputs/neuralnet_trained.p')
|
| 147 |
+
try:
|
| 148 |
+
# Run few times on same set of files
|
| 149 |
+
for iterin in range(iterationPerBatch):
|
| 150 |
+
print('ITERATION IN', iterin)
|
| 151 |
+
for fn in files_for_batch:
|
| 152 |
+
stime = time.time()
|
| 153 |
+
trainFileName = fn.replace('.ds.bz2', '.p2')
|
| 154 |
+
sentenceObj = loaded_SKT[trainFileName]
|
| 155 |
+
# print(trainFileName)
|
| 156 |
+
|
| 157 |
+
dcsObj = loaded_DCS[trainFileName]
|
| 158 |
+
if trainingStatus[sentenceObj.sent_id]:
|
| 159 |
+
continue
|
| 160 |
+
# trainer.Save('outputs/saved_trainer.p')
|
| 161 |
+
try:
|
| 162 |
+
trainer.Train(sentenceObj, dcsObj, bz2_input_folder, _debug)
|
| 163 |
+
except (IndexError, KeyError) as e:
|
| 164 |
+
print(e)
|
| 165 |
+
print('\x1b[31mFailed: {} \x1b[0m'.format(sentenceObj.sent_id))
|
| 166 |
+
except EOFError as e:
|
| 167 |
+
print('\x1b[31mBADFILE: {} \x1b[0m'.format(sentenceObj.sent_id))
|
| 168 |
+
ftime = time.time()
|
| 169 |
+
print('Time taken for file '+str(trainFileName)+'is '+str(ftime-stime)+"seconds")
|
| 170 |
+
with open('bron.csv','a') as fh:
|
| 171 |
+
rd = csv.writer(fh)
|
| 172 |
+
rd.writerow([str(ftime-stime)])
|
| 173 |
+
sys.stdout.flush() # Flush IO buffer
|
| 174 |
+
finishT = time.time()
|
| 175 |
+
|
| 176 |
+
print('Avg. time taken by 1 file(1 iteration): {:.3f}'.format((finishT - startT)/(iterationPerBatch*filePerBatch)))
|
| 177 |
+
except KeyboardInterrupt:
|
| 178 |
+
print('Training paused')
|
| 179 |
+
trainer.Save(p_name)
|
| 180 |
+
yield None
|
| 181 |
+
trainer.Save(p_name)
|
| 182 |
+
|
| 183 |
+
def test(loaded_SKT, loaded_DCS, n_testSet = -1, _testFiles = None, n_checkpt = 100):
|
| 184 |
+
total_lemma = 0;
|
| 185 |
+
correct_lemma = 0;
|
| 186 |
+
|
| 187 |
+
total_word = 0;
|
| 188 |
+
total_output_nodes = 0
|
| 189 |
+
correct_word = 0;
|
| 190 |
+
file_counter = 0
|
| 191 |
+
if _testFiles is None:
|
| 192 |
+
if n_testSet == -1:
|
| 193 |
+
_testFiles = TestFiles
|
| 194 |
+
else:
|
| 195 |
+
_testFiles = TestFiles[0:n_testSet]
|
| 196 |
+
else:
|
| 197 |
+
if n_testSet == -1:
|
| 198 |
+
_testFiles = _testFiles
|
| 199 |
+
else:
|
| 200 |
+
_testFiles = _testFiles[0:n_testSet]
|
| 201 |
+
|
| 202 |
+
recalls = []
|
| 203 |
+
recalls_of_word = []
|
| 204 |
+
precisions = []
|
| 205 |
+
precisions_of_words = []
|
| 206 |
+
for fn in _testFiles:
|
| 207 |
+
if file_counter % n_checkpt == 0:
|
| 208 |
+
print(file_counter,' Checkpoint... ')
|
| 209 |
+
if file_counter > 0:
|
| 210 |
+
print('Avg. Micro Recall of Lemmas: {}'.format(np.mean(np.array(recalls))))
|
| 211 |
+
print('Avg. Micro Recall of Words: {}'.format(np.mean(np.array(recalls_of_word))))
|
| 212 |
+
print('Avg. Micro Precision of Lemmas: {}'.format(np.mean(np.array(precisions))))
|
| 213 |
+
print('Avg. Micro Precision of Words: {}'.format(np.mean(np.array(precisions_of_words))))
|
| 214 |
+
sys.stdout.flush() # Flush IO buffer
|
| 215 |
+
|
| 216 |
+
file_counter += 1
|
| 217 |
+
|
| 218 |
+
testFileName = fn.replace('.ds.bz2', '.p2')
|
| 219 |
+
sentenceObj = loaded_SKT[testFileName]
|
| 220 |
+
# print(testFileName)
|
| 221 |
+
# print(type(testFileName))
|
| 222 |
+
dcsObj = loaded_DCS[testFileName]
|
| 223 |
+
|
| 224 |
+
try:
|
| 225 |
+
(word_match, lemma_match, n_dcsWords, n_output_nodes) = trainer.Test(sentenceObj, dcsObj)
|
| 226 |
+
|
| 227 |
+
recalls.append(lemma_match/n_dcsWords)
|
| 228 |
+
recalls_of_word.append(word_match/n_dcsWords)
|
| 229 |
+
|
| 230 |
+
precisions.append(lemma_match/n_output_nodes)
|
| 231 |
+
precisions_of_words.append(word_match/n_output_nodes)
|
| 232 |
+
|
| 233 |
+
total_lemma += n_dcsWords
|
| 234 |
+
total_word += n_dcsWords
|
| 235 |
+
|
| 236 |
+
total_output_nodes += n_output_nodes
|
| 237 |
+
|
| 238 |
+
correct_lemma += lemma_match
|
| 239 |
+
correct_word += word_match
|
| 240 |
+
except (IndexError, KeyError) as e:
|
| 241 |
+
print('Failed!')
|
| 242 |
+
|
| 243 |
+
print('Avg. Micro Recall of Lemmas: {}'.format(np.mean(np.array(recalls))))
|
| 244 |
+
print('Avg. Micro Recall of Words: {}'.format(np.mean(np.array(recalls_of_word))))
|
| 245 |
+
print('Avg. Micro Precision of Lemmas: {}'.format(np.mean(np.array(precisions))))
|
| 246 |
+
print('Avg. Micro Precision of Words: {}'.format(np.mean(np.array(precisions_of_words))))
|
| 247 |
+
|
| 248 |
+
return (recalls, recalls_of_word, precisions, precisions_of_words)
|
| 249 |
+
|
| 250 |
+
# NEW FUNCTION
|
| 251 |
+
def GetLoss(_mst_adj_graph, _mask_de_correct_edges, _WScalarMat):
|
| 252 |
+
_WScalarMat = _WScalarMat.copy()
|
| 253 |
+
_WScalarMat[_mst_adj_graph&(~_mask_de_correct_edges)] *= -1 # BAKA!!! Check before you try to fix this again
|
| 254 |
+
_WScalarMat[~_mst_adj_graph] = 0
|
| 255 |
+
return np.sum(_WScalarMat)
|
| 256 |
+
|
| 257 |
+
"""
|
| 258 |
+
################################################################################################
|
| 259 |
+
############################# TRAINER CLASS DEFINITION ######################################
|
| 260 |
+
################################################################################################
|
| 261 |
+
"""
|
| 262 |
+
class Trainer:
|
| 263 |
+
def __init__(self, modelFile = None):
|
| 264 |
+
if modelFile is None:
|
| 265 |
+
singleLayer = True
|
| 266 |
+
self._edge_vector_dim = 1500
|
| 267 |
+
if singleLayer:
|
| 268 |
+
self.hidden_layer_size = 1200
|
| 269 |
+
keep_prob = 0.6
|
| 270 |
+
self.neuralnet = NN(self._edge_vector_dim, self.hidden_layer_size, outer_relu=True, keep_prob=keep_prob)
|
| 271 |
+
else:
|
| 272 |
+
# DeepR Network
|
| 273 |
+
self.hidden_layer_size = 800
|
| 274 |
+
self.hidden_layer_size2 = 800
|
| 275 |
+
self.neuralnet = NN_2(self._edge_vector_dim, self.hidden_layer_size,\
|
| 276 |
+
hidden_layer_2_size = self.hidden_layer_size2, outer_relu=True)
|
| 277 |
+
self.history = defaultdict(lambda: list())
|
| 278 |
+
else:
|
| 279 |
+
loader = pickle.load(open(filename, 'rb'))
|
| 280 |
+
|
| 281 |
+
self.neuralnet.n = loader['n']
|
| 282 |
+
self.neuralnet.d = loader['d']
|
| 283 |
+
|
| 284 |
+
self.neuralnet = NN(self._edge_vector_dim, self.hidden_layer_size, outer_relu=True)
|
| 285 |
+
|
| 286 |
+
self.neuralnet.U = loader['U']
|
| 287 |
+
self.neuralnet.W = loader['W']
|
| 288 |
+
self.neuralnet.B1 = loader['B1']
|
| 289 |
+
self.neuralnet.B2 = loader['B2']
|
| 290 |
+
|
| 291 |
+
self.history = defaultdict(lambda: list())
|
| 292 |
+
|
| 293 |
+
# SET LEARNING RATES
|
| 294 |
+
if self.neuralnet.version == 'h1':
|
| 295 |
+
self.neuralnet.etaW = 3e-5
|
| 296 |
+
self.neuralnet.etaB1 = 1e-5
|
| 297 |
+
|
| 298 |
+
self.neuralnet.etaU = 1e-5
|
| 299 |
+
self.neuralnet.etaB2 = 1e-5
|
| 300 |
+
elif self.neuralnet.version == 'h2':
|
| 301 |
+
self.neuralnet.etaW1 = 3e-4
|
| 302 |
+
self.neuralnet.etaB1 = 1e-4
|
| 303 |
+
|
| 304 |
+
self.neuralnet.etaW2 = 1e-4
|
| 305 |
+
self.neuralnet.etaB2 = 1e-4
|
| 306 |
+
|
| 307 |
+
self.neuralnet.etaU = 1e-4
|
| 308 |
+
self.neuralnet.etaB3 = 1e-4
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def Reset(self):
|
| 312 |
+
self.neuralnet = NN(self._edge_vector_dim, self.hidden_layer_size)
|
| 313 |
+
self.history = defaultdict(lambda: list())
|
| 314 |
+
|
| 315 |
+
def Save(self, filename):
|
| 316 |
+
print('Weights Saved: ', filename)
|
| 317 |
+
if self.neuralnet.version == 'h1':
|
| 318 |
+
pickle.dump({
|
| 319 |
+
'U': self.neuralnet.U,
|
| 320 |
+
'W': self.neuralnet.W,
|
| 321 |
+
'n': self.neuralnet.n,
|
| 322 |
+
'd': self.neuralnet.d,
|
| 323 |
+
'B1': self.neuralnet.B1,
|
| 324 |
+
'B2': self.neuralnet.B2,
|
| 325 |
+
'keep_prob': self.neuralnet.keep_prob,
|
| 326 |
+
'version': self.neuralnet.version
|
| 327 |
+
}, open(filename, 'wb'))
|
| 328 |
+
return
|
| 329 |
+
elif self.neuralnet.version == 'h2':
|
| 330 |
+
pickle.dump({
|
| 331 |
+
'U': self.neuralnet.U,
|
| 332 |
+
'B3': self.neuralnet.B3,
|
| 333 |
+
'W2': self.neuralnet.W2,
|
| 334 |
+
'B2': self.neuralnet.B2,
|
| 335 |
+
'W1': self.neuralnet.W1,
|
| 336 |
+
'B1': self.neuralnet.B1,
|
| 337 |
+
'h1': self.neuralnet.h1,
|
| 338 |
+
'h2': self.neuralnet.h2,
|
| 339 |
+
'd': self.neuralnet.d,
|
| 340 |
+
'version': self.neuralnet.version
|
| 341 |
+
}, open(filename, 'wb'))
|
| 342 |
+
return
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def Load(self, filename):
|
| 346 |
+
loader = pickle.load(open(filename, 'rb'))
|
| 347 |
+
if 'version' not in loader: # means 1 hidden layer
|
| 348 |
+
self.neuralnet = NN(self._edge_vector_dim, self.hidden_layer_size, outer_relu=True)
|
| 349 |
+
self.neuralnet.U = loader['U']
|
| 350 |
+
self.neuralnet.W = loader['W']
|
| 351 |
+
self.neuralnet.B1 = loader['B1']
|
| 352 |
+
self.neuralnet.B2 = loader['B2']
|
| 353 |
+
self.neuralnet.hidden_layer_size = loader['n']
|
| 354 |
+
self.neuralnet._edge_vector_dim = loader['d']
|
| 355 |
+
if 'keep_prob' in loader:
|
| 356 |
+
self.neuralnet.keep_prob = loader['keep_prob']
|
| 357 |
+
self.neuralnet.dropout_prob = 1 - loader['keep_prob']
|
| 358 |
+
print('Keep Prob = {}, Dropout = {}'.format(self.neuralnet.keep_prob, self.neuralnet.dropout_prob))
|
| 359 |
+
else:
|
| 360 |
+
if loader['version'] == 'h1':
|
| 361 |
+
self.neuralnet = NN(self._edge_vector_dim, self.hidden_layer_size, outer_relu=True)
|
| 362 |
+
self.neuralnet.U = loader['U']
|
| 363 |
+
self.neuralnet.W = loader['W']
|
| 364 |
+
self.neuralnet.B1 = loader['B1']
|
| 365 |
+
self.neuralnet.B2 = loader['B2']
|
| 366 |
+
self.neuralnet.hidden_layer_size = loader['n']
|
| 367 |
+
self.neuralnet._edge_vector_dim = loader['d']
|
| 368 |
+
if 'keep_prob' in loader:
|
| 369 |
+
self.neuralnet.keep_prob = loader['keep_prob']
|
| 370 |
+
self.neuralnet.dropout_prob = 1 - loader['keep_prob']
|
| 371 |
+
print('Keep Prob = {}, Dropout = {}'.format(self.neuralnet.keep_prob, self.neuralnet.dropout_prob))
|
| 372 |
+
elif loader['version'] == 'h2':
|
| 373 |
+
self.neuralnet = NN_2(self._edge_vector_dim, self.hidden_layer_size, outer_relu=True)
|
| 374 |
+
|
| 375 |
+
self.neuralnet.U = loader['U']
|
| 376 |
+
self.neuralnet.B3 = loader['B3']
|
| 377 |
+
|
| 378 |
+
self.neuralnet.W2 = loader['W2']
|
| 379 |
+
self.neuralnet.B2 = loader['B2']
|
| 380 |
+
|
| 381 |
+
self.neuralnet.W1 = loader['W1']
|
| 382 |
+
self.neuralnet.B1 = loader['B1']
|
| 383 |
+
|
| 384 |
+
self.neuralnet.h1 = loader['h1']
|
| 385 |
+
self.neuralnet.h2 = loader['h2']
|
| 386 |
+
self.neuralnet.d = loader['d']
|
| 387 |
+
|
| 388 |
+
def CalculateLoss_n_Grads(self, WScalarMat, min_st_adj_worst, max_st_adj_gold, loss_type = 0, min_marginalized_energy = None):
|
| 389 |
+
doBpp = True
|
| 390 |
+
|
| 391 |
+
# Claculate the enrgies
|
| 392 |
+
etg = np.sum(WScalarMat[max_st_adj_gold])
|
| 393 |
+
etq = np.sum(WScalarMat[min_st_adj_worst])
|
| 394 |
+
|
| 395 |
+
if loss_type == 0:
|
| 396 |
+
# Variable Hinge Loss - CHECKED
|
| 397 |
+
L = etg - min_marginalized_energy
|
| 398 |
+
if L > 0:
|
| 399 |
+
dLdOut = np.zeros_like(WScalarMat)
|
| 400 |
+
dLdOut[max_st_adj_gold&(~min_st_adj_worst)] = 1
|
| 401 |
+
dLdOut[(~max_st_adj_gold)&min_st_adj_worst] = -1
|
| 402 |
+
else:
|
| 403 |
+
doBpp = False
|
| 404 |
+
return (L, None, doBpp)
|
| 405 |
+
elif loss_type == 1:
|
| 406 |
+
# LOg Loss
|
| 407 |
+
a = etg - etq
|
| 408 |
+
b = np.exp(a)
|
| 409 |
+
L = np.log(1 + b)
|
| 410 |
+
|
| 411 |
+
dLdOut = np.zeros_like(WScalarMat)
|
| 412 |
+
dLdOut[max_st_adj_gold&(~min_st_adj_worst)] = 1
|
| 413 |
+
dLdOut[(~max_st_adj_gold)&min_st_adj_worst] = -1
|
| 414 |
+
|
| 415 |
+
dLdOut *= (b/(1 + b))
|
| 416 |
+
elif loss_type == 2:
|
| 417 |
+
# Square exponential loss
|
| 418 |
+
gamma = 1
|
| 419 |
+
b = np.exp(-etq)
|
| 420 |
+
|
| 421 |
+
L = etg**2 + gamma*b
|
| 422 |
+
|
| 423 |
+
dLdOut = np.zeros_like(WScalarMat)
|
| 424 |
+
dLdOut[max_st_adj_gold&(~min_st_adj_worst)] = 2*etg
|
| 425 |
+
dLdOut[(~max_st_adj_gold)&min_st_adj_worst] = -gamma*b
|
| 426 |
+
pass
|
| 427 |
+
return (L, dLdOut, doBpp)
|
| 428 |
+
def Test(self, sentenceObj, dcsObj, dsbz2_name, _dump = False, _outFile = None):
|
| 429 |
+
if _dump:
|
| 430 |
+
if _outFile is None:
|
| 431 |
+
raise Exception('WTH r u thinking! pass me outFolder')
|
| 432 |
+
if self.neuralnet.version == 'h1':
|
| 433 |
+
self.neuralnet.ForTesting()
|
| 434 |
+
|
| 435 |
+
# with open('gt_cngs.csv','a') as fh:
|
| 436 |
+
# for i in dcsObj.cng:
|
| 437 |
+
# for j in i:
|
| 438 |
+
# print(str(sentenceObj.sent_id)+":"+str(j))
|
| 439 |
+
# wr = csv.writer(fh)
|
| 440 |
+
# wr.writerow([sentenceObj.sent_id,j])
|
| 441 |
+
|
| 442 |
+
# return
|
| 443 |
+
neuralnet = self.neuralnet
|
| 444 |
+
minScore = np.inf
|
| 445 |
+
minMst = None
|
| 446 |
+
|
| 447 |
+
# dsbz2_name = sentenceObj.sent_id + '.ds.bz2'
|
| 448 |
+
(nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 449 |
+
nodelist, conflicts_Dict, featVMat) = open_dsbz2(dsbz2_name)
|
| 450 |
+
|
| 451 |
+
# if len(nodelist) > 50:
|
| 452 |
+
# return None
|
| 453 |
+
|
| 454 |
+
if not self.neuralnet.outer_relu:
|
| 455 |
+
(WScalarMat, SigmoidGateOutput) = Get_W_Scalar_Matrix_from_FeatVect_Matrix(featVMat, nodelist, conflicts_Dict, neuralnet)
|
| 456 |
+
else:
|
| 457 |
+
WScalarMat = Get_W_Scalar_Matrix_from_FeatVect_Matrix(featVMat, nodelist, conflicts_Dict, neuralnet)
|
| 458 |
+
|
| 459 |
+
# print('NeuralNet Time: ', time.time() - startT)
|
| 460 |
+
# startT = time.time()
|
| 461 |
+
|
| 462 |
+
# Get all MST
|
| 463 |
+
# print('before getting all cliques')
|
| 464 |
+
for source in range(len(nodelist)):
|
| 465 |
+
(mst_nodes, mst_adj_graph, _) = clique(nodelist, WScalarMat, conflicts_Dict, source)
|
| 466 |
+
# print('.', end = '')
|
| 467 |
+
score = GetMSTWeight(mst_adj_graph, WScalarMat)
|
| 468 |
+
if(score < minScore):
|
| 469 |
+
minScore = score
|
| 470 |
+
minMst = mst_nodes
|
| 471 |
+
# print('after getting all cliques')
|
| 472 |
+
dcsLemmas = [[rom_slp(l) for l in arr]for arr in dcsObj.lemmas]
|
| 473 |
+
word_match = 0
|
| 474 |
+
lemma_match = 0
|
| 475 |
+
n_output_nodes = 0
|
| 476 |
+
|
| 477 |
+
if _dump:
|
| 478 |
+
predicted_lemmas = [sentenceObj.sent_id]
|
| 479 |
+
predicted_cngs = [sentenceObj.sent_id]
|
| 480 |
+
predicted_chunk_id = [sentenceObj.sent_id]
|
| 481 |
+
predicted_pos = [sentenceObj.sent_id]
|
| 482 |
+
predicted_id = [sentenceObj.sent_id]
|
| 483 |
+
|
| 484 |
+
for chunk_id, wdSplit in minMst.items():
|
| 485 |
+
for wd in wdSplit:
|
| 486 |
+
if _dump:
|
| 487 |
+
predicted_lemmas.append(wd.lemma)
|
| 488 |
+
predicted_cngs.append(wd.cng)
|
| 489 |
+
predicted_chunk_id.append(wd.chunk_id)
|
| 490 |
+
predicted_pos.append(wd.pos)
|
| 491 |
+
predicted_id.append(wd.id)
|
| 492 |
+
|
| 493 |
+
n_output_nodes += 1
|
| 494 |
+
# Match lemma
|
| 495 |
+
search_result = [i for i, j in enumerate(dcsLemmas[chunk_id]) if j == wd.lemma]
|
| 496 |
+
if len(search_result) > 0:
|
| 497 |
+
lemma_match += 1
|
| 498 |
+
# Match CNG
|
| 499 |
+
for i in search_result:
|
| 500 |
+
if(dcsObj.cng[chunk_id][i] == str(wd.cng)):
|
| 501 |
+
word_match += 1
|
| 502 |
+
# print(wd.lemma, wd.cng)
|
| 503 |
+
break
|
| 504 |
+
dcsLemmas = [l for arr in dcsObj.lemmas for l in arr]
|
| 505 |
+
|
| 506 |
+
if _dump:
|
| 507 |
+
with open(_outFile, 'a') as fh:
|
| 508 |
+
dcsv = csv.writer(fh)
|
| 509 |
+
dcsv.writerow(predicted_lemmas)
|
| 510 |
+
dcsv.writerow(predicted_cngs)
|
| 511 |
+
dcsv.writerow(predicted_chunk_id)
|
| 512 |
+
dcsv.writerow(predicted_pos)
|
| 513 |
+
dcsv.writerow(predicted_id)
|
| 514 |
+
dcsv.writerow([sentenceObj.sent_id, word_match, lemma_match, len(dcsLemmas), n_output_nodes])
|
| 515 |
+
|
| 516 |
+
# print('All MST Time: ', time.time() - startT)
|
| 517 |
+
# print('Node Count: ', len(nodelist))
|
| 518 |
+
# print('\nFull Match: {}, Partial Match: {}, OutOf {}, NodeCount: {}, '.\
|
| 519 |
+
# format(word_match, lemma_match, len(dcsLemmas), len(nodelist)))
|
| 520 |
+
return (word_match, lemma_match, len(dcsLemmas), n_output_nodes)
|
| 521 |
+
|
| 522 |
+
def Train(self, sentenceObj, dcsObj, bz2_input_folder, _debug = True):
|
| 523 |
+
self.neuralnet.ForTraining()
|
| 524 |
+
self.neuralnet.new_dropout() # renew dropout setting
|
| 525 |
+
# Hyperparameter for hinge loss: m
|
| 526 |
+
m_hinge_param = 14
|
| 527 |
+
|
| 528 |
+
dsbz2_name = sentenceObj.sent_id + '.ds.bz2'
|
| 529 |
+
(nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 530 |
+
nodelist, conflicts_Dict, featVMat) = open_dsbz2(bz2_input_folder + dsbz2_name)
|
| 531 |
+
|
| 532 |
+
sub = 0
|
| 533 |
+
for s in conflicts_Dict.keys():
|
| 534 |
+
sub = sub+len(conflicts_Dict[s])
|
| 535 |
+
|
| 536 |
+
with open('bron.csv','a') as fh:
|
| 537 |
+
rd = csv.writer(fh)
|
| 538 |
+
rd.writerow([str(dsbz2_name),str(len(nodelist)),str(sub)])
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
print(dsbz2_name)
|
| 542 |
+
print("NodeLength: "+str(len(nodelist)))
|
| 543 |
+
|
| 544 |
+
'''
|
| 545 |
+
if(len(nodelist)>40):
|
| 546 |
+
with open('bron.csv','a') as fh:
|
| 547 |
+
rd = csv.writer(fh)
|
| 548 |
+
rd.writerow([str(dsbz2_name),'0'])
|
| 549 |
+
return
|
| 550 |
+
'''
|
| 551 |
+
if(len(nodelist)>100):
|
| 552 |
+
print("Nodelength : "+str(len(nodelist)))
|
| 553 |
+
# Train for large graphs separately
|
| 554 |
+
# if len(nodelist) < 40:
|
| 555 |
+
# return
|
| 556 |
+
|
| 557 |
+
""" FORM MAXIMUM(ENERGY) SPANNING TREE OF THE GOLDEN GRAPH : WORST GOLD STRUCTURE """
|
| 558 |
+
WScalarMat_correct = Get_W_Scalar_Matrix_from_FeatVect_Matrix(featVMat_correct, nodelist_correct,\
|
| 559 |
+
conflicts_Dict_correct, self.neuralnet)
|
| 560 |
+
source = 0
|
| 561 |
+
""" Find the max spanning tree : negative Weight matrix passed """
|
| 562 |
+
# (min_st_gold_ndict, min_st_adj_gold_small, _) =\
|
| 563 |
+
# clique(nodelist_correct, -WScalarMat_correct, conflicts_Dict_correct, source)
|
| 564 |
+
# print('sentence start')
|
| 565 |
+
|
| 566 |
+
# print('stuck after this maybe')
|
| 567 |
+
# print(WScalarMat_correct)
|
| 568 |
+
# print(nodelist_correct)
|
| 569 |
+
# print('*')
|
| 570 |
+
|
| 571 |
+
R = set()
|
| 572 |
+
P = set()
|
| 573 |
+
X = set()
|
| 574 |
+
for i in range(len(nodelist_correct)):
|
| 575 |
+
P.add(i)
|
| 576 |
+
|
| 577 |
+
# print('BEFORE'*3)
|
| 578 |
+
# print('BFR')
|
| 579 |
+
# print('#'*50)
|
| 580 |
+
bcliq = bron(R,P,X,nodelist_correct,conflicts_Dict_correct,1)
|
| 581 |
+
with open('bron.csv','a') as fh:
|
| 582 |
+
rd = csv.writer(fh)
|
| 583 |
+
rd.writerow([str(dsbz2_name),str(len(bcliq))])
|
| 584 |
+
# print('AFR')
|
| 585 |
+
# print('-'*40)
|
| 586 |
+
min_st_adj_gold_small = np.ndarray(WScalarMat_correct.shape, np.bool)*False
|
| 587 |
+
# print('p1')
|
| 588 |
+
for i in bcliq[0]:
|
| 589 |
+
for j in bcliq[0]:
|
| 590 |
+
if(i==j):
|
| 591 |
+
continue
|
| 592 |
+
min_st_adj_gold_small[i][j]=True
|
| 593 |
+
# print('p2')
|
| 594 |
+
# print(' before bron')
|
| 595 |
+
# print(len(bcliq),bcliq)
|
| 596 |
+
# print('*'*40)
|
| 597 |
+
# print('after bron')
|
| 598 |
+
|
| 599 |
+
# (min_st_gold_ndict, min_st_adj_gold_small, _) =\
|
| 600 |
+
# clique(nodelist_correct, WScalarMat_correct, conflicts_Dict_correct, source)
|
| 601 |
+
# print('found a way out')
|
| 602 |
+
# print('AFTER'*3)
|
| 603 |
+
# print('-'*30)
|
| 604 |
+
energy_gold_max_ST = np.sum(WScalarMat_correct[min_st_adj_gold_small])
|
| 605 |
+
# print("Gold: "+str(energy_gold_max_ST))
|
| 606 |
+
""" Convert correct spanning tree graph adj matrix to full marix dimensions """
|
| 607 |
+
""" Create full-size adjacency matrix for correct_mst_small """
|
| 608 |
+
|
| 609 |
+
nodelen = len(nodelist)
|
| 610 |
+
# print(nodelen)
|
| 611 |
+
# print('p2.5')
|
| 612 |
+
# try:
|
| 613 |
+
min_st_adj_gold = np.ndarray((nodelen, nodelen), np.bool)*False # T_STAR
|
| 614 |
+
for i in range(min_st_adj_gold_small.shape[0]):
|
| 615 |
+
for j in range(min_st_adj_gold_small.shape[1]):
|
| 616 |
+
min_st_adj_gold[nodelist_to_correct_mapping[i], nodelist_to_correct_mapping[j]] =\
|
| 617 |
+
min_st_adj_gold_small[i, j]
|
| 618 |
+
# except Exception as e:
|
| 619 |
+
# print(e)
|
| 620 |
+
# print('p3')
|
| 621 |
+
""" Delta(Margin) Function : MASK FOR WHICH NODES IN NODELIST BELONG TO DCS """
|
| 622 |
+
gold_nodes_mask = np.array([False]*len(nodelist))
|
| 623 |
+
gold_nodes_mask[list(nodelist_to_correct_mapping.values())] = True
|
| 624 |
+
margin_f = lambda nodes_mask: np.sum(nodes_mask&(~gold_nodes_mask))**2
|
| 625 |
+
|
| 626 |
+
""" FOR ALL POSSIBLE MST FROM THE COMPLETE GRAPH """
|
| 627 |
+
# print(WScalarMat)
|
| 628 |
+
WScalarMat = Get_W_Scalar_Matrix_from_FeatVect_Matrix(featVMat, nodelist, conflicts_Dict, self.neuralnet)
|
| 629 |
+
""" For each node - Find MST with that source"""
|
| 630 |
+
min_STx = None # Min Energy spanning tree with worst margin with gold_STx
|
| 631 |
+
min_marginalized_energy = np.inf
|
| 632 |
+
# print('p4')
|
| 633 |
+
# Generate random set of nodes from which mSTs are to be considered
|
| 634 |
+
n_nodes = len(nodelist)
|
| 635 |
+
selection_prob = 0.4
|
| 636 |
+
select_flag = np.random.rand(n_nodes) < selection_prob
|
| 637 |
+
# Fix if all zeros
|
| 638 |
+
if np.sum(select_flag) == 0:
|
| 639 |
+
select_flag[np.random.randint(n_nodes)] = 1
|
| 640 |
+
|
| 641 |
+
best_node_diff = np.Inf
|
| 642 |
+
best_energy = np.inf
|
| 643 |
+
# print('before###')
|
| 644 |
+
|
| 645 |
+
R = set()
|
| 646 |
+
P = set()
|
| 647 |
+
X = set()
|
| 648 |
+
# print('bfr')
|
| 649 |
+
for i in range(len(nodelist)):
|
| 650 |
+
P.add(i)
|
| 651 |
+
# try:
|
| 652 |
+
bcliq = (bron(R,P,X,nodelist,conflicts_Dict,1))
|
| 653 |
+
print('Enumerated number of cliques :'+str(len(bcliq)))
|
| 654 |
+
# except Exception as e:
|
| 655 |
+
# print(e)
|
| 656 |
+
# print('aftr')
|
| 657 |
+
for clique in bcliq:
|
| 658 |
+
# print('b1')
|
| 659 |
+
mst_adj_graph = np.ndarray(WScalarMat.shape, np.bool)*False
|
| 660 |
+
mst_nodes_bool = np.array([False]*len(nodelist))
|
| 661 |
+
for nd in clique:
|
| 662 |
+
mst_nodes_bool[nd] = True
|
| 663 |
+
for nd2 in clique:
|
| 664 |
+
if(nd==nd2):
|
| 665 |
+
continue
|
| 666 |
+
mst_adj_graph[nd][nd2] = True
|
| 667 |
+
# print("b2")
|
| 668 |
+
en_st = np.sum(WScalarMat[mst_adj_graph])
|
| 669 |
+
# Pick up the node_diff with lowest energy
|
| 670 |
+
delta_st = margin_f(mst_nodes_bool)
|
| 671 |
+
|
| 672 |
+
if _debug:
|
| 673 |
+
if best_energy > en_st:
|
| 674 |
+
best_node_diff = delta_st
|
| 675 |
+
best_energy = en_st
|
| 676 |
+
# print('b3')
|
| 677 |
+
# Minimum marginalized energy calculation
|
| 678 |
+
marginalized_en = en_st - delta_st
|
| 679 |
+
# print("Source:"+str(source)+"; Energy"+str(min_marginalized_energy))
|
| 680 |
+
# Minimum marginalized spanning tree : Randomization applied
|
| 681 |
+
# if marginalized_en < min_marginalized_energy and select_flag[source]:
|
| 682 |
+
if marginalized_en < min_marginalized_energy:
|
| 683 |
+
min_marginalized_energy = marginalized_en
|
| 684 |
+
min_STx = mst_adj_graph
|
| 685 |
+
# Energy diff should all be negative
|
| 686 |
+
if _debug:
|
| 687 |
+
print('Source: [{}], Node_Diff:{}, Max_Gold_En: {:.3f}, Energy: {:.3f}'.\
|
| 688 |
+
format(source, np.sum((~gold_nodes_mask)&mst_nodes_bool), energy_gold_max_ST, np.sum(WScalarMat[mst_adj_graph])))
|
| 689 |
+
# print('b4')
|
| 690 |
+
# for source in range(len(nodelist)):
|
| 691 |
+
# (mst_nodes, mst_adj_graph, mst_nodes_bool) = clique(nodelist, WScalarMat, conflicts_Dict, source) # T_X
|
| 692 |
+
# # Calculate energy of spanning tree
|
| 693 |
+
# en_st = np.sum(WScalarMat[mst_adj_graph])
|
| 694 |
+
# # Pick up the node_diff with lowest energy
|
| 695 |
+
# delta_st = margin_f(mst_nodes_bool)
|
| 696 |
+
|
| 697 |
+
# if _debug:
|
| 698 |
+
# if best_energy > en_st:
|
| 699 |
+
# best_node_diff = delta_st
|
| 700 |
+
# best_energy = en_st
|
| 701 |
+
|
| 702 |
+
# # Minimum marginalized energy calculation
|
| 703 |
+
# marginalized_en = en_st - delta_st
|
| 704 |
+
# # print("Source:"+str(source)+"; Energy"+str(min_marginalized_energy))
|
| 705 |
+
# # Minimum marginalized spanning tree : Randomization applied
|
| 706 |
+
# # if marginalized_en < min_marginalized_energy and select_flag[source]:
|
| 707 |
+
# if marginalized_en < min_marginalized_energy:
|
| 708 |
+
# min_marginalized_energy = marginalized_en
|
| 709 |
+
# min_STx = mst_adj_graph
|
| 710 |
+
# # Energy diff should all be negative
|
| 711 |
+
# if _debug:
|
| 712 |
+
# print('Source: [{}], Node_Diff:{}, Max_Gold_En: {:.3f}, Energy: {:.3f}'.\
|
| 713 |
+
# format(source, np.sum((~gold_nodes_mask)&mst_nodes_bool), energy_gold_max_ST, np.sum(WScalarMat[mst_adj_graph])))
|
| 714 |
+
# print('after###')
|
| 715 |
+
# print("Min-Energy:"+str(min_marginalized_energy))
|
| 716 |
+
# print("Gold-Energy"+str(np.sum(WScalarMat[min_st_adj_gold])))
|
| 717 |
+
# print("*"*40)
|
| 718 |
+
if _debug:
|
| 719 |
+
print('Best Node diff: {} with EN: {}'.format(np.sqrt(best_node_diff), best_energy))
|
| 720 |
+
""" Gradient Descent """
|
| 721 |
+
# LOSS TYPES -> hinge(0), log-loss(1), square-exponential(2)
|
| 722 |
+
Total_Loss, dLdOut, doBpp = self.CalculateLoss_n_Grads(WScalarMat, min_STx, min_st_adj_gold,\
|
| 723 |
+
loss_type = 0, min_marginalized_energy = min_marginalized_energy)
|
| 724 |
+
if doBpp:
|
| 725 |
+
if _debug:
|
| 726 |
+
print('{}. '.format(sentenceObj.sent_id), end = '')
|
| 727 |
+
self.neuralnet.Back_Prop(dLdOut, len(nodelist), featVMat, _debug)
|
| 728 |
+
else:
|
| 729 |
+
trainingStatus[sentenceObj.sent_id] = True
|
| 730 |
+
if _debug:
|
| 731 |
+
print("\nFileKey: %s, Loss: %6.3f" % (sentenceObj.sent_id, Total_Loss))
|
| 732 |
+
|
| 733 |
+
TrainFiles = None
|
| 734 |
+
trainer = None
|
| 735 |
+
p_name = ''
|
| 736 |
+
odir = ''
|
| 737 |
+
def InitModule():
|
| 738 |
+
global trainer
|
| 739 |
+
trainer = Trainer()
|
| 740 |
+
|
| 741 |
+
def register_nnet(nnet, bz2_input_folder):
|
| 742 |
+
if not os.path.isdir(odir):
|
| 743 |
+
os.mkdir(odir)
|
| 744 |
+
if not os.path.isfile('outputs/nnet_LOGS.csv'):
|
| 745 |
+
with open('outputs/nnet_LOGS.csv', 'a') as fh:
|
| 746 |
+
csv_r = csv.writer(fh)
|
| 747 |
+
csv_r.writerow(['odir', 'p_name', 'hidden_layer_size', '_edge_vector_dim'])
|
| 748 |
+
with open('outputs/nnet_LOGS.csv', 'a') as fh:
|
| 749 |
+
csv_r = csv.writer(fh)
|
| 750 |
+
if nnet.version == 'h1':
|
| 751 |
+
csv_r.writerow([odir, p_name, nnet.n, nnet.d, bz2_input_folder])
|
| 752 |
+
elif nnet.version == 'h2':
|
| 753 |
+
csv_r.writerow([odir, p_name, nnet.h1, nnet.h2, nnet.d, bz2_input_folder])
|
| 754 |
+
|
| 755 |
+
"""
|
| 756 |
+
################################################################################################
|
| 757 |
+
################################################################################################
|
| 758 |
+
################################################################################################
|
| 759 |
+
"""
|
| 760 |
+
def main():
|
| 761 |
+
global TrainFiles, p_name, odir
|
| 762 |
+
|
| 763 |
+
"""
|
| 764 |
+
################################################################################################
|
| 765 |
+
############################## GET A FILENAME TO SAVE WEIGHTS ################################
|
| 766 |
+
################################################################################################
|
| 767 |
+
"""
|
| 768 |
+
st = str(int((time.time() * 1e6) % 1e13))
|
| 769 |
+
log_name = 'logs/train_nnet_t{}.out'.format(st)
|
| 770 |
+
odir = 'outputs/train_t{}'.format(st)
|
| 771 |
+
p_name = 'outputs/train_t{}/nnet.p'.format(st)
|
| 772 |
+
print('nEURAL nET wILL bE sAVED hERE: ', p_name)
|
| 773 |
+
|
| 774 |
+
# Create Training File List
|
| 775 |
+
excluded_files = []
|
| 776 |
+
with open('inputs/Baseline4_advSample.csv', 'r') as f_handle:
|
| 777 |
+
opener = csv.reader(f_handle)
|
| 778 |
+
for line in opener:
|
| 779 |
+
excluded_files.append(line[1].replace('.p', '.ds.bz2'))
|
| 780 |
+
|
| 781 |
+
# Load Simultaneous files
|
| 782 |
+
print('Loading Large Files')
|
| 783 |
+
loaded_SKT = pickle.load(open('../Simultaneous_CompatSKT_10K.p', 'rb'), encoding=u'utf-8')
|
| 784 |
+
loaded_DCS = pickle.load(open('../Simultaneous_DCS_10K.p', 'rb'), encoding=u'utf-8')
|
| 785 |
+
|
| 786 |
+
# loaded_SKT = pickle.load(open('../Simultaneous_CompatSKT.p', 'rb'), encoding=u'utf-8')
|
| 787 |
+
# loaded_DCS = pickle.load(open('../Simultaneous_DCS.p', 'rb'), encoding=u'utf-8')
|
| 788 |
+
|
| 789 |
+
bz2_input_folder = '../NewData/skt_dcs_DS.bz2_4K_bigram_mir_10K/' #bm2
|
| 790 |
+
# bz2_input_folder = '../NewData/skt_dcs_DS.bz2_1L_bigram_mir_10K/' #bm3
|
| 791 |
+
# bz2_input_folder = '../NewData/skt_dcs_DS.bz2_4K_bigram_rfe_10K/' #br2
|
| 792 |
+
# bz2_input_folder = '../NewData/skt_dcs_DS.bz2_1L_bigram_rfe_10K/' #br3
|
| 793 |
+
# bz2_input_folder = '../NewData/skt_dcs_DS.bz2_4K_pmi_mir_10K/' #pm2
|
| 794 |
+
# bz2_input_folder = '../NewData/skt_dcs_DS.bz2_1L_pmi_mir_10K2/' #pm3
|
| 795 |
+
# bz2_input_folder = '../NewData/skt_dcs_DS.bz2_4K_pmi_rfe_10K/' #pr2
|
| 796 |
+
# bz2_input_folder = '../NewData/skt_dcs_DS.bz2_1L_pmi_rfe_10K/' #pr3 # bz2_input_folder = '/home/rs/15CS91R05/vishnu/Data/skt_dcs_DS.bz2_compat_10k_check_again/'
|
| 797 |
+
all_files = []
|
| 798 |
+
skipped = 0
|
| 799 |
+
for f in os.listdir(bz2_input_folder):
|
| 800 |
+
if '.ds.bz2' in f:
|
| 801 |
+
if f in excluded_files:
|
| 802 |
+
skipped += 1
|
| 803 |
+
continue
|
| 804 |
+
if f.replace('.ds.bz2', '.p2') not in loaded_DCS:
|
| 805 |
+
print('Couldnt find ', f)
|
| 806 |
+
continue
|
| 807 |
+
all_files.append(f)
|
| 808 |
+
|
| 809 |
+
print(skipped, 'files will not be used for training')
|
| 810 |
+
print('Size of training set:', len(all_files))
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
all_files=['32517.p2']
|
| 814 |
+
TrainFiles = all_files
|
| 815 |
+
with open('bron.csv','w') as fh:
|
| 816 |
+
rd = csv.writer(fh)
|
| 817 |
+
rd.writerow(['FileName','Nodelength & NCliques'])
|
| 818 |
+
|
| 819 |
+
InitModule()
|
| 820 |
+
trainingStatus = defaultdict(lambda: bool(False))
|
| 821 |
+
# train = train_generator(loaded_SKT, loaded_DCS, bz2_input_folder, n_trainset = -1, filePerBatch = 10, iterationPerBatch = 5, _debug=False, superEpochs = 5)
|
| 822 |
+
train = train_generator(loaded_SKT, loaded_DCS, bz2_input_folder, n_trainset = -1, filePerBatch = 20, iterationPerBatch = 3, _debug=False, superEpochs = 2)
|
| 823 |
+
|
| 824 |
+
# Complete Training
|
| 825 |
+
train.__next__()
|
| 826 |
+
|
| 827 |
+
print('Training Complete')
|
| 828 |
+
|
| 829 |
+
if __name__ == '__main__':
|
| 830 |
+
main()
|
dir/Train_clique.py
ADDED
|
@@ -0,0 +1,769 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
IMPORTS
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
## bUILT-iN pACKAGES
|
| 6 |
+
import sys, os, time, bz2, zlib, pickle, math, json, csv
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
import numpy as np
|
| 9 |
+
np.set_printoptions(suppress=True)
|
| 10 |
+
from IPython.display import display
|
| 11 |
+
|
| 12 |
+
## lAST sUMMER
|
| 13 |
+
from romtoslp import *
|
| 14 |
+
from sentences import *
|
| 15 |
+
from DCS import *
|
| 16 |
+
import MatDB
|
| 17 |
+
from heap_n_clique import *
|
| 18 |
+
from ECL_MST import *
|
| 19 |
+
|
| 20 |
+
import word_definite as WD
|
| 21 |
+
# from heap_n_PrimMST import *
|
| 22 |
+
from nnet import *
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
################################################################################################
|
| 26 |
+
########################### LOAD SENTENCE AND DCS OBJECT FILES ###############################
|
| 27 |
+
################################################################################################
|
| 28 |
+
"""
|
| 29 |
+
# loaded_SKT = pickle.load(open('../Simultaneous_CompatSKT_10K.p', 'rb'))
|
| 30 |
+
# loaded_DCS = pickle.load(open('../Simultaneous_DCS_10K.p', 'rb'))
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
################################################################################################
|
| 34 |
+
########################### OPENS AND EXTRACTS DCS AND SKT DATA STRUCTURES ###################
|
| 35 |
+
###################################### FROM BZ2 FILES ########################################
|
| 36 |
+
"""
|
| 37 |
+
def open_dsbz2(filename):
|
| 38 |
+
with bz2.BZ2File(filename, 'r') as f:
|
| 39 |
+
loader = pickle.load(f)
|
| 40 |
+
|
| 41 |
+
conflicts_Dict_correct = loader['conflicts_Dict_correct']
|
| 42 |
+
nodelist_to_correct_mapping = loader['nodelist_to_correct_mapping']
|
| 43 |
+
nodelist_correct = loader['nodelist_correct']
|
| 44 |
+
featVMat_correct = loader['featVMat_correct']
|
| 45 |
+
featVMat = loader['featVMat']
|
| 46 |
+
conflicts_Dict = loader['conflicts_Dict']
|
| 47 |
+
nodelist = loader['nodelist']
|
| 48 |
+
|
| 49 |
+
return (nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 50 |
+
nodelist, conflicts_Dict, featVMat)
|
| 51 |
+
|
| 52 |
+
"""
|
| 53 |
+
################################################################################################
|
| 54 |
+
###################### CREATE SEVERAL DATA STRUCTURES FROM SENTENCE/DCS ######################
|
| 55 |
+
########################### NODELIST, ADJACENCY LIST, GRAPH, HEAP #############################
|
| 56 |
+
################################################################################################
|
| 57 |
+
"""
|
| 58 |
+
def GetTrainingKit(sentenceObj, dcsObj):
|
| 59 |
+
nodelist = GetNodes(sentenceObj)
|
| 60 |
+
|
| 61 |
+
# Nodelist with only the correct_nodes
|
| 62 |
+
nodelist2 = GetNodes(sentenceObj)
|
| 63 |
+
nodelist2_to_correct_mapping = {}
|
| 64 |
+
nodelist_correct = []
|
| 65 |
+
search_key = 0
|
| 66 |
+
first_key = 0
|
| 67 |
+
for chunk_id in range(len(dcsObj.lemmas)):
|
| 68 |
+
while nodelist2[first_key].chunk_id != chunk_id:
|
| 69 |
+
first_key += 1
|
| 70 |
+
for j in range(len(dcsObj.lemmas[chunk_id])):
|
| 71 |
+
search_key = first_key
|
| 72 |
+
while (nodelist2[search_key].lemma != rom_slp(dcsObj.lemmas[chunk_id][j])) or (nodelist2[search_key].cng != dcsObj.cng[chunk_id][j]):
|
| 73 |
+
search_key += 1
|
| 74 |
+
if search_key >= len(nodelist2) or nodelist2[search_key].chunk_id > chunk_id:
|
| 75 |
+
break
|
| 76 |
+
# print((rom_slp(dcsObj.lemmas[chunk_id][j]), dcsObj.cng[chunk_id][j]))
|
| 77 |
+
# print(nodelist[search_key])
|
| 78 |
+
nodelist2_to_correct_mapping[len(nodelist_correct)] = search_key
|
| 79 |
+
nodelist_correct.append(nodelist2[search_key])
|
| 80 |
+
return (nodelist, nodelist_correct, nodelist2_to_correct_mapping)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def GetGraph(nodelist, neuralnet):
|
| 84 |
+
if not neuralnet.outer_relu:
|
| 85 |
+
conflicts_Dict = Get_Conflicts(nodelist)
|
| 86 |
+
|
| 87 |
+
featVMat = Get_Feat_Vec_Matrix(nodelist, conflicts_Dict)
|
| 88 |
+
|
| 89 |
+
(WScalarMat, SigmoidGateOutput) = Get_W_Scalar_Matrix_from_FeatVect_Matrix(featVMat, nodelist, conflicts_Dict, neuralnet)
|
| 90 |
+
return (conflicts_Dict, featVMat, WScalarMat, SigmoidGateOutput)
|
| 91 |
+
else:
|
| 92 |
+
conflicts_Dict = Get_Conflicts(nodelist)
|
| 93 |
+
|
| 94 |
+
featVMat = Get_Feat_Vec_Matrix(nodelist, conflicts_Dict)
|
| 95 |
+
|
| 96 |
+
WScalarMat = Get_W_Scalar_Matrix_from_FeatVect_Matrix(featVMat, nodelist, conflicts_Dict, neuralnet)
|
| 97 |
+
return (conflicts_Dict, featVMat, WScalarMat)
|
| 98 |
+
|
| 99 |
+
# NEW LOSS FUNCTION
|
| 100 |
+
def GetLoss(_mst_adj_graph, _mask_de_correct_edges, _negLogLikelies):
|
| 101 |
+
_negLogLikelies = _negLogLikelies.copy()
|
| 102 |
+
_negLogLikelies[~_mst_adj_graph] = 0
|
| 103 |
+
_negLogLikelies[~_mask_de_correct_edges] *= -1 # BAKA!!! Check before you try to fix this again
|
| 104 |
+
return np.sum(_negLogLikelies)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
"""
|
| 109 |
+
################################################################################################
|
| 110 |
+
############################## MAIN FUNCTION #################################################
|
| 111 |
+
################################################################################################
|
| 112 |
+
"""
|
| 113 |
+
trainingStatus = defaultdict(lambda: bool(False))
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
"""
|
| 117 |
+
################################################################################################
|
| 118 |
+
############################## TRAIN FUNCTION ################################################
|
| 119 |
+
################################################################################################
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def train_generator(loaded_SKT, loaded_DCS, bz2_input_folder, n_trainset = -1, iterationPerBatch = 10, filePerBatch = 20, _debug = True, superEpochs = 1):
|
| 123 |
+
# Train
|
| 124 |
+
if n_trainset == -1:
|
| 125 |
+
n_trainset = len(TrainFiles)
|
| 126 |
+
totalBatchToTrain = math.ceil(n_trainset/filePerBatch)
|
| 127 |
+
else:
|
| 128 |
+
totalBatchToTrain = math.ceil(n_trainset/filePerBatch)
|
| 129 |
+
|
| 130 |
+
register_nnet(trainer.neuralnet, bz2_input_folder)
|
| 131 |
+
print('Epoch:'+str(superEpochs))
|
| 132 |
+
print('iters:'+str(totalBatchToTrain))
|
| 133 |
+
for _epoch in range(superEpochs):
|
| 134 |
+
for iterout in range(totalBatchToTrain):
|
| 135 |
+
# Add timer
|
| 136 |
+
startT = time.time()
|
| 137 |
+
|
| 138 |
+
# Change current batch
|
| 139 |
+
if(iterout % 50 == 0):
|
| 140 |
+
trainer.Save(p_name.replace('.p', '_e{}_i{}.p'.format(_epoch, iterout)))
|
| 141 |
+
else:
|
| 142 |
+
trainer.Save(p_name)
|
| 143 |
+
print('Epoch: {}, Batch: {}'.format(_epoch, iterout))
|
| 144 |
+
files_for_batch = TrainFiles[iterout*filePerBatch:(iterout + 1)*filePerBatch]
|
| 145 |
+
# print(files_for_batch)
|
| 146 |
+
# trainer.Load('outputs/neuralnet_trained.p')
|
| 147 |
+
try:
|
| 148 |
+
# Run few times on same set of files
|
| 149 |
+
for iterin in range(iterationPerBatch):
|
| 150 |
+
print('ITERATION IN', iterin)
|
| 151 |
+
for fn in files_for_batch:
|
| 152 |
+
stime = time.time()
|
| 153 |
+
trainFileName = fn.replace('.ds.bz2', '.p2')
|
| 154 |
+
sentenceObj = loaded_SKT[trainFileName]
|
| 155 |
+
# print(trainFileName)
|
| 156 |
+
|
| 157 |
+
dcsObj = loaded_DCS[trainFileName]
|
| 158 |
+
if trainingStatus[sentenceObj.sent_id]:
|
| 159 |
+
continue
|
| 160 |
+
# trainer.Save('outputs/saved_trainer.p')
|
| 161 |
+
try:
|
| 162 |
+
trainer.Train(sentenceObj, dcsObj, bz2_input_folder, _debug)
|
| 163 |
+
except (IndexError, KeyError) as e:
|
| 164 |
+
print(e)
|
| 165 |
+
print('\x1b[31mFailed: {} \x1b[0m'.format(sentenceObj.sent_id))
|
| 166 |
+
except EOFError as e:
|
| 167 |
+
print('\x1b[31mBADFILE: {} \x1b[0m'.format(sentenceObj.sent_id))
|
| 168 |
+
ftime = time.time()
|
| 169 |
+
print('Time taken for file '+str(trainFileName)+'is '+str(ftime-stime)+"seconds")
|
| 170 |
+
with open('cliq.csv','a') as fh:
|
| 171 |
+
rd = csv.writer(fh)
|
| 172 |
+
rd.writerow([str(ftime-stime)])
|
| 173 |
+
sys.stdout.flush() # Flush IO buffer
|
| 174 |
+
finishT = time.time()
|
| 175 |
+
print('Avg. time taken by 1 file(1 iteration): {:.3f}'.format((finishT - startT)/(iterationPerBatch*filePerBatch)))
|
| 176 |
+
except KeyboardInterrupt:
|
| 177 |
+
print('Training paused')
|
| 178 |
+
trainer.Save(p_name)
|
| 179 |
+
yield None
|
| 180 |
+
trainer.Save(p_name)
|
| 181 |
+
|
| 182 |
+
def test(loaded_SKT, loaded_DCS, n_testSet = -1, _testFiles = None, n_checkpt = 100):
|
| 183 |
+
total_lemma = 0;
|
| 184 |
+
correct_lemma = 0;
|
| 185 |
+
|
| 186 |
+
total_word = 0;
|
| 187 |
+
total_output_nodes = 0
|
| 188 |
+
correct_word = 0;
|
| 189 |
+
file_counter = 0
|
| 190 |
+
if _testFiles is None:
|
| 191 |
+
if n_testSet == -1:
|
| 192 |
+
_testFiles = TestFiles
|
| 193 |
+
else:
|
| 194 |
+
_testFiles = TestFiles[0:n_testSet]
|
| 195 |
+
else:
|
| 196 |
+
if n_testSet == -1:
|
| 197 |
+
_testFiles = _testFiles
|
| 198 |
+
else:
|
| 199 |
+
_testFiles = _testFiles[0:n_testSet]
|
| 200 |
+
|
| 201 |
+
recalls = []
|
| 202 |
+
recalls_of_word = []
|
| 203 |
+
precisions = []
|
| 204 |
+
precisions_of_words = []
|
| 205 |
+
for fn in _testFiles:
|
| 206 |
+
if file_counter % n_checkpt == 0:
|
| 207 |
+
print(file_counter,' Checkpoint... ')
|
| 208 |
+
if file_counter > 0:
|
| 209 |
+
print('Avg. Micro Recall of Lemmas: {}'.format(np.mean(np.array(recalls))))
|
| 210 |
+
print('Avg. Micro Recall of Words: {}'.format(np.mean(np.array(recalls_of_word))))
|
| 211 |
+
print('Avg. Micro Precision of Lemmas: {}'.format(np.mean(np.array(precisions))))
|
| 212 |
+
print('Avg. Micro Precision of Words: {}'.format(np.mean(np.array(precisions_of_words))))
|
| 213 |
+
sys.stdout.flush() # Flush IO buffer
|
| 214 |
+
|
| 215 |
+
file_counter += 1
|
| 216 |
+
|
| 217 |
+
testFileName = fn.replace('.ds.bz2', '.p2')
|
| 218 |
+
sentenceObj = loaded_SKT[testFileName]
|
| 219 |
+
# print(testFileName)
|
| 220 |
+
# print(type(testFileName))
|
| 221 |
+
dcsObj = loaded_DCS[testFileName]
|
| 222 |
+
|
| 223 |
+
try:
|
| 224 |
+
(word_match, lemma_match, n_dcsWords, n_output_nodes) = trainer.Test(sentenceObj, dcsObj)
|
| 225 |
+
|
| 226 |
+
recalls.append(lemma_match/n_dcsWords)
|
| 227 |
+
recalls_of_word.append(word_match/n_dcsWords)
|
| 228 |
+
|
| 229 |
+
precisions.append(lemma_match/n_output_nodes)
|
| 230 |
+
precisions_of_words.append(word_match/n_output_nodes)
|
| 231 |
+
|
| 232 |
+
total_lemma += n_dcsWords
|
| 233 |
+
total_word += n_dcsWords
|
| 234 |
+
|
| 235 |
+
total_output_nodes += n_output_nodes
|
| 236 |
+
|
| 237 |
+
correct_lemma += lemma_match
|
| 238 |
+
correct_word += word_match
|
| 239 |
+
except (IndexError, KeyError) as e:
|
| 240 |
+
print('Failed!')
|
| 241 |
+
|
| 242 |
+
print('Avg. Micro Recall of Lemmas: {}'.format(np.mean(np.array(recalls))))
|
| 243 |
+
print('Avg. Micro Recall of Words: {}'.format(np.mean(np.array(recalls_of_word))))
|
| 244 |
+
print('Avg. Micro Precision of Lemmas: {}'.format(np.mean(np.array(precisions))))
|
| 245 |
+
print('Avg. Micro Precision of Words: {}'.format(np.mean(np.array(precisions_of_words))))
|
| 246 |
+
|
| 247 |
+
return (recalls, recalls_of_word, precisions, precisions_of_words)
|
| 248 |
+
|
| 249 |
+
# NEW FUNCTION
|
| 250 |
+
def GetLoss(_mst_adj_graph, _mask_de_correct_edges, _WScalarMat):
|
| 251 |
+
_WScalarMat = _WScalarMat.copy()
|
| 252 |
+
_WScalarMat[_mst_adj_graph&(~_mask_de_correct_edges)] *= -1 # BAKA!!! Check before you try to fix this again
|
| 253 |
+
_WScalarMat[~_mst_adj_graph] = 0
|
| 254 |
+
return np.sum(_WScalarMat)
|
| 255 |
+
|
| 256 |
+
"""
|
| 257 |
+
################################################################################################
|
| 258 |
+
############################# TRAINER CLASS DEFINITION ######################################
|
| 259 |
+
################################################################################################
|
| 260 |
+
"""
|
| 261 |
+
class Trainer:
|
| 262 |
+
def __init__(self, modelFile = None):
|
| 263 |
+
if modelFile is None:
|
| 264 |
+
singleLayer = True
|
| 265 |
+
self._edge_vector_dim = 1500
|
| 266 |
+
if singleLayer:
|
| 267 |
+
self.hidden_layer_size = 1200
|
| 268 |
+
keep_prob = 0.6
|
| 269 |
+
self.neuralnet = NN(self._edge_vector_dim, self.hidden_layer_size, outer_relu=True, keep_prob=keep_prob)
|
| 270 |
+
else:
|
| 271 |
+
# DeepR Network
|
| 272 |
+
self.hidden_layer_size = 800
|
| 273 |
+
self.hidden_layer_size2 = 800
|
| 274 |
+
self.neuralnet = NN_2(self._edge_vector_dim, self.hidden_layer_size,\
|
| 275 |
+
hidden_layer_2_size = self.hidden_layer_size2, outer_relu=True)
|
| 276 |
+
self.history = defaultdict(lambda: list())
|
| 277 |
+
else:
|
| 278 |
+
loader = pickle.load(open(filename, 'rb'))
|
| 279 |
+
|
| 280 |
+
self.neuralnet.n = loader['n']
|
| 281 |
+
self.neuralnet.d = loader['d']
|
| 282 |
+
|
| 283 |
+
self.neuralnet = NN(self._edge_vector_dim, self.hidden_layer_size, outer_relu=True)
|
| 284 |
+
|
| 285 |
+
self.neuralnet.U = loader['U']
|
| 286 |
+
self.neuralnet.W = loader['W']
|
| 287 |
+
self.neuralnet.B1 = loader['B1']
|
| 288 |
+
self.neuralnet.B2 = loader['B2']
|
| 289 |
+
|
| 290 |
+
self.history = defaultdict(lambda: list())
|
| 291 |
+
|
| 292 |
+
# SET LEARNING RATES
|
| 293 |
+
if self.neuralnet.version == 'h1':
|
| 294 |
+
self.neuralnet.etaW = 3e-5
|
| 295 |
+
self.neuralnet.etaB1 = 1e-5
|
| 296 |
+
|
| 297 |
+
self.neuralnet.etaU = 1e-5
|
| 298 |
+
self.neuralnet.etaB2 = 1e-5
|
| 299 |
+
elif self.neuralnet.version == 'h2':
|
| 300 |
+
self.neuralnet.etaW1 = 3e-4
|
| 301 |
+
self.neuralnet.etaB1 = 1e-4
|
| 302 |
+
|
| 303 |
+
self.neuralnet.etaW2 = 1e-4
|
| 304 |
+
self.neuralnet.etaB2 = 1e-4
|
| 305 |
+
|
| 306 |
+
self.neuralnet.etaU = 1e-4
|
| 307 |
+
self.neuralnet.etaB3 = 1e-4
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def Reset(self):
|
| 311 |
+
self.neuralnet = NN(self._edge_vector_dim, self.hidden_layer_size)
|
| 312 |
+
self.history = defaultdict(lambda: list())
|
| 313 |
+
|
| 314 |
+
def Save(self, filename):
|
| 315 |
+
print('Weights Saved: ', filename)
|
| 316 |
+
if self.neuralnet.version == 'h1':
|
| 317 |
+
pickle.dump({
|
| 318 |
+
'U': self.neuralnet.U,
|
| 319 |
+
'W': self.neuralnet.W,
|
| 320 |
+
'n': self.neuralnet.n,
|
| 321 |
+
'd': self.neuralnet.d,
|
| 322 |
+
'B1': self.neuralnet.B1,
|
| 323 |
+
'B2': self.neuralnet.B2,
|
| 324 |
+
'keep_prob': self.neuralnet.keep_prob,
|
| 325 |
+
'version': self.neuralnet.version
|
| 326 |
+
}, open(filename, 'wb'))
|
| 327 |
+
return
|
| 328 |
+
elif self.neuralnet.version == 'h2':
|
| 329 |
+
pickle.dump({
|
| 330 |
+
'U': self.neuralnet.U,
|
| 331 |
+
'B3': self.neuralnet.B3,
|
| 332 |
+
'W2': self.neuralnet.W2,
|
| 333 |
+
'B2': self.neuralnet.B2,
|
| 334 |
+
'W1': self.neuralnet.W1,
|
| 335 |
+
'B1': self.neuralnet.B1,
|
| 336 |
+
'h1': self.neuralnet.h1,
|
| 337 |
+
'h2': self.neuralnet.h2,
|
| 338 |
+
'd': self.neuralnet.d,
|
| 339 |
+
'version': self.neuralnet.version
|
| 340 |
+
}, open(filename, 'wb'))
|
| 341 |
+
return
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def Load(self, filename):
|
| 345 |
+
loader = pickle.load(open(filename, 'rb'))
|
| 346 |
+
if 'version' not in loader: # means 1 hidden layer
|
| 347 |
+
self.neuralnet = NN(self._edge_vector_dim, self.hidden_layer_size, outer_relu=True)
|
| 348 |
+
self.neuralnet.U = loader['U']
|
| 349 |
+
self.neuralnet.W = loader['W']
|
| 350 |
+
self.neuralnet.B1 = loader['B1']
|
| 351 |
+
self.neuralnet.B2 = loader['B2']
|
| 352 |
+
self.neuralnet.hidden_layer_size = loader['n']
|
| 353 |
+
self.neuralnet._edge_vector_dim = loader['d']
|
| 354 |
+
if 'keep_prob' in loader:
|
| 355 |
+
self.neuralnet.keep_prob = loader['keep_prob']
|
| 356 |
+
self.neuralnet.dropout_prob = 1 - loader['keep_prob']
|
| 357 |
+
print('Keep Prob = {}, Dropout = {}'.format(self.neuralnet.keep_prob, self.neuralnet.dropout_prob))
|
| 358 |
+
else:
|
| 359 |
+
if loader['version'] == 'h1':
|
| 360 |
+
self.neuralnet = NN(self._edge_vector_dim, self.hidden_layer_size, outer_relu=True)
|
| 361 |
+
self.neuralnet.U = loader['U']
|
| 362 |
+
self.neuralnet.W = loader['W']
|
| 363 |
+
self.neuralnet.B1 = loader['B1']
|
| 364 |
+
self.neuralnet.B2 = loader['B2']
|
| 365 |
+
self.neuralnet.hidden_layer_size = loader['n']
|
| 366 |
+
self.neuralnet._edge_vector_dim = loader['d']
|
| 367 |
+
if 'keep_prob' in loader:
|
| 368 |
+
self.neuralnet.keep_prob = loader['keep_prob']
|
| 369 |
+
self.neuralnet.dropout_prob = 1 - loader['keep_prob']
|
| 370 |
+
print('Keep Prob = {}, Dropout = {}'.format(self.neuralnet.keep_prob, self.neuralnet.dropout_prob))
|
| 371 |
+
elif loader['version'] == 'h2':
|
| 372 |
+
self.neuralnet = NN_2(self._edge_vector_dim, self.hidden_layer_size, outer_relu=True)
|
| 373 |
+
|
| 374 |
+
self.neuralnet.U = loader['U']
|
| 375 |
+
self.neuralnet.B3 = loader['B3']
|
| 376 |
+
|
| 377 |
+
self.neuralnet.W2 = loader['W2']
|
| 378 |
+
self.neuralnet.B2 = loader['B2']
|
| 379 |
+
|
| 380 |
+
self.neuralnet.W1 = loader['W1']
|
| 381 |
+
self.neuralnet.B1 = loader['B1']
|
| 382 |
+
|
| 383 |
+
self.neuralnet.h1 = loader['h1']
|
| 384 |
+
self.neuralnet.h2 = loader['h2']
|
| 385 |
+
self.neuralnet.d = loader['d']
|
| 386 |
+
|
| 387 |
+
def CalculateLoss_n_Grads(self, WScalarMat, min_st_adj_worst, max_st_adj_gold, loss_type = 0, min_marginalized_energy = None):
|
| 388 |
+
doBpp = True
|
| 389 |
+
|
| 390 |
+
# Claculate the enrgies
|
| 391 |
+
etg = np.sum(WScalarMat[max_st_adj_gold])
|
| 392 |
+
etq = np.sum(WScalarMat[min_st_adj_worst])
|
| 393 |
+
|
| 394 |
+
if loss_type == 0:
|
| 395 |
+
# Variable Hinge Loss - CHECKED
|
| 396 |
+
L = etg - min_marginalized_energy
|
| 397 |
+
if L > 0:
|
| 398 |
+
dLdOut = np.zeros_like(WScalarMat)
|
| 399 |
+
dLdOut[max_st_adj_gold&(~min_st_adj_worst)] = 1
|
| 400 |
+
dLdOut[(~max_st_adj_gold)&min_st_adj_worst] = -1
|
| 401 |
+
else:
|
| 402 |
+
doBpp = False
|
| 403 |
+
return (L, None, doBpp)
|
| 404 |
+
elif loss_type == 1:
|
| 405 |
+
# LOg Loss
|
| 406 |
+
a = etg - etq
|
| 407 |
+
b = np.exp(a)
|
| 408 |
+
L = np.log(1 + b)
|
| 409 |
+
|
| 410 |
+
dLdOut = np.zeros_like(WScalarMat)
|
| 411 |
+
dLdOut[max_st_adj_gold&(~min_st_adj_worst)] = 1
|
| 412 |
+
dLdOut[(~max_st_adj_gold)&min_st_adj_worst] = -1
|
| 413 |
+
|
| 414 |
+
dLdOut *= (b/(1 + b))
|
| 415 |
+
elif loss_type == 2:
|
| 416 |
+
# Square exponential loss
|
| 417 |
+
gamma = 1
|
| 418 |
+
b = np.exp(-etq)
|
| 419 |
+
|
| 420 |
+
L = etg**2 + gamma*b
|
| 421 |
+
|
| 422 |
+
dLdOut = np.zeros_like(WScalarMat)
|
| 423 |
+
dLdOut[max_st_adj_gold&(~min_st_adj_worst)] = 2*etg
|
| 424 |
+
dLdOut[(~max_st_adj_gold)&min_st_adj_worst] = -gamma*b
|
| 425 |
+
pass
|
| 426 |
+
return (L, dLdOut, doBpp)
|
| 427 |
+
def Test(self, sentenceObj, dcsObj, dsbz2_name, _dump = False, _outFile = None):
|
| 428 |
+
if _dump:
|
| 429 |
+
if _outFile is None:
|
| 430 |
+
raise Exception('WTH r u thinking! pass me outFolder')
|
| 431 |
+
if self.neuralnet.version == 'h1':
|
| 432 |
+
self.neuralnet.ForTesting()
|
| 433 |
+
|
| 434 |
+
# with open('gt_cngs.csv','a') as fh:
|
| 435 |
+
# for i in dcsObj.cng:
|
| 436 |
+
# for j in i:
|
| 437 |
+
# print(str(sentenceObj.sent_id)+":"+str(j))
|
| 438 |
+
# wr = csv.writer(fh)
|
| 439 |
+
# wr.writerow([sentenceObj.sent_id,j])
|
| 440 |
+
|
| 441 |
+
# return
|
| 442 |
+
neuralnet = self.neuralnet
|
| 443 |
+
minScore = np.inf
|
| 444 |
+
minMst = None
|
| 445 |
+
|
| 446 |
+
# dsbz2_name = sentenceObj.sent_id + '.ds.bz2'
|
| 447 |
+
(nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 448 |
+
nodelist, conflicts_Dict, featVMat) = open_dsbz2(dsbz2_name)
|
| 449 |
+
|
| 450 |
+
# if len(nodelist) > 50:
|
| 451 |
+
# return None
|
| 452 |
+
|
| 453 |
+
if not self.neuralnet.outer_relu:
|
| 454 |
+
(WScalarMat, SigmoidGateOutput) = Get_W_Scalar_Matrix_from_FeatVect_Matrix(featVMat, nodelist, conflicts_Dict, neuralnet)
|
| 455 |
+
else:
|
| 456 |
+
WScalarMat = Get_W_Scalar_Matrix_from_FeatVect_Matrix(featVMat, nodelist, conflicts_Dict, neuralnet)
|
| 457 |
+
|
| 458 |
+
# print('NeuralNet Time: ', time.time() - startT)
|
| 459 |
+
# startT = time.time()
|
| 460 |
+
|
| 461 |
+
# Get all MST
|
| 462 |
+
# print('before getting all cliques')
|
| 463 |
+
for source in range(len(nodelist)):
|
| 464 |
+
(mst_nodes, mst_adj_graph, _) = clique(nodelist, WScalarMat, conflicts_Dict, source)
|
| 465 |
+
# print('.', end = '')
|
| 466 |
+
score = GetMSTWeight(mst_adj_graph, WScalarMat)
|
| 467 |
+
if(score < minScore):
|
| 468 |
+
minScore = score
|
| 469 |
+
minMst = mst_nodes
|
| 470 |
+
# print('after getting all cliques')
|
| 471 |
+
dcsLemmas = [[rom_slp(l) for l in arr]for arr in dcsObj.lemmas]
|
| 472 |
+
word_match = 0
|
| 473 |
+
lemma_match = 0
|
| 474 |
+
n_output_nodes = 0
|
| 475 |
+
|
| 476 |
+
if _dump:
|
| 477 |
+
predicted_lemmas = [sentenceObj.sent_id]
|
| 478 |
+
predicted_cngs = [sentenceObj.sent_id]
|
| 479 |
+
predicted_chunk_id = [sentenceObj.sent_id]
|
| 480 |
+
predicted_pos = [sentenceObj.sent_id]
|
| 481 |
+
predicted_id = [sentenceObj.sent_id]
|
| 482 |
+
|
| 483 |
+
for chunk_id, wdSplit in minMst.items():
|
| 484 |
+
for wd in wdSplit:
|
| 485 |
+
if _dump:
|
| 486 |
+
predicted_lemmas.append(wd.lemma)
|
| 487 |
+
predicted_cngs.append(wd.cng)
|
| 488 |
+
predicted_chunk_id.append(wd.chunk_id)
|
| 489 |
+
predicted_pos.append(wd.pos)
|
| 490 |
+
predicted_id.append(wd.id)
|
| 491 |
+
|
| 492 |
+
n_output_nodes += 1
|
| 493 |
+
# Match lemma
|
| 494 |
+
search_result = [i for i, j in enumerate(dcsLemmas[chunk_id]) if j == wd.lemma]
|
| 495 |
+
if len(search_result) > 0:
|
| 496 |
+
lemma_match += 1
|
| 497 |
+
# Match CNG
|
| 498 |
+
for i in search_result:
|
| 499 |
+
if(dcsObj.cng[chunk_id][i] == str(wd.cng)):
|
| 500 |
+
word_match += 1
|
| 501 |
+
# print(wd.lemma, wd.cng)
|
| 502 |
+
break
|
| 503 |
+
dcsLemmas = [l for arr in dcsObj.lemmas for l in arr]
|
| 504 |
+
|
| 505 |
+
if _dump:
|
| 506 |
+
with open(_outFile, 'a') as fh:
|
| 507 |
+
dcsv = csv.writer(fh)
|
| 508 |
+
dcsv.writerow(predicted_lemmas)
|
| 509 |
+
dcsv.writerow(predicted_cngs)
|
| 510 |
+
dcsv.writerow(predicted_chunk_id)
|
| 511 |
+
dcsv.writerow(predicted_pos)
|
| 512 |
+
dcsv.writerow(predicted_id)
|
| 513 |
+
dcsv.writerow([sentenceObj.sent_id, word_match, lemma_match, len(dcsLemmas), n_output_nodes])
|
| 514 |
+
|
| 515 |
+
# print('All MST Time: ', time.time() - startT)
|
| 516 |
+
# print('Node Count: ', len(nodelist))
|
| 517 |
+
# print('\nFull Match: {}, Partial Match: {}, OutOf {}, NodeCount: {}, '.\
|
| 518 |
+
# format(word_match, lemma_match, len(dcsLemmas), len(nodelist)))
|
| 519 |
+
return (word_match, lemma_match, len(dcsLemmas), n_output_nodes)
|
| 520 |
+
|
| 521 |
+
def Train(self, sentenceObj, dcsObj, bz2_input_folder, _debug = True):
|
| 522 |
+
self.neuralnet.ForTraining()
|
| 523 |
+
self.neuralnet.new_dropout() # renew dropout setting
|
| 524 |
+
# Hyperparameter for hinge loss: m
|
| 525 |
+
m_hinge_param = 14
|
| 526 |
+
|
| 527 |
+
dsbz2_name = sentenceObj.sent_id + '.ds.bz2'
|
| 528 |
+
(nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 529 |
+
nodelist, conflicts_Dict, featVMat) = open_dsbz2(bz2_input_folder + dsbz2_name)
|
| 530 |
+
sub = 0
|
| 531 |
+
for s in conflicts_Dict.keys():
|
| 532 |
+
sub = sub+len(conflicts_Dict[s])
|
| 533 |
+
with open('cliq.csv','a') as fh:
|
| 534 |
+
rd = csv.writer(fh)
|
| 535 |
+
rd.writerow([str(dsbz2_name),str(len(nodelist)),str(sub)])
|
| 536 |
+
|
| 537 |
+
print(dsbz2_name)
|
| 538 |
+
print("NodeLength: "+str(len(nodelist)))
|
| 539 |
+
|
| 540 |
+
# if(len(nodelist)>40):
|
| 541 |
+
# with open('cliq.csv','a') as fh:
|
| 542 |
+
# rd = csv.writer(fh)
|
| 543 |
+
# rd.writerow([str(dsbz2_name),'0'])
|
| 544 |
+
# return
|
| 545 |
+
|
| 546 |
+
if(len(nodelist)>100):
|
| 547 |
+
print("Nodelength : "+str(len(nodelist)))
|
| 548 |
+
# Train for large graphs separately
|
| 549 |
+
# if len(nodelist) < 40:
|
| 550 |
+
# return
|
| 551 |
+
|
| 552 |
+
""" FORM MAXIMUM(ENERGY) SPANNING TREE OF THE GOLDEN GRAPH : WORST GOLD STRUCTURE """
|
| 553 |
+
WScalarMat_correct = Get_W_Scalar_Matrix_from_FeatVect_Matrix(featVMat_correct, nodelist_correct,\
|
| 554 |
+
conflicts_Dict_correct, self.neuralnet)
|
| 555 |
+
source = 0
|
| 556 |
+
""" Find the max spanning tree : negative Weight matrix passed """
|
| 557 |
+
# (min_st_gold_ndict, min_st_adj_gold_small, _) =\
|
| 558 |
+
# clique(nodelist_correct, -WScalarMat_correct, conflicts_Dict_correct, source)
|
| 559 |
+
# print('BEFORE'*3)
|
| 560 |
+
# print('sentence start')
|
| 561 |
+
|
| 562 |
+
# print('stuck after this maybe')
|
| 563 |
+
# print(WScalarMat_correct)
|
| 564 |
+
# print('-'*40)
|
| 565 |
+
# print(nodelist_correct)
|
| 566 |
+
# print('*')
|
| 567 |
+
(min_st_gold_ndict, min_st_adj_gold_small, _) =\
|
| 568 |
+
clique(nodelist_correct, WScalarMat_correct, conflicts_Dict_correct, source)
|
| 569 |
+
# print('found a way out')
|
| 570 |
+
# print('AFTER'*3)
|
| 571 |
+
# print('-'*30)
|
| 572 |
+
energy_gold_max_ST = np.sum(WScalarMat_correct[min_st_adj_gold_small])
|
| 573 |
+
# print("Gold: "+str(energy_gold_max_ST))
|
| 574 |
+
""" Convert correct spanning tree graph adj matrix to full marix dimensions """
|
| 575 |
+
""" Create full-size adjacency matrix for correct_mst_small """
|
| 576 |
+
|
| 577 |
+
nodelen = len(nodelist)
|
| 578 |
+
# print(nodelen)
|
| 579 |
+
|
| 580 |
+
min_st_adj_gold = np.ndarray((nodelen, nodelen), bool) * False # T_STAR
|
| 581 |
+
for i in range(min_st_adj_gold_small.shape[0]):
|
| 582 |
+
for j in range(min_st_adj_gold_small.shape[1]):
|
| 583 |
+
min_st_adj_gold[nodelist_to_correct_mapping[i], nodelist_to_correct_mapping[j]] =\
|
| 584 |
+
min_st_adj_gold_small[i, j]
|
| 585 |
+
|
| 586 |
+
""" Delta(Margin) Function : MASK FOR WHICH NODES IN NODELIST BELONG TO DCS """
|
| 587 |
+
gold_nodes_mask = np.array([False]*len(nodelist))
|
| 588 |
+
gold_nodes_mask[list(nodelist_to_correct_mapping.values())] = True
|
| 589 |
+
margin_f = lambda nodes_mask: np.sum(nodes_mask&(~gold_nodes_mask))**2
|
| 590 |
+
|
| 591 |
+
""" FOR ALL POSSIBLE MST FROM THE COMPLETE GRAPH """
|
| 592 |
+
WScalarMat = Get_W_Scalar_Matrix_from_FeatVect_Matrix(featVMat, nodelist, conflicts_Dict, self.neuralnet)
|
| 593 |
+
# print(WScalarMat)
|
| 594 |
+
""" For each node - Find MST with that source"""
|
| 595 |
+
min_STx = None # Min Energy spanning tree with worst margin with gold_STx
|
| 596 |
+
min_marginalized_energy = np.inf
|
| 597 |
+
|
| 598 |
+
# Generate random set of nodes from which mSTs are to be considered
|
| 599 |
+
n_nodes = len(nodelist)
|
| 600 |
+
selection_prob = 0.4
|
| 601 |
+
select_flag = np.random.rand(n_nodes) < selection_prob
|
| 602 |
+
# Fix if all zeros
|
| 603 |
+
if np.sum(select_flag) == 0:
|
| 604 |
+
select_flag[np.random.randint(n_nodes)] = 1
|
| 605 |
+
|
| 606 |
+
best_node_diff = np.Inf
|
| 607 |
+
best_energy = np.inf
|
| 608 |
+
# print('before###')
|
| 609 |
+
|
| 610 |
+
cliqset = set()
|
| 611 |
+
for source in range(len(nodelist)):
|
| 612 |
+
(mst_nodes, mst_adj_graph, mst_nodes_bool) = clique(nodelist, WScalarMat, conflicts_Dict, source) # T_X
|
| 613 |
+
# Calculate energy of spanning tree
|
| 614 |
+
cst = ''
|
| 615 |
+
for i in mst_nodes_bool:
|
| 616 |
+
if(i):
|
| 617 |
+
cst=cst+'1'
|
| 618 |
+
else:
|
| 619 |
+
cst=cst+'0'
|
| 620 |
+
cliqset.add(cst)
|
| 621 |
+
en_st = np.sum(WScalarMat[mst_adj_graph])
|
| 622 |
+
# Pick up the node_diff with lowest energy
|
| 623 |
+
delta_st = margin_f(mst_nodes_bool)
|
| 624 |
+
|
| 625 |
+
if _debug:
|
| 626 |
+
if best_energy > en_st:
|
| 627 |
+
best_node_diff = delta_st
|
| 628 |
+
best_energy = en_st
|
| 629 |
+
|
| 630 |
+
# Minimum marginalized energy calculation
|
| 631 |
+
marginalized_en = en_st - delta_st
|
| 632 |
+
# print("Source:"+str(source)+"; Energy"+str(min_marginalized_energy))
|
| 633 |
+
# Minimum marginalized spanning tree : Randomization applied
|
| 634 |
+
# if marginalized_en < min_marginalized_energy and select_flag[source]:
|
| 635 |
+
if marginalized_en < min_marginalized_energy:
|
| 636 |
+
min_marginalized_energy = marginalized_en
|
| 637 |
+
min_STx = mst_adj_graph
|
| 638 |
+
# Energy diff should all be negative
|
| 639 |
+
if _debug:
|
| 640 |
+
print('Source: [{}], Node_Diff:{}, Max_Gold_En: {:.3f}, Energy: {:.3f}'.\
|
| 641 |
+
format(source, np.sum((~gold_nodes_mask)&mst_nodes_bool), energy_gold_max_ST, np.sum(WScalarMat[mst_adj_graph])))
|
| 642 |
+
with open('cliq.csv','a') as fh:
|
| 643 |
+
rd = csv.writer(fh)
|
| 644 |
+
rd.writerow([str(dsbz2_name),len(cliqset)])
|
| 645 |
+
# print('after###')
|
| 646 |
+
# print("Min-Energy:"+str(min_marginalized_energy))
|
| 647 |
+
# print("Gold-Energy"+str(np.sum(WScalarMat[min_st_adj_gold])))
|
| 648 |
+
# print("*"*40)
|
| 649 |
+
if _debug:
|
| 650 |
+
print('Best Node diff: {} with EN: {}'.format(np.sqrt(best_node_diff), best_energy))
|
| 651 |
+
""" Gradient Descent """
|
| 652 |
+
# LOSS TYPES -> hinge(0), log-loss(1), square-exponential(2)
|
| 653 |
+
Total_Loss, dLdOut, doBpp = self.CalculateLoss_n_Grads(WScalarMat, min_STx, min_st_adj_gold,\
|
| 654 |
+
loss_type = 0, min_marginalized_energy = min_marginalized_energy)
|
| 655 |
+
if doBpp:
|
| 656 |
+
if _debug:
|
| 657 |
+
print('{}. '.format(sentenceObj.sent_id), end = '')
|
| 658 |
+
self.neuralnet.Back_Prop(dLdOut, len(nodelist), featVMat, _debug)
|
| 659 |
+
else:
|
| 660 |
+
trainingStatus[sentenceObj.sent_id] = True
|
| 661 |
+
if _debug:
|
| 662 |
+
print("\nFileKey: %s, Loss: %6.3f" % (sentenceObj.sent_id, Total_Loss))
|
| 663 |
+
|
| 664 |
+
TrainFiles = None
|
| 665 |
+
trainer = None
|
| 666 |
+
p_name = ''
|
| 667 |
+
odir = ''
|
| 668 |
+
def InitModule():
|
| 669 |
+
global trainer
|
| 670 |
+
trainer = Trainer()
|
| 671 |
+
|
| 672 |
+
def register_nnet(nnet, bz2_input_folder):
|
| 673 |
+
if not os.path.isdir(odir):
|
| 674 |
+
os.mkdir(odir)
|
| 675 |
+
if not os.path.isfile('outputs/nnet_LOGS.csv'):
|
| 676 |
+
with open('outputs/nnet_LOGS.csv', 'a') as fh:
|
| 677 |
+
csv_r = csv.writer(fh)
|
| 678 |
+
csv_r.writerow(['odir', 'p_name', 'hidden_layer_size', '_edge_vector_dim'])
|
| 679 |
+
with open('outputs/nnet_LOGS.csv', 'a') as fh:
|
| 680 |
+
csv_r = csv.writer(fh)
|
| 681 |
+
if nnet.version == 'h1':
|
| 682 |
+
csv_r.writerow([odir, p_name, nnet.n, nnet.d, bz2_input_folder])
|
| 683 |
+
elif nnet.version == 'h2':
|
| 684 |
+
csv_r.writerow([odir, p_name, nnet.h1, nnet.h2, nnet.d, bz2_input_folder])
|
| 685 |
+
|
| 686 |
+
"""
|
| 687 |
+
################################################################################################
|
| 688 |
+
################################################################################################
|
| 689 |
+
################################################################################################
|
| 690 |
+
"""
|
| 691 |
+
def main():
|
| 692 |
+
global TrainFiles, p_name, odir
|
| 693 |
+
|
| 694 |
+
"""
|
| 695 |
+
################################################################################################
|
| 696 |
+
############################## GET A FILENAME TO SAVE WEIGHTS ################################
|
| 697 |
+
################################################################################################
|
| 698 |
+
"""
|
| 699 |
+
st = str(int((time.time() * 1e6) % 1e13))
|
| 700 |
+
log_name = 'logs/train_nnet_t{}.out'.format(st)
|
| 701 |
+
odir = 'outputs/train_t{}'.format(st)
|
| 702 |
+
p_name = 'outputs/train_t{}/nnet.p'.format(st)
|
| 703 |
+
print('nEURAL nET wILL bE sAVED hERE: ', p_name)
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
if not os.path.isdir('inputs'):
|
| 707 |
+
os.mkdir('inputs')
|
| 708 |
+
|
| 709 |
+
if not os.path.isdir('outputs'):
|
| 710 |
+
os.mkdir('outputs')
|
| 711 |
+
|
| 712 |
+
if not os.path.isdir('NewData'):
|
| 713 |
+
os.mkdir('NewData')
|
| 714 |
+
# Create Training File List
|
| 715 |
+
excluded_files = []
|
| 716 |
+
with open('Baseline4_advSample.csv', 'r') as f_handle:
|
| 717 |
+
opener = csv.reader(f_handle)
|
| 718 |
+
for line in opener:
|
| 719 |
+
excluded_files.append(line[1].replace('.p', '.ds.bz2'))
|
| 720 |
+
|
| 721 |
+
# Load Simultaneous files
|
| 722 |
+
print('Loading Large Files')
|
| 723 |
+
loaded_SKT = pickle.load(open('Simultaneous_CompatSKT_10K.p', 'rb'), encoding=u'utf-8')
|
| 724 |
+
loaded_DCS = pickle.load(open('Simultaneous_DCS_10K.p', 'rb'), encoding=u'utf-8')
|
| 725 |
+
|
| 726 |
+
# loaded_SKT = pickle.load(open('../Simultaneous_CompatSKT.p', 'rb'), encoding=u'utf-8')
|
| 727 |
+
# loaded_DCS = pickle.load(open('../Simultaneous_DCS.p', 'rb'), encoding=u'utf-8')
|
| 728 |
+
|
| 729 |
+
bz2_input_folder = '../wordsegmentation/skt_dcs_DS.bz2_4K_bigram_mir_10K/' #bm2
|
| 730 |
+
# bz2_input_folder = '../wordsegmentation/skt_dcs_DS.bz2_1L_bigram_mir_10K/' #bm3
|
| 731 |
+
# bz2_input_folder = '../wordsegmentation/skt_dcs_DS.bz2_4K_bigram_rfe_10K/' #br2
|
| 732 |
+
# bz2_input_folder = '../wordsegmentation/skt_dcs_DS.bz2_1L_bigram_rfe_10K/' #br3
|
| 733 |
+
# bz2_input_folder = '../wordsegmentation/skt_dcs_DS.bz2_4K_pmi_mir_10K/' #pm2
|
| 734 |
+
# bz2_input_folder = '../wordsegmentation/skt_dcs_DS.bz2_1L_pmi_mir_10K2/' #pm3
|
| 735 |
+
# bz2_input_folder = '../wordsegmentation/skt_dcs_DS.bz2_4K_pmi_rfe_10K/' #pr2
|
| 736 |
+
# bz2_input_folder = '../wordsegmentation/skt_dcs_DS.bz2_1L_pmi_rfe_10K/' #pr3 # bz2_input_folder = '/home/rs/15CS91R05/vishnu/Data/skt_dcs_DS.bz2_compat_10k_check_again/'
|
| 737 |
+
all_files = []
|
| 738 |
+
skipped = 0
|
| 739 |
+
for f in os.listdir(bz2_input_folder):
|
| 740 |
+
if '.ds.bz2' in f:
|
| 741 |
+
if f in excluded_files:
|
| 742 |
+
skipped += 1
|
| 743 |
+
continue
|
| 744 |
+
if f.replace('.ds.bz2', '.p2') not in loaded_DCS:
|
| 745 |
+
print('Couldnt find ', f)
|
| 746 |
+
continue
|
| 747 |
+
all_files.append(f)
|
| 748 |
+
|
| 749 |
+
print(skipped, 'files will not be used for training')
|
| 750 |
+
print('Size of training set:', len(all_files))
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
# all_files=['32517.p2']
|
| 754 |
+
TrainFiles = all_files
|
| 755 |
+
with open('cliq.csv','w') as fh:
|
| 756 |
+
rd = csv.writer(fh)
|
| 757 |
+
rd.writerow(['FileName','Nodelength & NCliques'])
|
| 758 |
+
InitModule()
|
| 759 |
+
trainingStatus = defaultdict(lambda: bool(False))
|
| 760 |
+
# train = train_generator(loaded_SKT, loaded_DCS, bz2_input_folder, n_trainset = -1, filePerBatch = 10, iterationPerBatch = 5, _debug=False, superEpochs = 5)
|
| 761 |
+
train = train_generator(loaded_SKT, loaded_DCS, bz2_input_folder, n_trainset = -1, filePerBatch = 20, iterationPerBatch = 3, _debug=False, superEpochs = 2)
|
| 762 |
+
|
| 763 |
+
# Complete Training
|
| 764 |
+
train.__next__()
|
| 765 |
+
|
| 766 |
+
print('Training Complete')
|
| 767 |
+
|
| 768 |
+
if __name__ == '__main__':
|
| 769 |
+
main()
|
dir/Train_n_Save_NNet.py
ADDED
|
@@ -0,0 +1,680 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
IMPORTS
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
## bUILT-iN pACKAGES
|
| 6 |
+
import sys, os, time, bz2, zlib, pickle, math, json, csv
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
import numpy as np
|
| 9 |
+
np.set_printoptions(suppress=True)
|
| 10 |
+
from IPython.display import display
|
| 11 |
+
|
| 12 |
+
## lAST sUMMER
|
| 13 |
+
from romtoslp import *
|
| 14 |
+
from sentences import *
|
| 15 |
+
from DCS import *
|
| 16 |
+
import MatDB
|
| 17 |
+
from heap_n_PrimMST import *
|
| 18 |
+
from ECL_MST import *
|
| 19 |
+
|
| 20 |
+
import word_definite as WD
|
| 21 |
+
from heap_n_PrimMST import *
|
| 22 |
+
from nnet import *
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
################################################################################################
|
| 26 |
+
########################### LOAD SENTENCE AND DCS OBJECT FILES ###############################
|
| 27 |
+
################################################################################################
|
| 28 |
+
"""
|
| 29 |
+
# loaded_SKT = pickle.load(open('../Simultaneous_CompatSKT_10K.p', 'rb'))
|
| 30 |
+
# loaded_DCS = pickle.load(open('../Simultaneous_DCS_10K.p', 'rb'))
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
################################################################################################
|
| 34 |
+
########################### OPENS AND EXTRACTS DCS AND SKT DATA STRUCTURES ###################
|
| 35 |
+
###################################### FROM BZ2 FILES ########################################
|
| 36 |
+
"""
|
| 37 |
+
def open_dsbz2(filename):
|
| 38 |
+
with bz2.BZ2File(filename, 'r') as f:
|
| 39 |
+
loader = pickle.load(f)
|
| 40 |
+
|
| 41 |
+
conflicts_Dict_correct = loader['conflicts_Dict_correct']
|
| 42 |
+
nodelist_to_correct_mapping = loader['nodelist_to_correct_mapping']
|
| 43 |
+
nodelist_correct = loader['nodelist_correct']
|
| 44 |
+
featVMat_correct = loader['featVMat_correct']
|
| 45 |
+
featVMat = loader['featVMat']
|
| 46 |
+
conflicts_Dict = loader['conflicts_Dict']
|
| 47 |
+
nodelist = loader['nodelist']
|
| 48 |
+
|
| 49 |
+
return (nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 50 |
+
nodelist, conflicts_Dict, featVMat)
|
| 51 |
+
|
| 52 |
+
"""
|
| 53 |
+
################################################################################################
|
| 54 |
+
###################### CREATE SEVERAL DATA STRUCTURES FROM SENTENCE/DCS ######################
|
| 55 |
+
########################### NODELIST, ADJACENCY LIST, GRAPH, HEAP #############################
|
| 56 |
+
################################################################################################
|
| 57 |
+
"""
|
| 58 |
+
def GetTrainingKit(sentenceObj, dcsObj):
|
| 59 |
+
nodelist = GetNodes(sentenceObj)
|
| 60 |
+
|
| 61 |
+
# Nodelist with only the correct_nodes
|
| 62 |
+
nodelist2 = GetNodes(sentenceObj)
|
| 63 |
+
nodelist2_to_correct_mapping = {}
|
| 64 |
+
nodelist_correct = []
|
| 65 |
+
search_key = 0
|
| 66 |
+
first_key = 0
|
| 67 |
+
for chunk_id in range(len(dcsObj.lemmas)):
|
| 68 |
+
while nodelist2[first_key].chunk_id != chunk_id:
|
| 69 |
+
first_key += 1
|
| 70 |
+
for j in range(len(dcsObj.lemmas[chunk_id])):
|
| 71 |
+
search_key = first_key
|
| 72 |
+
while (nodelist2[search_key].lemma != rom_slp(dcsObj.lemmas[chunk_id][j])) or (nodelist2[search_key].cng != dcsObj.cng[chunk_id][j]):
|
| 73 |
+
search_key += 1
|
| 74 |
+
if search_key >= len(nodelist2) or nodelist2[search_key].chunk_id > chunk_id:
|
| 75 |
+
break
|
| 76 |
+
# print((rom_slp(dcsObj.lemmas[chunk_id][j]), dcsObj.cng[chunk_id][j]))
|
| 77 |
+
# print(nodelist[search_key])
|
| 78 |
+
nodelist2_to_correct_mapping[len(nodelist_correct)] = search_key
|
| 79 |
+
nodelist_correct.append(nodelist2[search_key])
|
| 80 |
+
return (nodelist, nodelist_correct, nodelist2_to_correct_mapping)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def GetGraph(nodelist, neuralnet):
|
| 84 |
+
if not neuralnet.outer_relu:
|
| 85 |
+
conflicts_Dict = Get_Conflicts(nodelist)
|
| 86 |
+
|
| 87 |
+
featVMat = Get_Feat_Vec_Matrix(nodelist, conflicts_Dict)
|
| 88 |
+
|
| 89 |
+
(WScalarMat, SigmoidGateOutput) = Get_W_Scalar_Matrix_from_FeatVect_Matrix(featVMat, nodelist, conflicts_Dict, neuralnet)
|
| 90 |
+
return (conflicts_Dict, featVMat, WScalarMat, SigmoidGateOutput)
|
| 91 |
+
else:
|
| 92 |
+
conflicts_Dict = Get_Conflicts(nodelist)
|
| 93 |
+
|
| 94 |
+
featVMat = Get_Feat_Vec_Matrix(nodelist, conflicts_Dict)
|
| 95 |
+
|
| 96 |
+
WScalarMat = Get_W_Scalar_Matrix_from_FeatVect_Matrix(featVMat, nodelist, conflicts_Dict, neuralnet)
|
| 97 |
+
return (conflicts_Dict, featVMat, WScalarMat)
|
| 98 |
+
|
| 99 |
+
# NEW LOSS FUNCTION
|
| 100 |
+
def GetLoss(_mst_adj_graph, _mask_de_correct_edges, _negLogLikelies):
|
| 101 |
+
_negLogLikelies = _negLogLikelies.copy()
|
| 102 |
+
_negLogLikelies[~_mst_adj_graph] = 0
|
| 103 |
+
_negLogLikelies[~_mask_de_correct_edges] *= -1 # BAKA!!! Check before you try to fix this again
|
| 104 |
+
return np.sum(_negLogLikelies)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
"""
|
| 109 |
+
################################################################################################
|
| 110 |
+
############################## MAIN FUNCTION #################################################
|
| 111 |
+
################################################################################################
|
| 112 |
+
"""
|
| 113 |
+
trainingStatus = defaultdict(lambda: bool(False))
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
"""
|
| 117 |
+
################################################################################################
|
| 118 |
+
############################## TRAIN FUNCTION ################################################
|
| 119 |
+
################################################################################################
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def train_generator(loaded_SKT, loaded_DCS, bz2_input_folder, n_trainset = -1, iterationPerBatch = 10, filePerBatch = 20, _debug = True, superEpochs = 1):
|
| 123 |
+
# Train
|
| 124 |
+
if n_trainset == -1:
|
| 125 |
+
n_trainset = len(TrainFiles)
|
| 126 |
+
totalBatchToTrain = math.ceil(n_trainset/filePerBatch)
|
| 127 |
+
else:
|
| 128 |
+
totalBatchToTrain = math.ceil(n_trainset/filePerBatch)
|
| 129 |
+
|
| 130 |
+
register_nnet(trainer.neuralnet, bz2_input_folder)
|
| 131 |
+
print('Epoch:'+str(superEpochs))
|
| 132 |
+
print('iters:'+str(totalBatchToTrain))
|
| 133 |
+
for _epoch in range(superEpochs):
|
| 134 |
+
for iterout in range(totalBatchToTrain):
|
| 135 |
+
# Add timer
|
| 136 |
+
startT = time.time()
|
| 137 |
+
|
| 138 |
+
# Change current batch
|
| 139 |
+
if(iterout % 50 == 0):
|
| 140 |
+
trainer.Save(p_name.replace('.p', '_e{}_i{}.p'.format(_epoch, iterout)))
|
| 141 |
+
else:
|
| 142 |
+
trainer.Save(p_name)
|
| 143 |
+
print('Epoch: {}, Batch: {}'.format(_epoch, iterout))
|
| 144 |
+
files_for_batch = TrainFiles[iterout*filePerBatch:(iterout + 1)*filePerBatch]
|
| 145 |
+
# print(files_for_batch)
|
| 146 |
+
# trainer.Load('outputs/neuralnet_trained.p')
|
| 147 |
+
try:
|
| 148 |
+
# Run few times on same set of files
|
| 149 |
+
for iterin in range(iterationPerBatch):
|
| 150 |
+
print('ITERATION IN', iterin)
|
| 151 |
+
for fn in files_for_batch:
|
| 152 |
+
trainFileName = fn.replace('.ds.bz2', '.p2')
|
| 153 |
+
sentenceObj = loaded_SKT[trainFileName]
|
| 154 |
+
dcsObj = loaded_DCS[trainFileName]
|
| 155 |
+
if trainingStatus[sentenceObj.sent_id]:
|
| 156 |
+
continue
|
| 157 |
+
# trainer.Save('outputs/saved_trainer.p')
|
| 158 |
+
try:
|
| 159 |
+
trainer.Train(sentenceObj, dcsObj, bz2_input_folder, _debug)
|
| 160 |
+
except (IndexError, KeyError) as e:
|
| 161 |
+
print('\x1b[31mFailed: {} \x1b[0m'.format(sentenceObj.sent_id))
|
| 162 |
+
except EOFError as e:
|
| 163 |
+
print('\x1b[31mBADFILE: {} \x1b[0m'.format(sentenceObj.sent_id))
|
| 164 |
+
sys.stdout.flush() # Flush IO buffer
|
| 165 |
+
finishT = time.time()
|
| 166 |
+
print('Avg. time taken by 1 file(1 iteration): {:.3f}'.format((finishT - startT)/(iterationPerBatch*filePerBatch)))
|
| 167 |
+
except KeyboardInterrupt:
|
| 168 |
+
print('Training paused')
|
| 169 |
+
trainer.Save(p_name)
|
| 170 |
+
yield None
|
| 171 |
+
trainer.Save(p_name)
|
| 172 |
+
|
| 173 |
+
def test(loaded_SKT, loaded_DCS, n_testSet = -1, _testFiles = None, n_checkpt = 100):
|
| 174 |
+
total_lemma = 0;
|
| 175 |
+
correct_lemma = 0;
|
| 176 |
+
|
| 177 |
+
total_word = 0;
|
| 178 |
+
total_output_nodes = 0
|
| 179 |
+
correct_word = 0;
|
| 180 |
+
file_counter = 0
|
| 181 |
+
if _testFiles is None:
|
| 182 |
+
if n_testSet == -1:
|
| 183 |
+
_testFiles = TestFiles
|
| 184 |
+
else:
|
| 185 |
+
_testFiles = TestFiles[0:n_testSet]
|
| 186 |
+
else:
|
| 187 |
+
if n_testSet == -1:
|
| 188 |
+
_testFiles = _testFiles
|
| 189 |
+
else:
|
| 190 |
+
_testFiles = _testFiles[0:n_testSet]
|
| 191 |
+
|
| 192 |
+
recalls = []
|
| 193 |
+
recalls_of_word = []
|
| 194 |
+
precisions = []
|
| 195 |
+
precisions_of_words = []
|
| 196 |
+
for fn in _testFiles:
|
| 197 |
+
if file_counter % n_checkpt == 0:
|
| 198 |
+
print(file_counter,' Checkpoint... ')
|
| 199 |
+
if file_counter > 0:
|
| 200 |
+
print('Avg. Micro Recall of Lemmas: {}'.format(np.mean(np.array(recalls))))
|
| 201 |
+
print('Avg. Micro Recall of Words: {}'.format(np.mean(np.array(recalls_of_word))))
|
| 202 |
+
print('Avg. Micro Precision of Lemmas: {}'.format(np.mean(np.array(precisions))))
|
| 203 |
+
print('Avg. Micro Precision of Words: {}'.format(np.mean(np.array(precisions_of_words))))
|
| 204 |
+
sys.stdout.flush() # Flush IO buffer
|
| 205 |
+
|
| 206 |
+
file_counter += 1
|
| 207 |
+
|
| 208 |
+
testFileName = fn.replace('.ds.bz2', '.p2')
|
| 209 |
+
sentenceObj = loaded_SKT[testFileName]
|
| 210 |
+
dcsObj = loaded_DCS[testFileName]
|
| 211 |
+
|
| 212 |
+
try:
|
| 213 |
+
(word_match, lemma_match, n_dcsWords, n_output_nodes) = trainer.Test(sentenceObj, dcsObj)
|
| 214 |
+
|
| 215 |
+
recalls.append(lemma_match/n_dcsWords)
|
| 216 |
+
recalls_of_word.append(word_match/n_dcsWords)
|
| 217 |
+
|
| 218 |
+
precisions.append(lemma_match/n_output_nodes)
|
| 219 |
+
precisions_of_words.append(word_match/n_output_nodes)
|
| 220 |
+
|
| 221 |
+
total_lemma += n_dcsWords
|
| 222 |
+
total_word += n_dcsWords
|
| 223 |
+
|
| 224 |
+
total_output_nodes += n_output_nodes
|
| 225 |
+
|
| 226 |
+
correct_lemma += lemma_match
|
| 227 |
+
correct_word += word_match
|
| 228 |
+
except (IndexError, KeyError) as e:
|
| 229 |
+
print('Failed!')
|
| 230 |
+
|
| 231 |
+
print('Avg. Micro Recall of Lemmas: {}'.format(np.mean(np.array(recalls))))
|
| 232 |
+
print('Avg. Micro Recall of Words: {}'.format(np.mean(np.array(recalls_of_word))))
|
| 233 |
+
print('Avg. Micro Precision of Lemmas: {}'.format(np.mean(np.array(precisions))))
|
| 234 |
+
print('Avg. Micro Precision of Words: {}'.format(np.mean(np.array(precisions_of_words))))
|
| 235 |
+
|
| 236 |
+
return (recalls, recalls_of_word, precisions, precisions_of_words)
|
| 237 |
+
|
| 238 |
+
# NEW FUNCTION
|
| 239 |
+
def GetLoss(_mst_adj_graph, _mask_de_correct_edges, _WScalarMat):
|
| 240 |
+
_WScalarMat = _WScalarMat.copy()
|
| 241 |
+
_WScalarMat[_mst_adj_graph&(~_mask_de_correct_edges)] *= -1 # BAKA!!! Check before you try to fix this again
|
| 242 |
+
_WScalarMat[~_mst_adj_graph] = 0
|
| 243 |
+
return np.sum(_WScalarMat)
|
| 244 |
+
|
| 245 |
+
"""
|
| 246 |
+
################################################################################################
|
| 247 |
+
############################# TRAINER CLASS DEFINITION ######################################
|
| 248 |
+
################################################################################################
|
| 249 |
+
"""
|
| 250 |
+
class Trainer:
|
| 251 |
+
def __init__(self, modelFile = None):
|
| 252 |
+
if modelFile is None:
|
| 253 |
+
singleLayer = True
|
| 254 |
+
self._edge_vector_dim = 1500
|
| 255 |
+
if singleLayer:
|
| 256 |
+
self.hidden_layer_size = 1200
|
| 257 |
+
keep_prob = 0.6
|
| 258 |
+
self.neuralnet = NN(self._edge_vector_dim, self.hidden_layer_size, outer_relu=True, keep_prob=keep_prob)
|
| 259 |
+
else:
|
| 260 |
+
# DeepR Network
|
| 261 |
+
self.hidden_layer_size = 800
|
| 262 |
+
self.hidden_layer_size2 = 800
|
| 263 |
+
self.neuralnet = NN_2(self._edge_vector_dim, self.hidden_layer_size,\
|
| 264 |
+
hidden_layer_2_size = self.hidden_layer_size2, outer_relu=True)
|
| 265 |
+
self.history = defaultdict(lambda: list())
|
| 266 |
+
else:
|
| 267 |
+
loader = pickle.load(open(filename, 'rb'))
|
| 268 |
+
|
| 269 |
+
self.neuralnet.n = loader['n']
|
| 270 |
+
self.neuralnet.d = loader['d']
|
| 271 |
+
|
| 272 |
+
self.neuralnet = NN(self._edge_vector_dim, self.hidden_layer_size, outer_relu=True)
|
| 273 |
+
|
| 274 |
+
self.neuralnet.U = loader['U']
|
| 275 |
+
self.neuralnet.W = loader['W']
|
| 276 |
+
self.neuralnet.B1 = loader['B1']
|
| 277 |
+
self.neuralnet.B2 = loader['B2']
|
| 278 |
+
|
| 279 |
+
self.history = defaultdict(lambda: list())
|
| 280 |
+
|
| 281 |
+
# SET LEARNING RATES
|
| 282 |
+
if self.neuralnet.version == 'h1':
|
| 283 |
+
self.neuralnet.etaW = 3e-5
|
| 284 |
+
self.neuralnet.etaB1 = 1e-5
|
| 285 |
+
|
| 286 |
+
self.neuralnet.etaU = 1e-5
|
| 287 |
+
self.neuralnet.etaB2 = 1e-5
|
| 288 |
+
elif self.neuralnet.version == 'h2':
|
| 289 |
+
self.neuralnet.etaW1 = 3e-4
|
| 290 |
+
self.neuralnet.etaB1 = 1e-4
|
| 291 |
+
|
| 292 |
+
self.neuralnet.etaW2 = 1e-4
|
| 293 |
+
self.neuralnet.etaB2 = 1e-4
|
| 294 |
+
|
| 295 |
+
self.neuralnet.etaU = 1e-4
|
| 296 |
+
self.neuralnet.etaB3 = 1e-4
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def Reset(self):
|
| 300 |
+
self.neuralnet = NN(self._edge_vector_dim, self.hidden_layer_size)
|
| 301 |
+
self.history = defaultdict(lambda: list())
|
| 302 |
+
|
| 303 |
+
def Save(self, filename):
|
| 304 |
+
print('Weights Saved: ', filename)
|
| 305 |
+
if self.neuralnet.version == 'h1':
|
| 306 |
+
pickle.dump({
|
| 307 |
+
'U': self.neuralnet.U,
|
| 308 |
+
'W': self.neuralnet.W,
|
| 309 |
+
'n': self.neuralnet.n,
|
| 310 |
+
'd': self.neuralnet.d,
|
| 311 |
+
'B1': self.neuralnet.B1,
|
| 312 |
+
'B2': self.neuralnet.B2,
|
| 313 |
+
'keep_prob': self.neuralnet.keep_prob,
|
| 314 |
+
'version': self.neuralnet.version
|
| 315 |
+
}, open(filename, 'wb'))
|
| 316 |
+
return
|
| 317 |
+
elif self.neuralnet.version == 'h2':
|
| 318 |
+
pickle.dump({
|
| 319 |
+
'U': self.neuralnet.U,
|
| 320 |
+
'B3': self.neuralnet.B3,
|
| 321 |
+
'W2': self.neuralnet.W2,
|
| 322 |
+
'B2': self.neuralnet.B2,
|
| 323 |
+
'W1': self.neuralnet.W1,
|
| 324 |
+
'B1': self.neuralnet.B1,
|
| 325 |
+
'h1': self.neuralnet.h1,
|
| 326 |
+
'h2': self.neuralnet.h2,
|
| 327 |
+
'd': self.neuralnet.d,
|
| 328 |
+
'version': self.neuralnet.version
|
| 329 |
+
}, open(filename, 'wb'))
|
| 330 |
+
return
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def Load(self, filename):
|
| 334 |
+
loader = pickle.load(open(filename, 'rb'))
|
| 335 |
+
if 'version' not in loader: # means 1 hidden layer
|
| 336 |
+
self.neuralnet = NN(self._edge_vector_dim, self.hidden_layer_size, outer_relu=True)
|
| 337 |
+
self.neuralnet.U = loader['U']
|
| 338 |
+
self.neuralnet.W = loader['W']
|
| 339 |
+
self.neuralnet.B1 = loader['B1']
|
| 340 |
+
self.neuralnet.B2 = loader['B2']
|
| 341 |
+
self.neuralnet.hidden_layer_size = loader['n']
|
| 342 |
+
self.neuralnet._edge_vector_dim = loader['d']
|
| 343 |
+
if 'keep_prob' in loader:
|
| 344 |
+
self.neuralnet.keep_prob = loader['keep_prob']
|
| 345 |
+
self.neuralnet.dropout_prob = 1 - loader['keep_prob']
|
| 346 |
+
print('Keep Prob = {}, Dropout = {}'.format(self.neuralnet.keep_prob, self.neuralnet.dropout_prob))
|
| 347 |
+
else:
|
| 348 |
+
if loader['version'] == 'h1':
|
| 349 |
+
self.neuralnet = NN(self._edge_vector_dim, self.hidden_layer_size, outer_relu=True)
|
| 350 |
+
self.neuralnet.U = loader['U']
|
| 351 |
+
self.neuralnet.W = loader['W']
|
| 352 |
+
self.neuralnet.B1 = loader['B1']
|
| 353 |
+
self.neuralnet.B2 = loader['B2']
|
| 354 |
+
self.neuralnet.hidden_layer_size = loader['n']
|
| 355 |
+
self.neuralnet._edge_vector_dim = loader['d']
|
| 356 |
+
if 'keep_prob' in loader:
|
| 357 |
+
self.neuralnet.keep_prob = loader['keep_prob']
|
| 358 |
+
self.neuralnet.dropout_prob = 1 - loader['keep_prob']
|
| 359 |
+
print('Keep Prob = {}, Dropout = {}'.format(self.neuralnet.keep_prob, self.neuralnet.dropout_prob))
|
| 360 |
+
elif loader['version'] == 'h2':
|
| 361 |
+
self.neuralnet = NN_2(self._edge_vector_dim, self.hidden_layer_size, outer_relu=True)
|
| 362 |
+
|
| 363 |
+
self.neuralnet.U = loader['U']
|
| 364 |
+
self.neuralnet.B3 = loader['B3']
|
| 365 |
+
|
| 366 |
+
self.neuralnet.W2 = loader['W2']
|
| 367 |
+
self.neuralnet.B2 = loader['B2']
|
| 368 |
+
|
| 369 |
+
self.neuralnet.W1 = loader['W1']
|
| 370 |
+
self.neuralnet.B1 = loader['B1']
|
| 371 |
+
|
| 372 |
+
self.neuralnet.h1 = loader['h1']
|
| 373 |
+
self.neuralnet.h2 = loader['h2']
|
| 374 |
+
self.neuralnet.d = loader['d']
|
| 375 |
+
|
| 376 |
+
def CalculateLoss_n_Grads(self, WScalarMat, min_st_adj_worst, max_st_adj_gold, loss_type = 0, min_marginalized_energy = None):
|
| 377 |
+
doBpp = True
|
| 378 |
+
|
| 379 |
+
# Claculate the enrgies
|
| 380 |
+
etg = np.sum(WScalarMat[max_st_adj_gold])
|
| 381 |
+
etq = np.sum(WScalarMat[min_st_adj_worst])
|
| 382 |
+
|
| 383 |
+
if loss_type == 0:
|
| 384 |
+
# Variable Hinge Loss - CHECKED
|
| 385 |
+
L = etg - min_marginalized_energy
|
| 386 |
+
if L > 0:
|
| 387 |
+
dLdOut = np.zeros_like(WScalarMat)
|
| 388 |
+
dLdOut[max_st_adj_gold&(~min_st_adj_worst)] = 1
|
| 389 |
+
dLdOut[(~max_st_adj_gold)&min_st_adj_worst] = -1
|
| 390 |
+
else:
|
| 391 |
+
doBpp = False
|
| 392 |
+
return (L, None, doBpp)
|
| 393 |
+
elif loss_type == 1:
|
| 394 |
+
# LOg Loss
|
| 395 |
+
a = etg - etq
|
| 396 |
+
b = np.exp(a)
|
| 397 |
+
L = np.log(1 + b)
|
| 398 |
+
|
| 399 |
+
dLdOut = np.zeros_like(WScalarMat)
|
| 400 |
+
dLdOut[max_st_adj_gold&(~min_st_adj_worst)] = 1
|
| 401 |
+
dLdOut[(~max_st_adj_gold)&min_st_adj_worst] = -1
|
| 402 |
+
|
| 403 |
+
dLdOut *= (b/(1 + b))
|
| 404 |
+
elif loss_type == 2:
|
| 405 |
+
# Square exponential loss
|
| 406 |
+
gamma = 1
|
| 407 |
+
b = np.exp(-etq)
|
| 408 |
+
|
| 409 |
+
L = etg**2 + gamma*b
|
| 410 |
+
|
| 411 |
+
dLdOut = np.zeros_like(WScalarMat)
|
| 412 |
+
dLdOut[max_st_adj_gold&(~min_st_adj_worst)] = 2*etg
|
| 413 |
+
dLdOut[(~max_st_adj_gold)&min_st_adj_worst] = -gamma*b
|
| 414 |
+
pass
|
| 415 |
+
return (L, dLdOut, doBpp)
|
| 416 |
+
def Test(self, sentenceObj, dcsObj, dsbz2_name, _dump = False, _outFile = None):
|
| 417 |
+
if _dump:
|
| 418 |
+
if _outFile is None:
|
| 419 |
+
raise Exception('WTH r u thinking! pass me outFolder')
|
| 420 |
+
if self.neuralnet.version == 'h1':
|
| 421 |
+
self.neuralnet.ForTesting()
|
| 422 |
+
neuralnet = self.neuralnet
|
| 423 |
+
minScore = np.inf
|
| 424 |
+
minMst = None
|
| 425 |
+
|
| 426 |
+
# dsbz2_name = sentenceObj.sent_id + '.ds.bz2'
|
| 427 |
+
(nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 428 |
+
nodelist, conflicts_Dict, featVMat) = open_dsbz2(dsbz2_name)
|
| 429 |
+
|
| 430 |
+
# if len(nodelist) > 50:
|
| 431 |
+
# return None
|
| 432 |
+
|
| 433 |
+
if not self.neuralnet.outer_relu:
|
| 434 |
+
(WScalarMat, SigmoidGateOutput) = Get_W_Scalar_Matrix_from_FeatVect_Matrix(featVMat, nodelist, conflicts_Dict, neuralnet)
|
| 435 |
+
else:
|
| 436 |
+
WScalarMat = Get_W_Scalar_Matrix_from_FeatVect_Matrix(featVMat, nodelist, conflicts_Dict, neuralnet)
|
| 437 |
+
|
| 438 |
+
# print('NeuralNet Time: ', time.time() - startT)
|
| 439 |
+
# startT = time.time()
|
| 440 |
+
|
| 441 |
+
# Get all MST
|
| 442 |
+
for source in range(len(nodelist)):
|
| 443 |
+
(mst_nodes, mst_adj_graph, _) = MST(nodelist, WScalarMat, conflicts_Dict, source)
|
| 444 |
+
# print('.', end = '')
|
| 445 |
+
score = GetMSTWeight(mst_adj_graph, WScalarMat)
|
| 446 |
+
if(score < minScore):
|
| 447 |
+
minScore = score
|
| 448 |
+
minMst = mst_nodes
|
| 449 |
+
dcsLemmas = [[rom_slp(l) for l in arr]for arr in dcsObj.lemmas]
|
| 450 |
+
word_match = 0
|
| 451 |
+
lemma_match = 0
|
| 452 |
+
n_output_nodes = 0
|
| 453 |
+
|
| 454 |
+
if _dump:
|
| 455 |
+
predicted_lemmas = [sentenceObj.sent_id]
|
| 456 |
+
predicted_cngs = [sentenceObj.sent_id]
|
| 457 |
+
predicted_chunk_id = [sentenceObj.sent_id]
|
| 458 |
+
predicted_pos = [sentenceObj.sent_id]
|
| 459 |
+
predicted_id = [sentenceObj.sent_id]
|
| 460 |
+
|
| 461 |
+
for chunk_id, wdSplit in minMst.items():
|
| 462 |
+
for wd in wdSplit:
|
| 463 |
+
if _dump:
|
| 464 |
+
predicted_lemmas.append(wd.lemma)
|
| 465 |
+
predicted_cngs.append(wd.cng)
|
| 466 |
+
predicted_chunk_id.append(wd.chunk_id)
|
| 467 |
+
predicted_pos.append(wd.pos)
|
| 468 |
+
predicted_id.append(wd.id)
|
| 469 |
+
|
| 470 |
+
n_output_nodes += 1
|
| 471 |
+
# Match lemma
|
| 472 |
+
search_result = [i for i, j in enumerate(dcsLemmas[chunk_id]) if j == wd.lemma]
|
| 473 |
+
if len(search_result) > 0:
|
| 474 |
+
lemma_match += 1
|
| 475 |
+
# Match CNG
|
| 476 |
+
for i in search_result:
|
| 477 |
+
if(dcsObj.cng[chunk_id][i] == str(wd.cng)):
|
| 478 |
+
word_match += 1
|
| 479 |
+
# print(wd.lemma, wd.cng)
|
| 480 |
+
break
|
| 481 |
+
dcsLemmas = [l for arr in dcsObj.lemmas for l in arr]
|
| 482 |
+
|
| 483 |
+
if _dump:
|
| 484 |
+
with open(_outFile, 'a') as fh:
|
| 485 |
+
dcsv = csv.writer(fh)
|
| 486 |
+
dcsv.writerow(predicted_lemmas)
|
| 487 |
+
dcsv.writerow(predicted_cngs)
|
| 488 |
+
dcsv.writerow(predicted_chunk_id)
|
| 489 |
+
dcsv.writerow(predicted_pos)
|
| 490 |
+
dcsv.writerow(predicted_id)
|
| 491 |
+
dcsv.writerow([sentenceObj.sent_id, word_match, lemma_match, len(dcsLemmas), n_output_nodes])
|
| 492 |
+
|
| 493 |
+
# print('All MST Time: ', time.time() - startT)
|
| 494 |
+
# print('Node Count: ', len(nodelist))
|
| 495 |
+
# print('\nFull Match: {}, Partial Match: {}, OutOf {}, NodeCount: {}, '.\
|
| 496 |
+
# format(word_match, lemma_match, len(dcsLemmas), len(nodelist)))
|
| 497 |
+
return (word_match, lemma_match, len(dcsLemmas), n_output_nodes)
|
| 498 |
+
|
| 499 |
+
def Train(self, sentenceObj, dcsObj, bz2_input_folder, _debug = True):
|
| 500 |
+
self.neuralnet.ForTraining()
|
| 501 |
+
self.neuralnet.new_dropout() # renew dropout setting
|
| 502 |
+
# Hyperparameter for hinge loss: m
|
| 503 |
+
m_hinge_param = 14
|
| 504 |
+
|
| 505 |
+
dsbz2_name = sentenceObj.sent_id + '.ds.bz2'
|
| 506 |
+
(nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 507 |
+
nodelist, conflicts_Dict, featVMat) = open_dsbz2(bz2_input_folder + dsbz2_name)
|
| 508 |
+
# Train for large graphs separately
|
| 509 |
+
# if len(nodelist) < 40:
|
| 510 |
+
# return
|
| 511 |
+
|
| 512 |
+
""" FORM MAXIMUM(ENERGY) SPANNING TREE OF THE GOLDEN GRAPH : WORST GOLD STRUCTURE """
|
| 513 |
+
WScalarMat_correct = Get_W_Scalar_Matrix_from_FeatVect_Matrix(featVMat_correct, nodelist_correct,\
|
| 514 |
+
conflicts_Dict_correct, self.neuralnet)
|
| 515 |
+
source = 0
|
| 516 |
+
""" Find the max spanning tree : negative Weight matrix passed """
|
| 517 |
+
# (min_st_gold_ndict, min_st_adj_gold_small, _) =\
|
| 518 |
+
# MST(nodelist_correct, -WScalarMat_correct, conflicts_Dict_correct, source)
|
| 519 |
+
(min_st_gold_ndict, min_st_adj_gold_small, _) =\
|
| 520 |
+
MST(nodelist_correct, WScalarMat_correct, conflicts_Dict_correct, source)
|
| 521 |
+
energy_gold_max_ST = np.sum(WScalarMat_correct[min_st_adj_gold_small])
|
| 522 |
+
|
| 523 |
+
""" Convert correct spanning tree graph adj matrix to full marix dimensions """
|
| 524 |
+
""" Create full-size adjacency matrix for correct_mst_small """
|
| 525 |
+
nodelen = len(nodelist)
|
| 526 |
+
min_st_adj_gold = np.ndarray((nodelen, nodelen), np.bool)*False # T_STAR
|
| 527 |
+
for i in range(min_st_adj_gold_small.shape[0]):
|
| 528 |
+
for j in range(min_st_adj_gold_small.shape[1]):
|
| 529 |
+
min_st_adj_gold[nodelist_to_correct_mapping[i], nodelist_to_correct_mapping[j]] =\
|
| 530 |
+
min_st_adj_gold_small[i, j]
|
| 531 |
+
|
| 532 |
+
""" Delta(Margin) Function : MASK FOR WHICH NODES IN NODELIST BELONG TO DCS """
|
| 533 |
+
gold_nodes_mask = np.array([False]*len(nodelist))
|
| 534 |
+
gold_nodes_mask[list(nodelist_to_correct_mapping.values())] = True
|
| 535 |
+
margin_f = lambda nodes_mask: np.sum(nodes_mask&(~gold_nodes_mask))**2
|
| 536 |
+
|
| 537 |
+
""" FOR ALL POSSIBLE MST FROM THE COMPLETE GRAPH """
|
| 538 |
+
WScalarMat = Get_W_Scalar_Matrix_from_FeatVect_Matrix(featVMat, nodelist, conflicts_Dict, self.neuralnet)
|
| 539 |
+
|
| 540 |
+
""" For each node - Find MST with that source"""
|
| 541 |
+
min_STx = None # Min Energy spanning tree with worst margin with gold_STx
|
| 542 |
+
min_marginalized_energy = np.inf
|
| 543 |
+
|
| 544 |
+
# Generate random set of nodes from which mSTs are to be considered
|
| 545 |
+
n_nodes = len(nodelist)
|
| 546 |
+
selection_prob = 0.4
|
| 547 |
+
select_flag = np.random.rand(n_nodes) < selection_prob
|
| 548 |
+
# Fix if all zeros
|
| 549 |
+
if np.sum(select_flag) == 0:
|
| 550 |
+
select_flag[np.random.randint(n_nodes)] = 1
|
| 551 |
+
|
| 552 |
+
best_node_diff = np.Inf
|
| 553 |
+
best_energy = np.Inf
|
| 554 |
+
for source in range(len(nodelist)):
|
| 555 |
+
(mst_nodes, mst_adj_graph, mst_nodes_bool) = MST(nodelist, WScalarMat, conflicts_Dict, source) # T_X
|
| 556 |
+
# Calculate energy of spanning tree
|
| 557 |
+
en_st = np.sum(WScalarMat[mst_adj_graph])
|
| 558 |
+
|
| 559 |
+
# Pick up the node_diff with lowest energy
|
| 560 |
+
delta_st = margin_f(mst_nodes_bool)
|
| 561 |
+
|
| 562 |
+
if _debug:
|
| 563 |
+
if best_energy > en_st:
|
| 564 |
+
best_node_diff = delta_st
|
| 565 |
+
best_energy = en_st
|
| 566 |
+
|
| 567 |
+
# Minimum marginalized energy calculation
|
| 568 |
+
marginalized_en = en_st - delta_st
|
| 569 |
+
# Minimum marginalized spanning tree : Randomization applied
|
| 570 |
+
# if marginalized_en < min_marginalized_energy and select_flag[source]:
|
| 571 |
+
if marginalized_en < min_marginalized_energy:
|
| 572 |
+
min_marginalized_energy = marginalized_en
|
| 573 |
+
min_STx = mst_adj_graph
|
| 574 |
+
# Energy diff should all be negative
|
| 575 |
+
if _debug:
|
| 576 |
+
print('Source: [{}], Node_Diff:{}, Max_Gold_En: {:.3f}, Energy: {:.3f}'.\
|
| 577 |
+
format(source, np.sum((~gold_nodes_mask)&mst_nodes_bool), energy_gold_max_ST, np.sum(WScalarMat[mst_adj_graph])))
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
if _debug:
|
| 581 |
+
print('Best Node diff: {} with EN: {}'.format(np.sqrt(best_node_diff), best_energy))
|
| 582 |
+
""" Gradient Descent """
|
| 583 |
+
# LOSS TYPES -> hinge(0), log-loss(1), square-exponential(2)
|
| 584 |
+
Total_Loss, dLdOut, doBpp = self.CalculateLoss_n_Grads(WScalarMat, min_STx, min_st_adj_gold,\
|
| 585 |
+
loss_type = 0, min_marginalized_energy = min_marginalized_energy)
|
| 586 |
+
if doBpp:
|
| 587 |
+
if _debug:
|
| 588 |
+
print('{}. '.format(sentenceObj.sent_id), end = '')
|
| 589 |
+
self.neuralnet.Back_Prop(dLdOut, len(nodelist), featVMat, _debug)
|
| 590 |
+
else:
|
| 591 |
+
trainingStatus[sentenceObj.sent_id] = True
|
| 592 |
+
if _debug:
|
| 593 |
+
print("\nFileKey: %s, Loss: %6.3f" % (sentenceObj.sent_id, Total_Loss))
|
| 594 |
+
|
| 595 |
+
TrainFiles = None
|
| 596 |
+
trainer = None
|
| 597 |
+
p_name = ''
|
| 598 |
+
odir = ''
|
| 599 |
+
def InitModule():
|
| 600 |
+
global trainer
|
| 601 |
+
trainer = Trainer()
|
| 602 |
+
|
| 603 |
+
def register_nnet(nnet, bz2_input_folder):
|
| 604 |
+
if not os.path.isdir(odir):
|
| 605 |
+
os.mkdir(odir)
|
| 606 |
+
if not os.path.isfile('outputs/nnet_LOGS.csv'):
|
| 607 |
+
with open('outputs/nnet_LOGS.csv', 'a') as fh:
|
| 608 |
+
csv_r = csv.writer(fh)
|
| 609 |
+
csv_r.writerow(['odir', 'p_name', 'hidden_layer_size', '_edge_vector_dim'])
|
| 610 |
+
with open('outputs/nnet_LOGS.csv', 'a') as fh:
|
| 611 |
+
csv_r = csv.writer(fh)
|
| 612 |
+
if nnet.version == 'h1':
|
| 613 |
+
csv_r.writerow([odir, p_name, nnet.n, nnet.d, bz2_input_folder])
|
| 614 |
+
elif nnet.version == 'h2':
|
| 615 |
+
csv_r.writerow([odir, p_name, nnet.h1, nnet.h2, nnet.d, bz2_input_folder])
|
| 616 |
+
|
| 617 |
+
"""
|
| 618 |
+
################################################################################################
|
| 619 |
+
################################################################################################
|
| 620 |
+
################################################################################################
|
| 621 |
+
"""
|
| 622 |
+
def main():
|
| 623 |
+
global TrainFiles, p_name, odir
|
| 624 |
+
"""
|
| 625 |
+
################################################################################################
|
| 626 |
+
############################## GET A FILENAME TO SAVE WEIGHTS ################################
|
| 627 |
+
################################################################################################
|
| 628 |
+
"""
|
| 629 |
+
st = str(int((time.time() * 1e6) % 1e13))
|
| 630 |
+
log_name = 'logs/train_nnet_t{}.out'.format(st)
|
| 631 |
+
odir = 'outputs/train_t{}'.format(st)
|
| 632 |
+
p_name = 'outputs/train_t{}/nnet.p'.format(st)
|
| 633 |
+
print('nEURAL nET wILL bE sAVED hERE: ', p_name)
|
| 634 |
+
|
| 635 |
+
# Create Training File List
|
| 636 |
+
excluded_files = []
|
| 637 |
+
with open('inputs/Baseline4_advSample.csv', 'r') as f_handle:
|
| 638 |
+
opener = csv.reader(f_handle)
|
| 639 |
+
for line in opener:
|
| 640 |
+
excluded_files.append(line[1].replace('.p', '.ds.bz2'))
|
| 641 |
+
|
| 642 |
+
# Load Simultaneous files
|
| 643 |
+
print('Loading Large Files')
|
| 644 |
+
loaded_SKT = pickle.load(open('../Simultaneous_CompatSKT_10K.p', 'rb'), encoding=u'utf-8')
|
| 645 |
+
loaded_DCS = pickle.load(open('../Simultaneous_DCS_10K.p', 'rb'), encoding=u'utf-8')
|
| 646 |
+
|
| 647 |
+
# loaded_SKT = pickle.load(open('../Simultaneous_CompatSKT.p', 'rb'), encoding=u'utf-8')
|
| 648 |
+
# loaded_DCS = pickle.load(open('../Simultaneous_DCS.p', 'rb'), encoding=u'utf-8')
|
| 649 |
+
|
| 650 |
+
bz2_input_folder = '../NewData/skt_dcs_DS.bz2_1L_bigram_10K/'
|
| 651 |
+
# bz2_input_folder = '/home/rs/15CS91R05/vishnu/Data/skt_dcs_DS.bz2_compat_10k_check_again/'
|
| 652 |
+
all_files = []
|
| 653 |
+
skipped = 0
|
| 654 |
+
for f in os.listdir(bz2_input_folder):
|
| 655 |
+
if '.ds.bz2' in f:
|
| 656 |
+
if f in excluded_files:
|
| 657 |
+
skipped += 1
|
| 658 |
+
continue
|
| 659 |
+
if f.replace('.ds.bz2', '.p2') not in loaded_DCS:
|
| 660 |
+
print('Couldnt find ', f)
|
| 661 |
+
continue
|
| 662 |
+
all_files.append(f)
|
| 663 |
+
|
| 664 |
+
print(skipped, 'files will not be used for training')
|
| 665 |
+
print('Size of training set:', len(all_files))
|
| 666 |
+
|
| 667 |
+
TrainFiles = all_files
|
| 668 |
+
|
| 669 |
+
InitModule()
|
| 670 |
+
trainingStatus = defaultdict(lambda: bool(False))
|
| 671 |
+
# train = train_generator(loaded_SKT, loaded_DCS, bz2_input_folder, n_trainset = -1, filePerBatch = 10, iterationPerBatch = 5, _debug=False, superEpochs = 5)
|
| 672 |
+
train = train_generator(loaded_SKT, loaded_DCS, bz2_input_folder, n_trainset = -1, filePerBatch = 20, iterationPerBatch = 3, _debug=False, superEpochs = 2)
|
| 673 |
+
|
| 674 |
+
# Complete Training
|
| 675 |
+
train.__next__()
|
| 676 |
+
|
| 677 |
+
print('Training Complete')
|
| 678 |
+
|
| 679 |
+
if __name__ == '__main__':
|
| 680 |
+
main()
|
dir/__pycache__/TestPool_Unit_clique.cpython-36.pyc
ADDED
|
Binary file (1.45 kB). View file
|
|
|
dir/__pycache__/heap_n_clique.cpython-36.pyc
ADDED
|
Binary file (6.34 kB). View file
|
|
|
dir/bronclique.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from word_definite import *
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
def Parent(i):
|
| 6 |
+
return max(0, math.floor((i - 1)/2))
|
| 7 |
+
|
| 8 |
+
def Left(i):
|
| 9 |
+
return 2*i + 1
|
| 10 |
+
|
| 11 |
+
def Right(i):
|
| 12 |
+
return 2*(i + 1)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
################################################################################################
|
| 17 |
+
######################## NOMINAL NODE CLASS REQUIRED FOR USING ################################
|
| 18 |
+
######################### WITH THE HEAP DATA STRUCTURE #######################################
|
| 19 |
+
################################################################################################
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
class Node:
|
| 23 |
+
def __init__(self, id, dist):
|
| 24 |
+
self.dist = dist
|
| 25 |
+
self.id = id
|
| 26 |
+
self.isConflicted = False
|
| 27 |
+
self.src = -1
|
| 28 |
+
|
| 29 |
+
"""
|
| 30 |
+
################################################################################################
|
| 31 |
+
############################ IMPLEMENTATION OF HEAP ##########################################
|
| 32 |
+
################################################################################################
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
class Heap:
|
| 36 |
+
# It's a minHeap
|
| 37 |
+
# Nodes are of type Word_definite
|
| 38 |
+
def __init__(self, nodeList):
|
| 39 |
+
self.nodeList = [n for n in nodeList]
|
| 40 |
+
self.len = len(nodeList)
|
| 41 |
+
self.idLocator = {}
|
| 42 |
+
for i in range(self.len):
|
| 43 |
+
self.idLocator[nodeList[i].id] = i
|
| 44 |
+
self.Build()
|
| 45 |
+
|
| 46 |
+
def Exchange(self, i, j):
|
| 47 |
+
t = self.nodeList[i]
|
| 48 |
+
self.nodeList[i] = self.nodeList[j]
|
| 49 |
+
self.nodeList[j] = t
|
| 50 |
+
self.idLocator[self.nodeList[i].id] = i
|
| 51 |
+
self.idLocator[self.nodeList[j].id] = j
|
| 52 |
+
|
| 53 |
+
def Decrease_Key(self, node, newDist, src):
|
| 54 |
+
if node.isConflicted:
|
| 55 |
+
return
|
| 56 |
+
i = self.idLocator[node.id]
|
| 57 |
+
if newDist > node.dist:
|
| 58 |
+
# relaxation not possible
|
| 59 |
+
return
|
| 60 |
+
else:
|
| 61 |
+
node.dist = newDist
|
| 62 |
+
node.src = src
|
| 63 |
+
parent = Parent(i)
|
| 64 |
+
while ((i > 0) and (self.nodeList[parent].dist > self.nodeList[i].dist)):
|
| 65 |
+
self.Exchange(i, parent)
|
| 66 |
+
i = parent
|
| 67 |
+
parent = Parent(i)
|
| 68 |
+
|
| 69 |
+
def Pop(self):
|
| 70 |
+
if(self.len == 0):
|
| 71 |
+
return None
|
| 72 |
+
if(self.nodeList[0].isConflicted):
|
| 73 |
+
# print("Pop has seen conflict!!!")
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
# Remove the entry from the top of the heap
|
| 77 |
+
nMin = self.nodeList[0]
|
| 78 |
+
self.idLocator[self.nodeList[0].id] = -1
|
| 79 |
+
|
| 80 |
+
# Put the last node on top of heap and heapify
|
| 81 |
+
self.nodeList[0] = self.nodeList[self.len - 1]
|
| 82 |
+
self.idLocator[self.nodeList[0].id] = 0
|
| 83 |
+
self.len -= 1
|
| 84 |
+
self.Min_Heapify(0)
|
| 85 |
+
return nMin
|
| 86 |
+
|
| 87 |
+
def Min_Heapify(self, i):
|
| 88 |
+
nMin = self.nodeList[i]
|
| 89 |
+
li = Left(i)
|
| 90 |
+
if(li < self.len):
|
| 91 |
+
if(self.nodeList[li].dist < nMin.dist):
|
| 92 |
+
nMin = self.nodeList[li]
|
| 93 |
+
min_i = li
|
| 94 |
+
ri = Right(i)
|
| 95 |
+
if(ri < self.len):
|
| 96 |
+
if(self.nodeList[ri].dist < nMin.dist):
|
| 97 |
+
nMin = self.nodeList[ri]
|
| 98 |
+
min_i = ri
|
| 99 |
+
if(nMin.id != self.nodeList[i].id):
|
| 100 |
+
self.Exchange(i, min_i)
|
| 101 |
+
self.Min_Heapify(min_i)
|
| 102 |
+
|
| 103 |
+
def Delete(self, node):
|
| 104 |
+
i = self.idLocator[node.id]
|
| 105 |
+
self.nodeList[i].isConflicted = True
|
| 106 |
+
self.nodeList[i].dist = np.inf
|
| 107 |
+
self.Min_Heapify(i)
|
| 108 |
+
|
| 109 |
+
def Build(self):
|
| 110 |
+
self.len = len(self.nodeList)
|
| 111 |
+
for i in range(int(Parent(self.len - 1)) + 1):
|
| 112 |
+
self.Min_Heapify(i)
|
| 113 |
+
|
| 114 |
+
def Print(self):
|
| 115 |
+
i = 0
|
| 116 |
+
level = 1
|
| 117 |
+
ilimit = 0
|
| 118 |
+
while(i < self.len):
|
| 119 |
+
print('N(%d, %2.1f)' % (self.nodeList[i].id, self.nodeList[i].dist), end = ' ')
|
| 120 |
+
i += 1
|
| 121 |
+
if(i > ilimit):
|
| 122 |
+
print('\n')
|
| 123 |
+
level *= 2
|
| 124 |
+
ilimit += level
|
| 125 |
+
|
| 126 |
+
"""
|
| 127 |
+
################################################################################################
|
| 128 |
+
###################### IMPLEMENTATION OF PRIM'S ALGO FOR FINDING MST ##########################
|
| 129 |
+
############################# USES HEAP DEFINED ABOVE ########################################
|
| 130 |
+
################################################################################################
|
| 131 |
+
"""
|
| 132 |
+
def MST(nodelist, WScalarMat, conflicts_Dict, source):
|
| 133 |
+
# WTF Dude!!! This function should not be used... It is running Prim's on a directed graph!!!
|
| 134 |
+
# Doesn't return MST
|
| 135 |
+
mst_adj_graph = np.ndarray(WScalarMat.shape, np.bool)*False
|
| 136 |
+
# print(len(nodelist))
|
| 137 |
+
# Reset nodes and put ids
|
| 138 |
+
for id in range(len(nodelist)):
|
| 139 |
+
nodelist[id].id = id
|
| 140 |
+
nodelist[id].dist = np.inf
|
| 141 |
+
nodelist[id].isConflicted = False
|
| 142 |
+
nodelist[id].src = -1
|
| 143 |
+
|
| 144 |
+
# Initialize Graph and min-Heap
|
| 145 |
+
nodelist[source].dist = 0
|
| 146 |
+
for neighbour in range(len(nodelist)):
|
| 147 |
+
if neighbour != source:
|
| 148 |
+
nodelist[neighbour].dist = WScalarMat[source][neighbour]
|
| 149 |
+
nodelist[neighbour].src = source
|
| 150 |
+
h = Heap(nodelist)
|
| 151 |
+
|
| 152 |
+
mst_nodes = defaultdict(lambda: [])
|
| 153 |
+
mst_nodes_bool = np.array([False]*len(nodelist))
|
| 154 |
+
# Run MST only until first conflicting node is seen
|
| 155 |
+
# Conflicting node will have np.inf as dist
|
| 156 |
+
while True:
|
| 157 |
+
nextNode = h.Pop()
|
| 158 |
+
if nextNode == None:
|
| 159 |
+
break
|
| 160 |
+
print("next-id:"+str(nextNode.id))
|
| 161 |
+
print('picked by '+str(nodelist[nextNode.id].dist))
|
| 162 |
+
print()
|
| 163 |
+
# print(nextNode.src, nextNode.id, nextNode)
|
| 164 |
+
mst_nodes_bool[nextNode.id] = True
|
| 165 |
+
mst_nodes[nextNode.chunk_id].append(nextNode)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
if nextNode.src != -1:
|
| 169 |
+
mst_adj_graph[nextNode.src, nextNode.id] = True
|
| 170 |
+
# mst_adj_graph[nextNode.id, nextNode.src] = True
|
| 171 |
+
nid = nextNode.id
|
| 172 |
+
for conId in conflicts_Dict[nid]:
|
| 173 |
+
h.Delete(nodelist[conId])
|
| 174 |
+
for neighbour in range(len(nodelist)):
|
| 175 |
+
if neighbour != nextNode.id:
|
| 176 |
+
print(WScalarMat[nextNode.id][neighbour])
|
| 177 |
+
print(nodelist[neighbour].dist)
|
| 178 |
+
h.Decrease_Key(nodelist[neighbour], WScalarMat[nextNode.id][neighbour], nextNode.id)
|
| 179 |
+
|
| 180 |
+
print(mst_nodes_bool)
|
| 181 |
+
# print(mst_nodes_bool)
|
| 182 |
+
print('#'*30)
|
| 183 |
+
mst_nodes = dict(mst_nodes)
|
| 184 |
+
|
| 185 |
+
return (mst_nodes, mst_adj_graph, mst_nodes_bool)
|
| 186 |
+
|
| 187 |
+
def clique(nodelist, WScalarMat, conflicts_Dict, source):
|
| 188 |
+
# WTF Dude!!! This function should not be used... It is running Prim's on a directed graph!!!
|
| 189 |
+
# Doesn't return MST
|
| 190 |
+
mst_adj_graph = np.ndarray(WScalarMat.shape, np.bool)*False
|
| 191 |
+
# print(len(nodelist))
|
| 192 |
+
# Reset nodes and put ids
|
| 193 |
+
# print('node-ids')
|
| 194 |
+
for id in range(len(nodelist)):
|
| 195 |
+
# print(id)
|
| 196 |
+
nodelist[id].id = id
|
| 197 |
+
nodelist[id].dist = np.inf
|
| 198 |
+
nodelist[id].isConflicted = False
|
| 199 |
+
nodelist[id].src = -1
|
| 200 |
+
# print('*'*40)
|
| 201 |
+
# Initialize Graph and min-Heap
|
| 202 |
+
nodelist[source].dist = 0
|
| 203 |
+
|
| 204 |
+
nodeset=set()
|
| 205 |
+
for neighbour in range(len(nodelist)):
|
| 206 |
+
if neighbour != source:
|
| 207 |
+
nodelist[neighbour].dist = WScalarMat[source][neighbour]
|
| 208 |
+
nodelist[neighbour].src = source
|
| 209 |
+
# nodeset.add((nodelist[neighbour].dist,neighbour))
|
| 210 |
+
|
| 211 |
+
# nodeset = sorted(nodeset)
|
| 212 |
+
nodeset.add((0,source))
|
| 213 |
+
nodesadded=[]
|
| 214 |
+
nodesavailable = np.zeros(len(nodelist),dtype=int) # o if available, 1 if not available
|
| 215 |
+
|
| 216 |
+
mst_nodes = defaultdict(lambda: [])
|
| 217 |
+
mst_nodes_bool = np.array([False]*len(nodelist))
|
| 218 |
+
# Run MST only until first conflicting node is seen
|
| 219 |
+
# Conflicting node will have np.inf as dist
|
| 220 |
+
|
| 221 |
+
it=0
|
| 222 |
+
nextNode=-1
|
| 223 |
+
while True:
|
| 224 |
+
# print(nodeset)
|
| 225 |
+
it+=1
|
| 226 |
+
# print(it)
|
| 227 |
+
if(it>1000):
|
| 228 |
+
break
|
| 229 |
+
if(len(nodeset)==0):
|
| 230 |
+
break
|
| 231 |
+
# print('before nn assign: ')
|
| 232 |
+
# print(nextNode)
|
| 233 |
+
nextNode = next(iter(nodeset))
|
| 234 |
+
# print("after nn assign:")
|
| 235 |
+
# print(nextNode)
|
| 236 |
+
# print("Nextnode is :"+str(nextNode[1])+" Picked by :"+str(nextNode[0]))
|
| 237 |
+
nextNode=nodelist[nextNode[1]]
|
| 238 |
+
# print(type(nextNode))
|
| 239 |
+
# print(st_setr(nextNode.id)+"",)
|
| 240 |
+
|
| 241 |
+
# print(nextNode.id)
|
| 242 |
+
nodesavailable[nextNode.id]=1
|
| 243 |
+
# nodesavailable=1
|
| 244 |
+
if nextNode == None:
|
| 245 |
+
break
|
| 246 |
+
# print(nextNode.src, nextNode.id, nextNode)
|
| 247 |
+
mst_nodes_bool[nextNode.id] = True
|
| 248 |
+
mst_nodes[nextNode.chunk_id].append(nextNode)
|
| 249 |
+
|
| 250 |
+
nodeset = set()
|
| 251 |
+
|
| 252 |
+
if nextNode.src != -1:
|
| 253 |
+
mst_adj_graph[nextNode.src, nextNode.id] = True
|
| 254 |
+
# mst_adj_graph[nextNode.id, nextNode.src] = True
|
| 255 |
+
|
| 256 |
+
nid = nextNode.id
|
| 257 |
+
nodesadded.append(nid)
|
| 258 |
+
for conId in conflicts_Dict[nid]:
|
| 259 |
+
# h.Delete(nodelist[conId])
|
| 260 |
+
nodesavailable[conId]=1
|
| 261 |
+
# print('here')
|
| 262 |
+
|
| 263 |
+
for neighbour in range(len(nodelist)):
|
| 264 |
+
# print(type(nodesavailable))
|
| 265 |
+
# print(type(nodesavailable[0]))
|
| 266 |
+
if(nodesavailable[neighbour]==1):
|
| 267 |
+
continue
|
| 268 |
+
if neighbour != nextNode.id:
|
| 269 |
+
# h.Decrease_Key(nodelist[neighbour], WScalarMat[nextNode.id][neighbour], nextNode.id)
|
| 270 |
+
edgewt=0
|
| 271 |
+
# print(nodesadded)
|
| 272 |
+
for nodepresent in nodesadded:
|
| 273 |
+
edgewt+=WScalarMat[nodepresent][neighbour]
|
| 274 |
+
# print('adding '+str(neighbour))
|
| 275 |
+
nodeset.add((edgewt,neighbour))
|
| 276 |
+
# print(nodeset)
|
| 277 |
+
|
| 278 |
+
nodeset=sorted(nodeset)
|
| 279 |
+
# print(nodeset)
|
| 280 |
+
# print("#"*30)
|
| 281 |
+
# print(mst_nodes_bool)
|
| 282 |
+
# print('-'*20)
|
| 283 |
+
# print('#'*30)
|
| 284 |
+
mst_nodes = dict(mst_nodes)
|
| 285 |
+
if(it>1000):
|
| 286 |
+
print('!!!!*10')
|
| 287 |
+
for i in range(len(mst_nodes_bool)):
|
| 288 |
+
for j in range(len(mst_nodes_bool)):
|
| 289 |
+
if(i==j):
|
| 290 |
+
continue
|
| 291 |
+
if(mst_nodes_bool[i] and mst_nodes_bool[j]):
|
| 292 |
+
mst_adj_graph[i][j]=True
|
| 293 |
+
mst_adj_graph[j][i]=True
|
| 294 |
+
|
| 295 |
+
# print(mst_adj_graph)
|
| 296 |
+
# print("#")
|
| 297 |
+
return (mst_nodes, mst_adj_graph, mst_nodes_bool)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def bron(R,P,X,nodelist,conflicts_Dict,level):
|
| 301 |
+
L = []
|
| 302 |
+
if(len(P)==0 and len(X)==0):
|
| 303 |
+
L.append(R)
|
| 304 |
+
return L
|
| 305 |
+
# print('P'*30)
|
| 306 |
+
# print(P)
|
| 307 |
+
# print('R'*30)
|
| 308 |
+
# print(R)
|
| 309 |
+
# print('X'*30)
|
| 310 |
+
# print(X)
|
| 311 |
+
Pit = P.copy()
|
| 312 |
+
# while(len(P)>0):
|
| 313 |
+
for v in Pit:
|
| 314 |
+
# v = next(iter(P))
|
| 315 |
+
R1 = R.copy()
|
| 316 |
+
P1 = P.copy()
|
| 317 |
+
X1 = X.copy()
|
| 318 |
+
R1.add(v)
|
| 319 |
+
for i in conflicts_Dict[v]:
|
| 320 |
+
if(i in P1):
|
| 321 |
+
P1.remove(i)
|
| 322 |
+
if(i in X1):
|
| 323 |
+
X1.remove(i)
|
| 324 |
+
if(v in P1):
|
| 325 |
+
P1.remove(v)
|
| 326 |
+
if(v in X1):
|
| 327 |
+
X1.remove(v)
|
| 328 |
+
G = bron(R1,P1,X1,nodelist,conflicts_Dict,level+1)
|
| 329 |
+
if(v in P):
|
| 330 |
+
P.remove(v)
|
| 331 |
+
X.add(v)
|
| 332 |
+
for i in G:
|
| 333 |
+
L.append(i)
|
| 334 |
+
return L
|
| 335 |
+
|
| 336 |
+
def RandomST_GoldOnly(nodelist, WScalarMat, conflicts_Dict, source):
|
| 337 |
+
(mst_nodes, mst_adj_graph, mst_nodes_bool) = MST(nodelist, WScalarMat, conflicts_Dict, source)
|
| 338 |
+
|
| 339 |
+
mst_adj_graph = np.zeros_like(mst_adj_graph)
|
| 340 |
+
nodelen = len(nodelist)
|
| 341 |
+
|
| 342 |
+
## Random mst_adj_graph
|
| 343 |
+
free_set = list(range(nodelen))
|
| 344 |
+
full_set = list(range(nodelen))
|
| 345 |
+
st_set = []
|
| 346 |
+
start_node = np.random.randint(nodelen)
|
| 347 |
+
st_set.append(start_node)
|
| 348 |
+
free_set.remove(start_node)
|
| 349 |
+
for x in range(nodelen - 1):
|
| 350 |
+
a = st_set[np.random.randint(len(st_set))]
|
| 351 |
+
b = free_set[np.random.randint(len(free_set))]
|
| 352 |
+
if b not in st_set:
|
| 353 |
+
st_set.append(b)
|
| 354 |
+
free_set.remove(b)
|
| 355 |
+
mst_adj_graph[a, b] = 1
|
| 356 |
+
# mst_adj_graph[b, a] = 1 # Directed Spanning tree
|
| 357 |
+
|
| 358 |
+
return (mst_nodes, mst_adj_graph, mst_nodes_bool)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def GetMSTWeight(mst_adj_graph, WScalarMat):
|
| 362 |
+
return np.sum(WScalarMat[mst_adj_graph])
|
dir/bucket_by_conflicting_nodes_0.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import os
|
| 3 |
+
import bz2
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def harmonic(P, R):
|
| 7 |
+
"""For Calculation of F-Score since it is the HM of P and R"""
|
| 8 |
+
return(2 * P * R / float(P + R))
|
| 9 |
+
# Test on a couple of files
|
| 10 |
+
base_path_csv = '/home/rs/15CS91R05/gaurav/myTryouts/init_results/prediction_csvs/'
|
| 11 |
+
base_path_bz2 = '/home/rs/15CS91R05/Bishal/NewData/skt_dcs_DS.bz2_1L_bigram_heldout_dev/'
|
| 12 |
+
|
| 13 |
+
pred_csvs = os.listdir(base_path_csv)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
"""Task 5: See data from number of conflicts
|
| 17 |
+
Approach: Select a node from DCS, take count of conflicting nodes using the conflictsDict_correct
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
# Function to open bz2 files (that contains both DCS & SKT info)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def open_dsbz2(filename):
|
| 24 |
+
with bz2.BZ2File(filename, 'r') as f:
|
| 25 |
+
loader = pickle.load(f)
|
| 26 |
+
|
| 27 |
+
conflicts_Dict_correct = loader['conflicts_Dict_correct']
|
| 28 |
+
nodelist_to_correct_mapping = loader['nodelist_to_correct_mapping']
|
| 29 |
+
nodelist_correct = loader['nodelist_correct']
|
| 30 |
+
featVMat_correct = loader['featVMat_correct']
|
| 31 |
+
featVMat = loader['featVMat']
|
| 32 |
+
conflicts_Dict = loader['conflicts_Dict']
|
| 33 |
+
nodelist = loader['nodelist']
|
| 34 |
+
|
| 35 |
+
return (nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 36 |
+
nodelist, conflicts_Dict, featVMat)
|
| 37 |
+
|
| 38 |
+
bucket_by_conflicting_nodes = {}
|
| 39 |
+
num_conflicting_nodes = set()
|
| 40 |
+
csv = open(base_path_csv + pred_csvs[0], 'r').readlines()
|
| 41 |
+
|
| 42 |
+
for line in range(0, len(csv), 6):
|
| 43 |
+
head_line = csv[line].strip().split(',')
|
| 44 |
+
fname = head_line[0]
|
| 45 |
+
print("Bz2 File number", fname, line / 6)
|
| 46 |
+
|
| 47 |
+
(nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 48 |
+
nodelist, conflicts_Dict, featVMat) = open_dsbz2(base_path_bz2 + fname + '.ds.bz2')
|
| 49 |
+
|
| 50 |
+
assert len(nodelist_correct) == len(conflicts_Dict_correct)
|
| 51 |
+
for node in conflicts_Dict_correct:
|
| 52 |
+
lemma = nodelist_correct[node].lemma
|
| 53 |
+
conflicting_nodes_count = len(conflicts_Dict_correct[node])
|
| 54 |
+
if conflicting_nodes_count not in num_conflicting_nodes:
|
| 55 |
+
num_conflicting_nodes.add(conflicting_nodes_count)
|
| 56 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count] = {'lemmas': set(), 'precision': [0, 0], 'recall': [0, 0]}
|
| 57 |
+
|
| 58 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['lemmas'].add(lemma)
|
| 59 |
+
|
| 60 |
+
data = csv[line + 5].strip().split(',')
|
| 61 |
+
word_recall = float(data[1]) / float(data[3])
|
| 62 |
+
lemma_recall = float(data[2]) / float(data[3])
|
| 63 |
+
word_precision = float(data[1]) / float(data[4])
|
| 64 |
+
lemma_precision = float(data[2]) / float(data[4])
|
| 65 |
+
|
| 66 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['recall'][0] += word_recall
|
| 67 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['recall'][1] += lemma_recall
|
| 68 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['precision'][0] += word_precision
|
| 69 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['precision'][1] += lemma_precision
|
| 70 |
+
|
| 71 |
+
# for conflicting_count in bucket_by_conflicting_nodes:
|
| 72 |
+
|
| 73 |
+
# # Average P & R
|
| 74 |
+
# bucket_by_conflicting_nodes[conflicting_count]['precision'][0] /= bucket_by_conflicting_nodes[conflicting_count]['num_lemmas']
|
| 75 |
+
# bucket_by_conflicting_nodes[conflicting_count]['precision'][1] /= bucket_by_conflicting_nodes[conflicting_count]['num_lemmas']
|
| 76 |
+
# bucket_by_conflicting_nodes[conflicting_count]['recall'][0] /= bucket_by_conflicting_nodes[conflicting_count]['num_lemmas']
|
| 77 |
+
# bucket_by_conflicting_nodes[conflicting_count]['recall'][1] /= bucket_by_conflicting_nodes[conflicting_count]['num_lemmas']
|
| 78 |
+
#
|
| 79 |
+
# # Find F-Score
|
| 80 |
+
# wrd_fscore = harmonic(bucket_by_conflicting_nodes[conflicting_count]['precision'][0], bucket_by_conflicting_nodes[conflicting_count]['recall'][0])
|
| 81 |
+
# lma_fscore = harmonic(bucket_by_conflicting_nodes[conflicting_count]['precision'][1], bucket_by_conflicting_nodes[conflicting_count]['recall'][1])
|
| 82 |
+
# bucket_by_conflicting_nodes[conflicting_count]['fscore'] = [wrd_fscore, lma_fscore]
|
| 83 |
+
|
| 84 |
+
with open('final_task_gaurav/bucket_by_conflicting_nodes_0.p', 'wb') as f:
|
| 85 |
+
pickle.dump(bucket_by_conflicting_nodes, f)
|
dir/bucket_by_conflicting_nodes_1.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import os
|
| 3 |
+
import bz2
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def harmonic(P, R):
|
| 7 |
+
"""For Calculation of F-Score since it is the HM of P and R"""
|
| 8 |
+
return(2 * P * R / float(P + R))
|
| 9 |
+
# Test on a couple of files
|
| 10 |
+
base_path_csv = '/home/rs/15CS91R05/gaurav/myTryouts/init_results/prediction_csvs/'
|
| 11 |
+
base_path_bz2 = '/home/rs/15CS91R05/Bishal/NewData/skt_dcs_DS.bz2_1L_bigram_heldout_dev/'
|
| 12 |
+
|
| 13 |
+
pred_csvs = os.listdir(base_path_csv)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
"""Task 5: See data from number of conflicts
|
| 17 |
+
Approach: Select a node from DCS, take count of conflicting nodes using the conflictsDict_correct
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
# Function to open bz2 files (that contains both DCS & SKT info)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def open_dsbz2(filename):
|
| 24 |
+
with bz2.BZ2File(filename, 'r') as f:
|
| 25 |
+
loader = pickle.load(f)
|
| 26 |
+
|
| 27 |
+
conflicts_Dict_correct = loader['conflicts_Dict_correct']
|
| 28 |
+
nodelist_to_correct_mapping = loader['nodelist_to_correct_mapping']
|
| 29 |
+
nodelist_correct = loader['nodelist_correct']
|
| 30 |
+
featVMat_correct = loader['featVMat_correct']
|
| 31 |
+
featVMat = loader['featVMat']
|
| 32 |
+
conflicts_Dict = loader['conflicts_Dict']
|
| 33 |
+
nodelist = loader['nodelist']
|
| 34 |
+
|
| 35 |
+
return (nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 36 |
+
nodelist, conflicts_Dict, featVMat)
|
| 37 |
+
|
| 38 |
+
bucket_by_conflicting_nodes = {}
|
| 39 |
+
num_conflicting_nodes = set()
|
| 40 |
+
csv = open(base_path_csv + pred_csvs[1], 'r').readlines()
|
| 41 |
+
|
| 42 |
+
for line in range(0, len(csv), 6):
|
| 43 |
+
head_line = csv[line].strip().split(',')
|
| 44 |
+
fname = head_line[0]
|
| 45 |
+
print("Bz2 File number", fname, line / 6)
|
| 46 |
+
|
| 47 |
+
(nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 48 |
+
nodelist, conflicts_Dict, featVMat) = open_dsbz2(base_path_bz2 + fname + '.ds.bz2')
|
| 49 |
+
|
| 50 |
+
assert len(nodelist_correct) == len(conflicts_Dict_correct)
|
| 51 |
+
for node in conflicts_Dict_correct:
|
| 52 |
+
lemma = nodelist_correct[node].lemma
|
| 53 |
+
conflicting_nodes_count = len(conflicts_Dict_correct[node])
|
| 54 |
+
if conflicting_nodes_count not in num_conflicting_nodes:
|
| 55 |
+
num_conflicting_nodes.add(conflicting_nodes_count)
|
| 56 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count] = {'lemmas': set(), 'precision': [0, 0], 'recall': [0, 0]}
|
| 57 |
+
|
| 58 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['lemmas'].add(lemma)
|
| 59 |
+
|
| 60 |
+
data = csv[line + 5].strip().split(',')
|
| 61 |
+
word_recall = float(data[1]) / float(data[3])
|
| 62 |
+
lemma_recall = float(data[2]) / float(data[3])
|
| 63 |
+
word_precision = float(data[1]) / float(data[4])
|
| 64 |
+
lemma_precision = float(data[2]) / float(data[4])
|
| 65 |
+
|
| 66 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['recall'][0] += word_recall
|
| 67 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['recall'][1] += lemma_recall
|
| 68 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['precision'][0] += word_precision
|
| 69 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['precision'][1] += lemma_precision
|
| 70 |
+
|
| 71 |
+
with open('final_task_gaurav/bucket_by_conflicting_nodes_1.p', 'wb') as f:
|
| 72 |
+
pickle.dump(bucket_by_conflicting_nodes, f)
|
dir/bucket_by_conflicting_nodes_2.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import os
|
| 3 |
+
import bz2
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def harmonic(P, R):
|
| 7 |
+
"""For Calculation of F-Score since it is the HM of P and R"""
|
| 8 |
+
return(2 * P * R / float(P + R))
|
| 9 |
+
# Test on a couple of files
|
| 10 |
+
base_path_csv = '/home/rs/15CS91R05/gaurav/myTryouts/init_results/prediction_csvs/'
|
| 11 |
+
base_path_bz2 = '/home/rs/15CS91R05/Bishal/NewData/skt_dcs_DS.bz2_1L_bigram_heldout_dev/'
|
| 12 |
+
|
| 13 |
+
pred_csvs = os.listdir(base_path_csv)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
"""Task 5: See data from number of conflicts
|
| 17 |
+
Approach: Select a node from DCS, take count of conflicting nodes using the conflictsDict_correct
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
# Function to open bz2 files (that contains both DCS & SKT info)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def open_dsbz2(filename):
|
| 24 |
+
with bz2.BZ2File(filename, 'r') as f:
|
| 25 |
+
loader = pickle.load(f)
|
| 26 |
+
|
| 27 |
+
conflicts_Dict_correct = loader['conflicts_Dict_correct']
|
| 28 |
+
nodelist_to_correct_mapping = loader['nodelist_to_correct_mapping']
|
| 29 |
+
nodelist_correct = loader['nodelist_correct']
|
| 30 |
+
featVMat_correct = loader['featVMat_correct']
|
| 31 |
+
featVMat = loader['featVMat']
|
| 32 |
+
conflicts_Dict = loader['conflicts_Dict']
|
| 33 |
+
nodelist = loader['nodelist']
|
| 34 |
+
|
| 35 |
+
return (nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 36 |
+
nodelist, conflicts_Dict, featVMat)
|
| 37 |
+
|
| 38 |
+
bucket_by_conflicting_nodes = {}
|
| 39 |
+
num_conflicting_nodes = set()
|
| 40 |
+
csv = open(base_path_csv + pred_csvs[2], 'r').readlines()
|
| 41 |
+
|
| 42 |
+
for line in range(0, len(csv), 6):
|
| 43 |
+
head_line = csv[line].strip().split(',')
|
| 44 |
+
fname = head_line[0]
|
| 45 |
+
print("Bz2 File number", fname, line / 6)
|
| 46 |
+
|
| 47 |
+
(nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 48 |
+
nodelist, conflicts_Dict, featVMat) = open_dsbz2(base_path_bz2 + fname + '.ds.bz2')
|
| 49 |
+
|
| 50 |
+
assert len(nodelist_correct) == len(conflicts_Dict_correct)
|
| 51 |
+
for node in conflicts_Dict_correct:
|
| 52 |
+
lemma = nodelist_correct[node].lemma
|
| 53 |
+
conflicting_nodes_count = len(conflicts_Dict_correct[node])
|
| 54 |
+
if conflicting_nodes_count not in num_conflicting_nodes:
|
| 55 |
+
num_conflicting_nodes.add(conflicting_nodes_count)
|
| 56 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count] = {'lemmas': set(), 'precision': [0, 0], 'recall': [0, 0]}
|
| 57 |
+
|
| 58 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['lemmas'].add(lemma)
|
| 59 |
+
|
| 60 |
+
data = csv[line + 5].strip().split(',')
|
| 61 |
+
word_recall = float(data[1]) / float(data[3])
|
| 62 |
+
lemma_recall = float(data[2]) / float(data[3])
|
| 63 |
+
word_precision = float(data[1]) / float(data[4])
|
| 64 |
+
lemma_precision = float(data[2]) / float(data[4])
|
| 65 |
+
|
| 66 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['recall'][0] += word_recall
|
| 67 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['recall'][1] += lemma_recall
|
| 68 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['precision'][0] += word_precision
|
| 69 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['precision'][1] += lemma_precision
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
with open('final_task_gaurav/bucket_by_conflicting_nodes_2.p', 'wb') as f:
|
| 73 |
+
pickle.dump(bucket_by_conflicting_nodes, f)
|
dir/bucket_by_conflicting_nodes_3.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import os
|
| 3 |
+
import bz2
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def harmonic(P, R):
|
| 7 |
+
"""For Calculation of F-Score since it is the HM of P and R"""
|
| 8 |
+
return(2 * P * R / float(P + R))
|
| 9 |
+
# Test on a couple of files
|
| 10 |
+
base_path_csv = '/home/rs/15CS91R05/gaurav/myTryouts/init_results/prediction_csvs/'
|
| 11 |
+
base_path_bz2 = '/home/rs/15CS91R05/Bishal/NewData/skt_dcs_DS.bz2_1L_bigram_heldout_dev/'
|
| 12 |
+
|
| 13 |
+
pred_csvs = os.listdir(base_path_csv)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
"""Task 5: See data from number of conflicts
|
| 17 |
+
Approach: Select a node from DCS, take count of conflicting nodes using the conflictsDict_correct
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
# Function to open bz2 files (that contains both DCS & SKT info)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def open_dsbz2(filename):
|
| 24 |
+
with bz2.BZ2File(filename, 'r') as f:
|
| 25 |
+
loader = pickle.load(f)
|
| 26 |
+
|
| 27 |
+
conflicts_Dict_correct = loader['conflicts_Dict_correct']
|
| 28 |
+
nodelist_to_correct_mapping = loader['nodelist_to_correct_mapping']
|
| 29 |
+
nodelist_correct = loader['nodelist_correct']
|
| 30 |
+
featVMat_correct = loader['featVMat_correct']
|
| 31 |
+
featVMat = loader['featVMat']
|
| 32 |
+
conflicts_Dict = loader['conflicts_Dict']
|
| 33 |
+
nodelist = loader['nodelist']
|
| 34 |
+
|
| 35 |
+
return (nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 36 |
+
nodelist, conflicts_Dict, featVMat)
|
| 37 |
+
|
| 38 |
+
bucket_by_conflicting_nodes = {}
|
| 39 |
+
num_conflicting_nodes = set()
|
| 40 |
+
csv = open(base_path_csv + pred_csvs[3], 'r').readlines()
|
| 41 |
+
|
| 42 |
+
for line in range(0, len(csv), 6):
|
| 43 |
+
head_line = csv[line].strip().split(',')
|
| 44 |
+
fname = head_line[0]
|
| 45 |
+
print("Bz2 File number", fname, line / 6)
|
| 46 |
+
|
| 47 |
+
(nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 48 |
+
nodelist, conflicts_Dict, featVMat) = open_dsbz2(base_path_bz2 + fname + '.ds.bz2')
|
| 49 |
+
|
| 50 |
+
assert len(nodelist_correct) == len(conflicts_Dict_correct)
|
| 51 |
+
for node in conflicts_Dict_correct:
|
| 52 |
+
lemma = nodelist_correct[node].lemma
|
| 53 |
+
conflicting_nodes_count = len(conflicts_Dict_correct[node])
|
| 54 |
+
if conflicting_nodes_count not in num_conflicting_nodes:
|
| 55 |
+
num_conflicting_nodes.add(conflicting_nodes_count)
|
| 56 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count] = {'lemmas': set(), 'precision': [0, 0], 'recall': [0, 0]}
|
| 57 |
+
|
| 58 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['lemmas'].add(lemma)
|
| 59 |
+
|
| 60 |
+
data = csv[line + 5].strip().split(',')
|
| 61 |
+
word_recall = float(data[1]) / float(data[3])
|
| 62 |
+
lemma_recall = float(data[2]) / float(data[3])
|
| 63 |
+
word_precision = float(data[1]) / float(data[4])
|
| 64 |
+
lemma_precision = float(data[2]) / float(data[4])
|
| 65 |
+
|
| 66 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['recall'][0] += word_recall
|
| 67 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['recall'][1] += lemma_recall
|
| 68 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['precision'][0] += word_precision
|
| 69 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['precision'][1] += lemma_precision
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
with open('final_task_gaurav/bucket_by_conflicting_nodes_3.p', 'wb') as f:
|
| 73 |
+
pickle.dump(bucket_by_conflicting_nodes, f)
|
dir/bucket_by_conflicting_nodes_4.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import os
|
| 3 |
+
import bz2
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def harmonic(P, R):
|
| 7 |
+
"""For Calculation of F-Score since it is the HM of P and R"""
|
| 8 |
+
return(2 * P * R / float(P + R))
|
| 9 |
+
# Test on a couple of files
|
| 10 |
+
base_path_csv = '/home/rs/15CS91R05/gaurav/myTryouts/init_results/prediction_csvs/'
|
| 11 |
+
base_path_bz2 = '/home/rs/15CS91R05/Bishal/NewData/skt_dcs_DS.bz2_1L_bigram_heldout_dev/'
|
| 12 |
+
|
| 13 |
+
pred_csvs = os.listdir(base_path_csv)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
"""Task 5: See data from number of conflicts
|
| 17 |
+
Approach: Select a node from DCS, take count of conflicting nodes using the conflictsDict_correct
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
# Function to open bz2 files (that contains both DCS & SKT info)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def open_dsbz2(filename):
|
| 24 |
+
with bz2.BZ2File(filename, 'r') as f:
|
| 25 |
+
loader = pickle.load(f)
|
| 26 |
+
|
| 27 |
+
conflicts_Dict_correct = loader['conflicts_Dict_correct']
|
| 28 |
+
nodelist_to_correct_mapping = loader['nodelist_to_correct_mapping']
|
| 29 |
+
nodelist_correct = loader['nodelist_correct']
|
| 30 |
+
featVMat_correct = loader['featVMat_correct']
|
| 31 |
+
featVMat = loader['featVMat']
|
| 32 |
+
conflicts_Dict = loader['conflicts_Dict']
|
| 33 |
+
nodelist = loader['nodelist']
|
| 34 |
+
|
| 35 |
+
return (nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 36 |
+
nodelist, conflicts_Dict, featVMat)
|
| 37 |
+
|
| 38 |
+
bucket_by_conflicting_nodes = {}
|
| 39 |
+
num_conflicting_nodes = set()
|
| 40 |
+
csv = open(base_path_csv + pred_csvs[4], 'r').readlines()
|
| 41 |
+
|
| 42 |
+
for line in range(0, len(csv), 6):
|
| 43 |
+
head_line = csv[line].strip().split(',')
|
| 44 |
+
fname = head_line[0]
|
| 45 |
+
print("Bz2 File number", fname, line / 6)
|
| 46 |
+
|
| 47 |
+
(nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 48 |
+
nodelist, conflicts_Dict, featVMat) = open_dsbz2(base_path_bz2 + fname + '.ds.bz2')
|
| 49 |
+
|
| 50 |
+
assert len(nodelist_correct) == len(conflicts_Dict_correct)
|
| 51 |
+
for node in conflicts_Dict_correct:
|
| 52 |
+
lemma = nodelist_correct[node].lemma
|
| 53 |
+
conflicting_nodes_count = len(conflicts_Dict_correct[node])
|
| 54 |
+
if conflicting_nodes_count not in num_conflicting_nodes:
|
| 55 |
+
num_conflicting_nodes.add(conflicting_nodes_count)
|
| 56 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count] = {'lemmas': set(), 'precision': [0, 0], 'recall': [0, 0]}
|
| 57 |
+
|
| 58 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['lemmas'].add(lemma)
|
| 59 |
+
|
| 60 |
+
data = csv[line + 5].strip().split(',')
|
| 61 |
+
word_recall = float(data[1]) / float(data[3])
|
| 62 |
+
lemma_recall = float(data[2]) / float(data[3])
|
| 63 |
+
word_precision = float(data[1]) / float(data[4])
|
| 64 |
+
lemma_precision = float(data[2]) / float(data[4])
|
| 65 |
+
|
| 66 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['recall'][0] += word_recall
|
| 67 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['recall'][1] += lemma_recall
|
| 68 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['precision'][0] += word_precision
|
| 69 |
+
bucket_by_conflicting_nodes[conflicting_nodes_count]['precision'][1] += lemma_precision
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
with open('final_task_gaurav/bucket_by_conflicting_nodes_4.p', 'wb') as f:
|
| 73 |
+
pickle.dump(bucket_by_conflicting_nodes, f)
|
dir/bz2_counter.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
folders = ['skt_dcs_DS.bz2_1L_bigram_mir_Large']
|
| 4 |
+
|
| 5 |
+
for folder in folders:
|
| 6 |
+
path = os.path.join('../NewData', folder)
|
| 7 |
+
c = len(os.listdir(path))
|
| 8 |
+
print('Folder: {:35s} ------> File_Count: {}\n'.format(folder, c))
|
dir/cliq.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dir/datainspect.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#from Train_n_Save_NNet import *
|
| 2 |
+
import bz2,pickle
|
| 3 |
+
def open_dsbz2(filename):
|
| 4 |
+
with bz2.BZ2File(filename, 'r') as f:
|
| 5 |
+
loader = pickle.load(f)
|
| 6 |
+
|
| 7 |
+
conflicts_Dict_correct = loader['conflicts_Dict_correct']
|
| 8 |
+
nodelist_to_correct_mapping = loader['nodelist_to_correct_mapping']
|
| 9 |
+
nodelist_correct = loader['nodelist_correct']
|
| 10 |
+
featVMat_correct = loader['featVMat_correct']
|
| 11 |
+
featVMat = loader['featVMat']
|
| 12 |
+
conflicts_Dict = loader['conflicts_Dict']
|
| 13 |
+
nodelist = loader['nodelist']
|
| 14 |
+
print("conflicts_Dict_correct:",)
|
| 15 |
+
print(conflicts_Dict_correct)
|
| 16 |
+
print("nodelist_to_correct_mapping: ")
|
| 17 |
+
print(nodelist_to_correct_mapping)
|
| 18 |
+
|
| 19 |
+
print("nodelist_correct")
|
| 20 |
+
nc0 = nodelist_correct[0]
|
| 21 |
+
print(type(nc0))
|
| 22 |
+
print(nodelist_correct[0])
|
| 23 |
+
|
| 24 |
+
print("featVMat_correct")
|
| 25 |
+
print(featVMat_correct[0][1][0])
|
| 26 |
+
|
| 27 |
+
print("featVMat")
|
| 28 |
+
print(featVMat[0][0])
|
| 29 |
+
|
| 30 |
+
print("conflicts_Dict")
|
| 31 |
+
print(conflicts_Dict)
|
| 32 |
+
|
| 33 |
+
print("nodelist")
|
| 34 |
+
print(nodelist)
|
| 35 |
+
#return (nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,nodelist, conflicts_Dict, featVMat)
|
| 36 |
+
|
| 37 |
+
print(open_dsbz2("100004.ds.bz2"))
|
dir/dcs_skt_bzipper.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## bUILT-iN pACKAGES
|
| 2 |
+
import sys, os, csv
|
| 3 |
+
import pickle
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
import json
|
| 6 |
+
import numpy as np
|
| 7 |
+
import math
|
| 8 |
+
np.set_printoptions(suppress=True)
|
| 9 |
+
from IPython.display import display
|
| 10 |
+
|
| 11 |
+
## lAST sUMMER
|
| 12 |
+
from romtoslp import *
|
| 13 |
+
from sentences import *
|
| 14 |
+
from DCS import *
|
| 15 |
+
import MatDB
|
| 16 |
+
import time
|
| 17 |
+
import bz2
|
| 18 |
+
import zlib
|
| 19 |
+
|
| 20 |
+
## lAST yEAR
|
| 21 |
+
# from word_definite import *
|
| 22 |
+
# from nnet import *
|
| 23 |
+
# from heap_n_PrimMST import *
|
| 24 |
+
# from word_definite import *
|
| 25 |
+
|
| 26 |
+
is10K = False
|
| 27 |
+
|
| 28 |
+
if is10K:
|
| 29 |
+
loaded_SKT = pickle.load(open('../Simultaneous_CompatSKT_10K.p', 'rb'), encoding=u'utf-8')
|
| 30 |
+
loaded_DCS = pickle.load(open('../Simultaneous_DCS_10K.p', 'rb'), encoding=u'utf-8')
|
| 31 |
+
outFolder = '../NewData/skt_dcs_DS.bz2_10K/'
|
| 32 |
+
else:
|
| 33 |
+
loaded_SKT = pickle.load(open('../Simultaneous_CompatSKT.p', 'rb'), encoding=u'utf-8')
|
| 34 |
+
loaded_DCS = pickle.load(open('../Simultaneous_DCS.p', 'rb'), encoding=u'utf-8')
|
| 35 |
+
outFolder = '../NewData/skt_dcs_DS.bz2/'
|
| 36 |
+
|
| 37 |
+
conversion_file_list = list(loaded_DCS.keys())
|
| 38 |
+
outFolder = '../NewData/skt_dcs_DS.bz2_1L_bigram_mir_Large/'
|
| 39 |
+
|
| 40 |
+
## SPECIAL - HELD OUT DATASET - OVERWRITES
|
| 41 |
+
'''
|
| 42 |
+
outFolder = '../NewData/skt_dcs_DS.bz2_1L_bigram_rfe_heldout/'
|
| 43 |
+
baseline_filelist = []
|
| 44 |
+
with open('inputs/Baseline4_advSample.csv') as f:
|
| 45 |
+
baseline_reader = csv.reader(f)
|
| 46 |
+
for line in baseline_reader:
|
| 47 |
+
baseline_filelist.append(line[1])
|
| 48 |
+
|
| 49 |
+
conversion_file_list = [f.replace('.p', '.p2') for f in baseline_filelist]
|
| 50 |
+
#'''
|
| 51 |
+
## SPECIAL CODE ENDS HERE
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
dataset_4k_1k = pickle.load(open('../SmallDataset_4K_1K.p', 'rb'))
|
| 55 |
+
TrainFiles = dataset_4k_1k['TrainFiles']
|
| 56 |
+
TestFiles = dataset_4k_1k['TestFiles']
|
| 57 |
+
|
| 58 |
+
dataset_6k_3k = pickle.load(open('../SmallDataset_6K_3K.p', 'rb'))
|
| 59 |
+
TrainFiles_2 = dataset_6k_3k['TrainFiles']
|
| 60 |
+
TestFiles_2 = dataset_6k_3k['TestFiles']
|
| 61 |
+
|
| 62 |
+
matDB = MatDB.MatDB()
|
| 63 |
+
|
| 64 |
+
# from MatDB import *
|
| 65 |
+
import word_definite as WD
|
| 66 |
+
from heap_n_PrimMST import *
|
| 67 |
+
from nnet import *
|
| 68 |
+
"""
|
| 69 |
+
################################################################################################
|
| 70 |
+
###################### CREATE SEVERAL DATA STRUCTURES FROM SENTENCE/DCS ######################
|
| 71 |
+
########################### NODELIST, ADJACENCY LIST, GRAPH, HEAP #############################
|
| 72 |
+
################################################################################################
|
| 73 |
+
"""
|
| 74 |
+
def GetTrainingKit(sentenceObj, dcsObj):
|
| 75 |
+
nodelist = GetNodes(sentenceObj)
|
| 76 |
+
|
| 77 |
+
# Nodelist with only the correct_nodes
|
| 78 |
+
nodelist2 = GetNodes(sentenceObj)
|
| 79 |
+
nodelist2_to_correct_mapping = {}
|
| 80 |
+
nodelist_correct = []
|
| 81 |
+
search_key = 0
|
| 82 |
+
first_key = 0
|
| 83 |
+
for chunk_id in range(len(dcsObj.lemmas)):
|
| 84 |
+
while nodelist2[first_key].chunk_id != chunk_id:
|
| 85 |
+
first_key += 1
|
| 86 |
+
for j in range(len(dcsObj.lemmas[chunk_id])):
|
| 87 |
+
search_key = first_key
|
| 88 |
+
while (nodelist2[search_key].lemma != rom_slp(dcsObj.lemmas[chunk_id][j])) or (nodelist2[search_key].cng != dcsObj.cng[chunk_id][j]):
|
| 89 |
+
search_key += 1
|
| 90 |
+
if search_key >= len(nodelist2) or nodelist2[search_key].chunk_id > chunk_id:
|
| 91 |
+
break
|
| 92 |
+
# print((rom_slp(dcsObj.lemmas[chunk_id][j]), dcsObj.cng[chunk_id][j]))
|
| 93 |
+
# print(nodelist[search_key])
|
| 94 |
+
nodelist2_to_correct_mapping[len(nodelist_correct)] = search_key
|
| 95 |
+
nodelist_correct.append(nodelist2[search_key])
|
| 96 |
+
return (nodelist, nodelist_correct, nodelist2_to_correct_mapping)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def GetGraph(nodelist, neuralnet):
|
| 100 |
+
if not neuralnet.outer_relu:
|
| 101 |
+
conflicts_Dict = Get_Conflicts(nodelist)
|
| 102 |
+
|
| 103 |
+
featVMat = Get_Feat_Vec_Matrix(nodelist, conflicts_Dict)
|
| 104 |
+
|
| 105 |
+
(WScalarMat, SigmoidGateOutput) = Get_W_Scalar_Matrix_from_FeatVect_Matrix(featVMat, nodelist, conflicts_Dict, neuralnet)
|
| 106 |
+
return (conflicts_Dict, featVMat, WScalarMat, SigmoidGateOutput)
|
| 107 |
+
else:
|
| 108 |
+
conflicts_Dict = Get_Conflicts(nodelist)
|
| 109 |
+
|
| 110 |
+
featVMat = Get_Feat_Vec_Matrix(nodelist, conflicts_Dict)
|
| 111 |
+
|
| 112 |
+
WScalarMat = Get_W_Scalar_Matrix_from_FeatVect_Matrix(featVMat, nodelist, conflicts_Dict, neuralnet)
|
| 113 |
+
return (conflicts_Dict, featVMat, WScalarMat)
|
| 114 |
+
|
| 115 |
+
"""
|
| 116 |
+
################################################################################################
|
| 117 |
+
############################## GET A FILENAME TO SAVE WEIGHTS ################################
|
| 118 |
+
################################################################################################
|
| 119 |
+
"""
|
| 120 |
+
trainingStatus = defaultdict(lambda: bool(False))
|
| 121 |
+
|
| 122 |
+
class Trainer:
|
| 123 |
+
def __init__(self):
|
| 124 |
+
self.hidden_layer_size = 300
|
| 125 |
+
self._edge_vector_dim = WD._edge_vector_dim
|
| 126 |
+
# self._full_cnglist = list(WD.mat_cngCount_1D)
|
| 127 |
+
self.neuralnet = NN(self._edge_vector_dim, self.hidden_layer_size, outer_relu=True)
|
| 128 |
+
self.history = defaultdict(lambda: list())
|
| 129 |
+
|
| 130 |
+
def SaveToMem(self, sentenceObj, dcsObj, _debug = True):
|
| 131 |
+
|
| 132 |
+
""" Pre-Process DCS and SKT to get all Nodes etc. """
|
| 133 |
+
try:
|
| 134 |
+
(nodelist, nodelist_correct, nodelist_to_correct_mapping) = GetTrainingKit(sentenceObj, dcsObj)
|
| 135 |
+
except IndexError as e:
|
| 136 |
+
# print('\x1b[31mError with {} \x1b[0m'.format(sentenceObj.sent_id))
|
| 137 |
+
# print(e)
|
| 138 |
+
return
|
| 139 |
+
|
| 140 |
+
# startT = time.time()
|
| 141 |
+
""" SKT FEATURE VECTOR MATRIX """
|
| 142 |
+
conflicts_Dict_correct = Get_Conflicts(nodelist_correct)
|
| 143 |
+
featVMat_correct = Get_Feat_Vec_Matrix(nodelist_correct, conflicts_Dict_correct)
|
| 144 |
+
|
| 145 |
+
""" SKT FEATURE VECTOR MATRIX """
|
| 146 |
+
conflicts_Dict = Get_Conflicts(nodelist)
|
| 147 |
+
featVMat = Get_Feat_Vec_Matrix(nodelist, conflicts_Dict)
|
| 148 |
+
# print('Nodelen: {}, Time taken to create: {}'.format(len(nodelist), time.time() - startT))
|
| 149 |
+
|
| 150 |
+
with bz2.BZ2File(outFolder + sentenceObj.sent_id + '.ds.bz2', 'w') as f:
|
| 151 |
+
pickle.dump({
|
| 152 |
+
'nodelist': nodelist,
|
| 153 |
+
'nodelist_correct': nodelist_correct,
|
| 154 |
+
'nodelist_to_correct_mapping': nodelist_to_correct_mapping,
|
| 155 |
+
'conflicts_Dict_correct': conflicts_Dict_correct,
|
| 156 |
+
'featVMat_correct': featVMat_correct,
|
| 157 |
+
'conflicts_Dict': conflicts_Dict,
|
| 158 |
+
'featVMat': featVMat
|
| 159 |
+
}, f)
|
| 160 |
+
|
| 161 |
+
trainer = None
|
| 162 |
+
def InitModule(_matDB):
|
| 163 |
+
global WD, trainer
|
| 164 |
+
_edge_vec_dim = 1500
|
| 165 |
+
WD.word_definite_extInit(_matDB, _edge_vec_dim)
|
| 166 |
+
trainer = Trainer()
|
| 167 |
+
InitModule(matDB)
|
| 168 |
+
trainingStatus = defaultdict(lambda: bool(False))
|
| 169 |
+
# trainer.Load('outputs/train_nnet_t427031523027.p')
|
| 170 |
+
|
| 171 |
+
"""
|
| 172 |
+
################################################################################################
|
| 173 |
+
############################## TRAIN FUNCTION ################################################
|
| 174 |
+
################################################################################################
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
def save_all_bz2(loaded_SKT, loaded_DCS, n_checkpt = 100):
|
| 178 |
+
file_counter = 0
|
| 179 |
+
print('{} files to process'.format(len(conversion_file_list)))
|
| 180 |
+
for fn in conversion_file_list:
|
| 181 |
+
if file_counter % n_checkpt == 0:
|
| 182 |
+
print(file_counter,' Checkpoint... ')
|
| 183 |
+
sys.stdout.flush() # Flush IO buffer
|
| 184 |
+
if os.path.isfile(outFolder + fn.replace('.p2', '.ds.bz2')):
|
| 185 |
+
print('Skipping: ', fn)
|
| 186 |
+
continue
|
| 187 |
+
try:
|
| 188 |
+
_ = trainer.SaveToMem(loaded_SKT[fn], loaded_DCS[fn])
|
| 189 |
+
except (IndexError, KeyError) as e:
|
| 190 |
+
pass
|
| 191 |
+
file_counter += 1
|
| 192 |
+
|
| 193 |
+
if not os.path.isdir(outFolder):
|
| 194 |
+
print('Creating directory: ', outFolder)
|
| 195 |
+
os.mkdir(outFolder)
|
| 196 |
+
save_all_bz2(loaded_SKT, loaded_DCS, n_checkpt=100)
|
dir/evaluate.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# coding: utf-8
|
| 3 |
+
|
| 4 |
+
# In[1]:
|
| 5 |
+
|
| 6 |
+
import pandas,sys
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# In[7]:
|
| 10 |
+
fils = {
|
| 11 |
+
'BM3' : ["BM3_NLoss_proc0.csv","BM3_NLoss_proc2.csv","BM3_NLoss_proc1.csv","BM3_NLoss_proc3.csv"],
|
| 12 |
+
'BM2' : ["BM2_NLoss_proc0.csv","BM2_NLoss_proc2.csv","BM2_NLoss_proc1.csv","BM2_NLoss_proc3.csv"],
|
| 13 |
+
'BR2' : ["BR2_NLoss_proc0.csv","BR2_NLoss_proc2.csv","BR2_NLoss_proc1.csv","BR2_NLoss_proc3.csv"],
|
| 14 |
+
'BR3' : ["BR3_NLoss_proc0.csv","BR3_NLoss_proc2.csv","BR3_NLoss_proc1.csv","BR3_NLoss_proc3.csv"],
|
| 15 |
+
'PM2' : ["PM2_NLoss_proc0.csv","PM2_NLoss_proc2.csv","PM2_NLoss_proc1.csv","PM2_NLoss_proc3.csv"],
|
| 16 |
+
'PM3' : ["PM3_NLoss_proc0.csv","PM3_NLoss_proc2.csv","PM3_NLoss_proc1.csv","PM3_NLoss_proc3.csv"],
|
| 17 |
+
'PR2' : ["PR2_NLoss_proc0.csv","PR2_NLoss_proc2.csv","PR2_NLoss_proc1.csv","PR2_NLoss_proc3.csv"],
|
| 18 |
+
'PR3' : ["PR3_NLoss_proc0.csv","PR3_NLoss_proc2.csv","PR3_NLoss_proc1.csv","PR3_NLoss_proc3.csv"]
|
| 19 |
+
}
|
| 20 |
+
import pandas
|
| 21 |
+
from collections import defaultdict
|
| 22 |
+
|
| 23 |
+
def predLoss(tag):
|
| 24 |
+
gt = defaultdict(dict)
|
| 25 |
+
|
| 26 |
+
for item in fils[tag]:
|
| 27 |
+
fil = open('outputs/'+str(item)).read().splitlines()
|
| 28 |
+
for i,line in enumerate(fil):
|
| 29 |
+
if i % 6 == 0:
|
| 30 |
+
setCol = line.split(',')
|
| 31 |
+
gt[setCol[0]]['predLemma'] = setCol[1:]
|
| 32 |
+
if i%6 == 1:
|
| 33 |
+
gt[setCol[0]]['predCNG'] = line.split(',')[1:]
|
| 34 |
+
if len(gt[setCol[0]]['predLemma']) != len(gt[setCol[0]]['predCNG']):
|
| 35 |
+
print(gt[setCol[0]])
|
| 36 |
+
if i%6 == 2:
|
| 37 |
+
gt[setCol[0]]['chunkID'] = line.split(',')[1:]
|
| 38 |
+
if len(gt[setCol[0]]['predLemma']) != len(gt[setCol[0]]['chunkID']):
|
| 39 |
+
print(gt[setCol[0]])
|
| 40 |
+
if i%6 == 3:
|
| 41 |
+
gt[setCol[0]]['chunkIDCNG'] = line.split(',')[1:]
|
| 42 |
+
if len(gt[setCol[0]]['predLemma']) != len(gt[setCol[0]]['chunkIDCNG']):
|
| 43 |
+
print(gt[setCol[0]])
|
| 44 |
+
if i%6 == 4:
|
| 45 |
+
gt[setCol[0]]['idInNodeID'] = line.split(',')[1:]
|
| 46 |
+
if len(gt[setCol[0]]['predLemma']) != len(gt[setCol[0]]['idInNodeID']):
|
| 47 |
+
print(gt[setCol[0]])
|
| 48 |
+
if i%6 == 5:
|
| 49 |
+
gt[setCol[0]]['params'] = line.split(',')[1:]
|
| 50 |
+
|
| 51 |
+
if line.split(',')[0] != setCol[0]:
|
| 52 |
+
print(i,setCol,line)
|
| 53 |
+
print('breakin')
|
| 54 |
+
break
|
| 55 |
+
return gt
|
| 56 |
+
|
| 57 |
+
def pdframe(gt):
|
| 58 |
+
params = defaultdict(dict)
|
| 59 |
+
for item in gt.keys():
|
| 60 |
+
tatkal = gt[item]['params']
|
| 61 |
+
params[item]['corrWords'],params[item]['corrLemma'] = int(tatkal[0]),int(tatkal[1])
|
| 62 |
+
params[item]['dcsSize'],params[item]['predictions'] = int(tatkal[2]),int(tatkal[3])
|
| 63 |
+
params[item]['word++Precision'] = params[item]['corrWords']*1.0/params[item]['predictions']
|
| 64 |
+
params[item]['word++Recall'] = params[item]['corrWords']*1.0/params[item]['dcsSize']
|
| 65 |
+
params[item]['wordPrecision'] = params[item]['corrLemma']*1.0/params[item]['predictions']
|
| 66 |
+
params[item]['wordRecall'] = params[item]['corrLemma']*1.0/params[item]['dcsSize']
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
initRes = pandas.DataFrame.from_dict(params,orient='index')
|
| 70 |
+
return initRes
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# In[8]:
|
| 74 |
+
|
| 75 |
+
if(len(sys.argv)<2):
|
| 76 |
+
print("Provide an argument for the feature to be evaluated")
|
| 77 |
+
|
| 78 |
+
else:
|
| 79 |
+
BM2gt = predLoss(str(sys.argv[1]))
|
| 80 |
+
BM2pd = pdframe(BM2gt)
|
| 81 |
+
print(BM2pd.mean())
|
dir/generate_dcs_and_skt_csv.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
base_path_bz2 = '/home/rs/15CS91R05/Bishal/NewData/skt_dcs_DS.bz2_1L_bigram_heldout_dev/'
|
| 2 |
+
|
| 3 |
+
import operator
|
| 4 |
+
import bz2
|
| 5 |
+
import os
|
| 6 |
+
import pickle
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# Function to open bz2 files (that contains both DCS & SKT info)
|
| 10 |
+
def open_dsbz2(filename):
|
| 11 |
+
with bz2.BZ2File(filename, 'r') as f:
|
| 12 |
+
loader = pickle.load(f)
|
| 13 |
+
|
| 14 |
+
conflicts_Dict_correct = loader['conflicts_Dict_correct']
|
| 15 |
+
nodelist_to_correct_mapping = loader['nodelist_to_correct_mapping']
|
| 16 |
+
nodelist_correct = loader['nodelist_correct']
|
| 17 |
+
featVMat_correct = loader['featVMat_correct']
|
| 18 |
+
featVMat = loader['featVMat']
|
| 19 |
+
conflicts_Dict = loader['conflicts_Dict']
|
| 20 |
+
nodelist = loader['nodelist']
|
| 21 |
+
|
| 22 |
+
return (nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 23 |
+
nodelist, conflicts_Dict, featVMat)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
bz2_files = os.listdir(base_path_bz2)
|
| 27 |
+
dcs_heldout_csv = ''
|
| 28 |
+
skt_heldout_csv = ''
|
| 29 |
+
count = 0
|
| 30 |
+
for ds in bz2_files:
|
| 31 |
+
|
| 32 |
+
(nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 33 |
+
nodelist, conflicts_Dict, featVMat) = open_dsbz2(base_path_bz2 + ds)
|
| 34 |
+
|
| 35 |
+
fname = ds.replace('.ds.bz2', '')
|
| 36 |
+
print(count, fname)
|
| 37 |
+
count += 1
|
| 38 |
+
|
| 39 |
+
lemmas = ''
|
| 40 |
+
cngs = ''
|
| 41 |
+
|
| 42 |
+
lemmas_dcs = ''
|
| 43 |
+
cngs_dcs = ''
|
| 44 |
+
|
| 45 |
+
for node in nodelist:
|
| 46 |
+
lemmas += ',' + node.lemma
|
| 47 |
+
cngs += ',' + node.cng
|
| 48 |
+
lemmas += '\n'
|
| 49 |
+
cngs += '\n'
|
| 50 |
+
|
| 51 |
+
for node in nodelist_correct:
|
| 52 |
+
lemmas_dcs += ',' + node.lemma
|
| 53 |
+
cngs_dcs += ',' + node.cng
|
| 54 |
+
lemmas_dcs += '\n'
|
| 55 |
+
cngs_dcs += '\n'
|
| 56 |
+
|
| 57 |
+
entry = fname + lemmas + fname + cngs
|
| 58 |
+
entry_dcs = fname + lemmas_dcs + fname + cngs_dcs
|
| 59 |
+
|
| 60 |
+
dcs_heldout_csv += entry_dcs
|
| 61 |
+
skt_heldout_csv += entry
|
| 62 |
+
|
| 63 |
+
with open("final_task_gaurav/dcs_heldout.csv", "w") as f:
|
| 64 |
+
f.write(dcs_heldout_csv)
|
| 65 |
+
with open("final_task_gaurav/skt_heldout.csv", "w") as f:
|
| 66 |
+
f.write(skt_heldout_csv)
|
dir/gt2.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import csv,pickle,json,bz2
|
| 4 |
+
from romtoslp import *
|
| 5 |
+
loaded_DCS = pickle.load(open('../Simultaneous_DCS_ho.p', 'rb'))
|
| 6 |
+
folder = '../NewData/skt_dcs_DS.bz2_4K_bigram_mir_heldout/'
|
| 7 |
+
|
| 8 |
+
def open_dsbz2(filename):
|
| 9 |
+
with bz2.BZ2File(filename, 'r') as f:
|
| 10 |
+
loader = pickle.load(f)
|
| 11 |
+
|
| 12 |
+
conflicts_Dict_correct = loader['conflicts_Dict_correct']
|
| 13 |
+
nodelist_to_correct_mapping = loader['nodelist_to_correct_mapping']
|
| 14 |
+
nodelist_correct = loader['nodelist_correct']
|
| 15 |
+
featVMat_correct = loader['featVMat_correct']
|
| 16 |
+
featVMat = loader['featVMat']
|
| 17 |
+
conflicts_Dict = loader['conflicts_Dict']
|
| 18 |
+
nodelist = loader['nodelist']
|
| 19 |
+
|
| 20 |
+
return (nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 21 |
+
nodelist, conflicts_Dict, featVMat)
|
| 22 |
+
|
| 23 |
+
#snippet for forming the groundtruth csv file
|
| 24 |
+
with open('groundtruth2.csv','w') as fh:
|
| 25 |
+
rd = csv.writer(fh)
|
| 26 |
+
rd.writerow(['File','Lemma','CNG','lemmaCorr','lemmaCNGcorr','predCNG','Conflicts'])
|
| 27 |
+
count=0
|
| 28 |
+
for ii in range(4):
|
| 29 |
+
with open("BM2_NLoss_proc"+str(ii)+".csv",'r') as fh:
|
| 30 |
+
rd = csv.reader(fh)
|
| 31 |
+
while(True):
|
| 32 |
+
try:
|
| 33 |
+
print(count)
|
| 34 |
+
count+=1
|
| 35 |
+
x=next(rd) #predicted lemmas
|
| 36 |
+
sentid = x[0]
|
| 37 |
+
dcsobj = loaded_DCS[str(sentid)+'.p2']
|
| 38 |
+
# print(dcsobj.cng)
|
| 39 |
+
# print(dcsobj.lemmas)
|
| 40 |
+
# print(dcsobj.dcs_chunks)
|
| 41 |
+
nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 42 |
+
nodelist, conflicts_Dict, featVMat = open_dsbz2(folder+str(sentid)+'.ds.bz2')
|
| 43 |
+
# print(conflicts_Dict_correct)
|
| 44 |
+
# print(nodelist_correct)
|
| 45 |
+
# break
|
| 46 |
+
dll = 0
|
| 47 |
+
for i in dcsobj.lemmas:
|
| 48 |
+
dll+=len(i)
|
| 49 |
+
if(dll!=len(nodelist_correct)):
|
| 50 |
+
print('here')
|
| 51 |
+
print(dcsobj.lemmas)
|
| 52 |
+
print(nodelist_correct)
|
| 53 |
+
gtlemmas = []
|
| 54 |
+
for outerlist in dcsobj.lemmas:
|
| 55 |
+
for element in outerlist:
|
| 56 |
+
gtlemmas.append(rom_slp(element))
|
| 57 |
+
pdlemmas = x[1:]
|
| 58 |
+
|
| 59 |
+
x=next(rd) #predicted cngs
|
| 60 |
+
gtcngs = []
|
| 61 |
+
i = 0
|
| 62 |
+
for outerlist in dcsobj.cng:
|
| 63 |
+
for element in outerlist:
|
| 64 |
+
gtcngs.append((element,len(conflicts_Dict_correct[i])))
|
| 65 |
+
i+=1
|
| 66 |
+
pdcngs = x[1:]
|
| 67 |
+
for i in range(4):
|
| 68 |
+
x=(next(rd))
|
| 69 |
+
# print(gtlemmas)
|
| 70 |
+
# print(pdlemmas)
|
| 71 |
+
# print(gtcngs)
|
| 72 |
+
# print(pdcngs)
|
| 73 |
+
pdldict = dict()
|
| 74 |
+
gtldict = dict()
|
| 75 |
+
for i in range(len(gtlemmas)):
|
| 76 |
+
if(gtlemmas[i] in gtldict):
|
| 77 |
+
gtldict[gtlemmas[i]].append(gtcngs[i])
|
| 78 |
+
else:
|
| 79 |
+
gtldict[gtlemmas[i]] = [gtcngs[i]]
|
| 80 |
+
|
| 81 |
+
for i in range(len(pdlemmas)):
|
| 82 |
+
if(pdlemmas[i] in pdldict):
|
| 83 |
+
pdldict[pdlemmas[i]].append(pdcngs[i])
|
| 84 |
+
else:
|
| 85 |
+
pdldict[pdlemmas[i]] = [pdcngs[i]]
|
| 86 |
+
|
| 87 |
+
# print(gtldict)
|
| 88 |
+
# print(gtldict)
|
| 89 |
+
|
| 90 |
+
lemmaround2 = []
|
| 91 |
+
cnground2 = []
|
| 92 |
+
for gtl in gtldict.keys():
|
| 93 |
+
for gtlcng in gtldict[gtl]:
|
| 94 |
+
lemmacorr = 0
|
| 95 |
+
lemmaCNGcorr=0
|
| 96 |
+
predictedcng = 'nil'
|
| 97 |
+
confcount = gtlcng[1]
|
| 98 |
+
gtlcng = gtlcng[0]
|
| 99 |
+
if(gtl in pdldict.keys()):
|
| 100 |
+
if(len(pdldict[gtl])>0):
|
| 101 |
+
if(gtlcng in pdldict[gtl]):
|
| 102 |
+
lemmacorr = 1
|
| 103 |
+
predictedcng = gtlcng
|
| 104 |
+
lemmaCNGcorr = 1
|
| 105 |
+
pdldict[gtl].remove(gtlcng)
|
| 106 |
+
with open('groundtruth2.csv','a') as fh:
|
| 107 |
+
rwd = csv.writer(fh)
|
| 108 |
+
row = [sentid,gtl,gtlcng,lemmacorr,lemmaCNGcorr,gtlcng,confcount]
|
| 109 |
+
rwd.writerow(row)
|
| 110 |
+
else:
|
| 111 |
+
lemmaround2.append(gtl)
|
| 112 |
+
cnground2.append((gtlcng,confcount))
|
| 113 |
+
else:
|
| 114 |
+
with open('groundtruth2.csv','a') as fh:
|
| 115 |
+
rwd = csv.writer(fh)
|
| 116 |
+
row = [sentid,gtl,gtlcng,lemmacorr,lemmaCNGcorr,predictedcng,confcount]
|
| 117 |
+
rwd.writerow(row)
|
| 118 |
+
else:
|
| 119 |
+
with open('groundtruth2.csv','a') as fh:
|
| 120 |
+
rwd = csv.writer(fh)
|
| 121 |
+
row = [sentid,gtl,gtlcng,lemmacorr,lemmaCNGcorr,predictedcng,confcount]
|
| 122 |
+
rwd.writerow(row)
|
| 123 |
+
# now all elements with lemmaCNGcorr ==1 are out of the way
|
| 124 |
+
# reiterating for the lemmas which didnt have a cng but had a lemma earlier
|
| 125 |
+
for i in range(len(lemmaround2)):
|
| 126 |
+
gtl = lemmaround2[i]
|
| 127 |
+
gtlcng = cnground2[i]
|
| 128 |
+
confcount = gtlcng[1]
|
| 129 |
+
gtlcng = gtlcng[0]
|
| 130 |
+
lemmacorr = 0
|
| 131 |
+
lemmaCNGcorr = 0
|
| 132 |
+
predictedcng = 'nil'
|
| 133 |
+
if(gtl in pdldict.keys()):
|
| 134 |
+
if(len(pdldict[gtl])>0):
|
| 135 |
+
lemmacorr = 1
|
| 136 |
+
predictedcng = pdldict[gtl][0]
|
| 137 |
+
pdldict[gtl].remove(pdldict[gtl][0])
|
| 138 |
+
with open('groundtruth2.csv','a') as fh:
|
| 139 |
+
rwd = csv.writer(fh)
|
| 140 |
+
row = [sentid,gtl,gtlcng,lemmacorr,lemmaCNGcorr,predictedcng,confcount]
|
| 141 |
+
rwd.writerow(row)
|
| 142 |
+
# print('done here')
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(e)
|
| 145 |
+
print('been there')
|
| 146 |
+
# break
|
| 147 |
+
continue
|
| 148 |
+
|
dir/heap_n_PrimMST.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from word_definite import *
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
def Parent(i):
|
| 6 |
+
return max(0, math.floor((i - 1)/2))
|
| 7 |
+
|
| 8 |
+
def Left(i):
|
| 9 |
+
return 2*i + 1
|
| 10 |
+
|
| 11 |
+
def Right(i):
|
| 12 |
+
return 2*(i + 1)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
################################################################################################
|
| 17 |
+
######################## NOMINAL NODE CLASS REQUIRED FOR USING ################################
|
| 18 |
+
######################### WITH THE HEAP DATA STRUCTURE #######################################
|
| 19 |
+
################################################################################################
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
class Node:
|
| 23 |
+
def __init__(self, id, dist):
|
| 24 |
+
self.dist = dist
|
| 25 |
+
self.id = id
|
| 26 |
+
self.isConflicted = False
|
| 27 |
+
self.src = -1
|
| 28 |
+
|
| 29 |
+
"""
|
| 30 |
+
################################################################################################
|
| 31 |
+
############################ IMPLEMENTATION OF HEAP ##########################################
|
| 32 |
+
################################################################################################
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
class Heap:
|
| 36 |
+
# It's a minHeap
|
| 37 |
+
# Nodes are of type Word_definite
|
| 38 |
+
def __init__(self, nodeList):
|
| 39 |
+
self.nodeList = [n for n in nodeList]
|
| 40 |
+
self.len = len(nodeList)
|
| 41 |
+
self.idLocator = {}
|
| 42 |
+
for i in range(self.len):
|
| 43 |
+
self.idLocator[nodeList[i].id] = i
|
| 44 |
+
self.Build()
|
| 45 |
+
|
| 46 |
+
def Exchange(self, i, j):
|
| 47 |
+
t = self.nodeList[i]
|
| 48 |
+
self.nodeList[i] = self.nodeList[j]
|
| 49 |
+
self.nodeList[j] = t
|
| 50 |
+
self.idLocator[self.nodeList[i].id] = i
|
| 51 |
+
self.idLocator[self.nodeList[j].id] = j
|
| 52 |
+
|
| 53 |
+
def Decrease_Key(self, node, newDist, src):
|
| 54 |
+
if node.isConflicted:
|
| 55 |
+
return
|
| 56 |
+
i = self.idLocator[node.id]
|
| 57 |
+
if newDist > node.dist:
|
| 58 |
+
# relaxation not possible
|
| 59 |
+
return
|
| 60 |
+
else:
|
| 61 |
+
node.dist = newDist
|
| 62 |
+
node.src = src
|
| 63 |
+
parent = Parent(i)
|
| 64 |
+
while ((i > 0) and (self.nodeList[parent].dist > self.nodeList[i].dist)):
|
| 65 |
+
self.Exchange(i, parent)
|
| 66 |
+
i = parent
|
| 67 |
+
parent = Parent(i)
|
| 68 |
+
|
| 69 |
+
def Pop(self):
|
| 70 |
+
if(self.len == 0):
|
| 71 |
+
return None
|
| 72 |
+
if(self.nodeList[0].isConflicted):
|
| 73 |
+
# print("Pop has seen conflict!!!")
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
# Remove the entry from the top of the heap
|
| 77 |
+
nMin = self.nodeList[0]
|
| 78 |
+
self.idLocator[self.nodeList[0].id] = -1
|
| 79 |
+
|
| 80 |
+
# Put the last node on top of heap and heapify
|
| 81 |
+
self.nodeList[0] = self.nodeList[self.len - 1]
|
| 82 |
+
self.idLocator[self.nodeList[0].id] = 0
|
| 83 |
+
self.len -= 1
|
| 84 |
+
self.Min_Heapify(0)
|
| 85 |
+
return nMin
|
| 86 |
+
|
| 87 |
+
def Min_Heapify(self, i):
|
| 88 |
+
nMin = self.nodeList[i]
|
| 89 |
+
li = Left(i)
|
| 90 |
+
if(li < self.len):
|
| 91 |
+
if(self.nodeList[li].dist < nMin.dist):
|
| 92 |
+
nMin = self.nodeList[li]
|
| 93 |
+
min_i = li
|
| 94 |
+
ri = Right(i)
|
| 95 |
+
if(ri < self.len):
|
| 96 |
+
if(self.nodeList[ri].dist < nMin.dist):
|
| 97 |
+
nMin = self.nodeList[ri]
|
| 98 |
+
min_i = ri
|
| 99 |
+
if(nMin.id != self.nodeList[i].id):
|
| 100 |
+
self.Exchange(i, min_i)
|
| 101 |
+
self.Min_Heapify(min_i)
|
| 102 |
+
|
| 103 |
+
def Delete(self, node):
|
| 104 |
+
i = self.idLocator[node.id]
|
| 105 |
+
self.nodeList[i].isConflicted = True
|
| 106 |
+
self.nodeList[i].dist = np.inf
|
| 107 |
+
self.Min_Heapify(i)
|
| 108 |
+
|
| 109 |
+
def Build(self):
|
| 110 |
+
self.len = len(self.nodeList)
|
| 111 |
+
for i in range(int(Parent(self.len - 1)) + 1):
|
| 112 |
+
self.Min_Heapify(i)
|
| 113 |
+
|
| 114 |
+
def Print(self):
|
| 115 |
+
i = 0
|
| 116 |
+
level = 1
|
| 117 |
+
ilimit = 0
|
| 118 |
+
while(i < self.len):
|
| 119 |
+
print('N(%d, %2.1f)' % (self.nodeList[i].id, self.nodeList[i].dist), end = ' ')
|
| 120 |
+
i += 1
|
| 121 |
+
if(i > ilimit):
|
| 122 |
+
print('\n')
|
| 123 |
+
level *= 2
|
| 124 |
+
ilimit += level
|
| 125 |
+
|
| 126 |
+
"""
|
| 127 |
+
################################################################################################
|
| 128 |
+
###################### IMPLEMENTATION OF PRIM'S ALGO FOR FINDING MST ##########################
|
| 129 |
+
############################# USES HEAP DEFINED ABOVE ########################################
|
| 130 |
+
################################################################################################
|
| 131 |
+
"""
|
| 132 |
+
def MST(nodelist, WScalarMat, conflicts_Dict, source):
|
| 133 |
+
# WTF Dude!!! This function should not be used... It is running Prim's on a directed graph!!!
|
| 134 |
+
# Doesn't return MST
|
| 135 |
+
mst_adj_graph = np.ndarray(WScalarMat.shape, np.bool)*False
|
| 136 |
+
# Reset nodes and put ids
|
| 137 |
+
for id in range(len(nodelist)):
|
| 138 |
+
nodelist[id].id = id
|
| 139 |
+
nodelist[id].dist = np.inf
|
| 140 |
+
nodelist[id].isConflicted = False
|
| 141 |
+
nodelist[id].src = -1
|
| 142 |
+
|
| 143 |
+
# Initialize Graph and min-Heap
|
| 144 |
+
nodelist[source].dist = 0
|
| 145 |
+
for neighbour in range(len(nodelist)):
|
| 146 |
+
if neighbour != source:
|
| 147 |
+
nodelist[neighbour].dist = WScalarMat[source][neighbour]
|
| 148 |
+
nodelist[neighbour].src = source
|
| 149 |
+
h = Heap(nodelist)
|
| 150 |
+
|
| 151 |
+
mst_nodes = defaultdict(lambda: [])
|
| 152 |
+
mst_nodes_bool = np.array([False]*len(nodelist))
|
| 153 |
+
# Run MST only until first conflicting node is seen
|
| 154 |
+
# Conflicting node will have np.inf as dist
|
| 155 |
+
while True:
|
| 156 |
+
nextNode = h.Pop()
|
| 157 |
+
if nextNode == None:
|
| 158 |
+
break
|
| 159 |
+
# print(nextNode.src, nextNode.id, nextNode)
|
| 160 |
+
mst_nodes_bool[nextNode.id] = True
|
| 161 |
+
mst_nodes[nextNode.chunk_id].append(nextNode)
|
| 162 |
+
if nextNode.src != -1:
|
| 163 |
+
mst_adj_graph[nextNode.src, nextNode.id] = True
|
| 164 |
+
# mst_adj_graph[nextNode.id, nextNode.src] = True
|
| 165 |
+
nid = nextNode.id
|
| 166 |
+
for conId in conflicts_Dict[nid]:
|
| 167 |
+
h.Delete(nodelist[conId])
|
| 168 |
+
for neighbour in range(len(nodelist)):
|
| 169 |
+
if neighbour != nextNode.id:
|
| 170 |
+
h.Decrease_Key(nodelist[neighbour], WScalarMat[nextNode.id][neighbour], nextNode.id)
|
| 171 |
+
mst_nodes = dict(mst_nodes)
|
| 172 |
+
|
| 173 |
+
return (mst_nodes, mst_adj_graph, mst_nodes_bool)
|
| 174 |
+
|
| 175 |
+
def RandomST_GoldOnly(nodelist, WScalarMat, conflicts_Dict, source):
|
| 176 |
+
(mst_nodes, mst_adj_graph, mst_nodes_bool) = MST(nodelist, WScalarMat, conflicts_Dict, source)
|
| 177 |
+
|
| 178 |
+
mst_adj_graph = np.zeros_like(mst_adj_graph)
|
| 179 |
+
nodelen = len(nodelist)
|
| 180 |
+
|
| 181 |
+
## Random MST
|
| 182 |
+
free_set = list(range(nodelen))
|
| 183 |
+
full_set = list(range(nodelen))
|
| 184 |
+
st_set = []
|
| 185 |
+
start_node = np.random.randint(nodelen)
|
| 186 |
+
st_set.append(start_node)
|
| 187 |
+
free_set.remove(start_node)
|
| 188 |
+
for x in range(nodelen - 1):
|
| 189 |
+
a = st_set[np.random.randint(len(st_set))]
|
| 190 |
+
b = free_set[np.random.randint(len(free_set))]
|
| 191 |
+
if b not in st_set:
|
| 192 |
+
st_set.append(b)
|
| 193 |
+
free_set.remove(b)
|
| 194 |
+
mst_adj_graph[a, b] = 1
|
| 195 |
+
# mst_adj_graph[b, a] = 1 # Directed Spanning tree
|
| 196 |
+
|
| 197 |
+
return (mst_nodes, mst_adj_graph, mst_nodes_bool)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def GetMSTWeight(mst_adj_graph, WScalarMat):
|
| 201 |
+
return np.sum(WScalarMat[mst_adj_graph])
|
dir/heap_n_clique.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from word_definite import *
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
def Parent(i):
|
| 6 |
+
return max(0, math.floor((i - 1)/2))
|
| 7 |
+
|
| 8 |
+
def Left(i):
|
| 9 |
+
return 2*i + 1
|
| 10 |
+
|
| 11 |
+
def Right(i):
|
| 12 |
+
return 2*(i + 1)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
################################################################################################
|
| 17 |
+
######################## NOMINAL NODE CLASS REQUIRED FOR USING ################################
|
| 18 |
+
######################### WITH THE HEAP DATA STRUCTURE #######################################
|
| 19 |
+
################################################################################################
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
class Node:
|
| 23 |
+
def __init__(self, id, dist):
|
| 24 |
+
self.dist = dist
|
| 25 |
+
self.id = id
|
| 26 |
+
self.isConflicted = False
|
| 27 |
+
self.src = -1
|
| 28 |
+
|
| 29 |
+
"""
|
| 30 |
+
################################################################################################
|
| 31 |
+
############################ IMPLEMENTATION OF HEAP ##########################################
|
| 32 |
+
################################################################################################
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
class Heap:
|
| 36 |
+
# It's a minHeap
|
| 37 |
+
# Nodes are of type Word_definite
|
| 38 |
+
def __init__(self, nodeList):
|
| 39 |
+
self.nodeList = [n for n in nodeList]
|
| 40 |
+
self.len = len(nodeList)
|
| 41 |
+
self.idLocator = {}
|
| 42 |
+
for i in range(self.len):
|
| 43 |
+
self.idLocator[nodeList[i].id] = i
|
| 44 |
+
self.Build()
|
| 45 |
+
|
| 46 |
+
def Exchange(self, i, j):
|
| 47 |
+
t = self.nodeList[i]
|
| 48 |
+
self.nodeList[i] = self.nodeList[j]
|
| 49 |
+
self.nodeList[j] = t
|
| 50 |
+
self.idLocator[self.nodeList[i].id] = i
|
| 51 |
+
self.idLocator[self.nodeList[j].id] = j
|
| 52 |
+
|
| 53 |
+
def Decrease_Key(self, node, newDist, src):
|
| 54 |
+
if node.isConflicted:
|
| 55 |
+
return
|
| 56 |
+
i = self.idLocator[node.id]
|
| 57 |
+
if newDist > node.dist:
|
| 58 |
+
# relaxation not possible
|
| 59 |
+
return
|
| 60 |
+
else:
|
| 61 |
+
node.dist = newDist
|
| 62 |
+
node.src = src
|
| 63 |
+
parent = Parent(i)
|
| 64 |
+
while ((i > 0) and (self.nodeList[parent].dist > self.nodeList[i].dist)):
|
| 65 |
+
self.Exchange(i, parent)
|
| 66 |
+
i = parent
|
| 67 |
+
parent = Parent(i)
|
| 68 |
+
|
| 69 |
+
def Pop(self):
|
| 70 |
+
if(self.len == 0):
|
| 71 |
+
return None
|
| 72 |
+
if(self.nodeList[0].isConflicted):
|
| 73 |
+
# print("Pop has seen conflict!!!")
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
# Remove the entry from the top of the heap
|
| 77 |
+
nMin = self.nodeList[0]
|
| 78 |
+
self.idLocator[self.nodeList[0].id] = -1
|
| 79 |
+
|
| 80 |
+
# Put the last node on top of heap and heapify
|
| 81 |
+
self.nodeList[0] = self.nodeList[self.len - 1]
|
| 82 |
+
self.idLocator[self.nodeList[0].id] = 0
|
| 83 |
+
self.len -= 1
|
| 84 |
+
self.Min_Heapify(0)
|
| 85 |
+
return nMin
|
| 86 |
+
|
| 87 |
+
def Min_Heapify(self, i):
|
| 88 |
+
nMin = self.nodeList[i]
|
| 89 |
+
li = Left(i)
|
| 90 |
+
if(li < self.len):
|
| 91 |
+
if(self.nodeList[li].dist < nMin.dist):
|
| 92 |
+
nMin = self.nodeList[li]
|
| 93 |
+
min_i = li
|
| 94 |
+
ri = Right(i)
|
| 95 |
+
if(ri < self.len):
|
| 96 |
+
if(self.nodeList[ri].dist < nMin.dist):
|
| 97 |
+
nMin = self.nodeList[ri]
|
| 98 |
+
min_i = ri
|
| 99 |
+
if(nMin.id != self.nodeList[i].id):
|
| 100 |
+
self.Exchange(i, min_i)
|
| 101 |
+
self.Min_Heapify(min_i)
|
| 102 |
+
|
| 103 |
+
def Delete(self, node):
|
| 104 |
+
i = self.idLocator[node.id]
|
| 105 |
+
self.nodeList[i].isConflicted = True
|
| 106 |
+
self.nodeList[i].dist = np.inf
|
| 107 |
+
self.Min_Heapify(i)
|
| 108 |
+
|
| 109 |
+
def Build(self):
|
| 110 |
+
self.len = len(self.nodeList)
|
| 111 |
+
for i in range(int(Parent(self.len - 1)) + 1):
|
| 112 |
+
self.Min_Heapify(i)
|
| 113 |
+
|
| 114 |
+
def Print(self):
|
| 115 |
+
i = 0
|
| 116 |
+
level = 1
|
| 117 |
+
ilimit = 0
|
| 118 |
+
while(i < self.len):
|
| 119 |
+
print('N(%d, %2.1f)' % (self.nodeList[i].id, self.nodeList[i].dist), end = ' ')
|
| 120 |
+
i += 1
|
| 121 |
+
if(i > ilimit):
|
| 122 |
+
print('\n')
|
| 123 |
+
level *= 2
|
| 124 |
+
ilimit += level
|
| 125 |
+
|
| 126 |
+
"""
|
| 127 |
+
################################################################################################
|
| 128 |
+
###################### IMPLEMENTATION OF PRIM'S ALGO FOR FINDING MST ##########################
|
| 129 |
+
############################# USES HEAP DEFINED ABOVE ########################################
|
| 130 |
+
################################################################################################
|
| 131 |
+
"""
|
| 132 |
+
def MST(nodelist, WScalarMat, conflicts_Dict, source):
|
| 133 |
+
# WTF Dude!!! This function should not be used... It is running Prim's on a directed graph!!!
|
| 134 |
+
# Doesn't return MST
|
| 135 |
+
mst_adj_graph = np.ndarray(WScalarMat.shape, np.bool)*False
|
| 136 |
+
# print(len(nodelist))
|
| 137 |
+
# Reset nodes and put ids
|
| 138 |
+
for id in range(len(nodelist)):
|
| 139 |
+
nodelist[id].id = id
|
| 140 |
+
nodelist[id].dist = np.inf
|
| 141 |
+
nodelist[id].isConflicted = False
|
| 142 |
+
nodelist[id].src = -1
|
| 143 |
+
|
| 144 |
+
# Initialize Graph and min-Heap
|
| 145 |
+
nodelist[source].dist = 0
|
| 146 |
+
for neighbour in range(len(nodelist)):
|
| 147 |
+
if neighbour != source:
|
| 148 |
+
nodelist[neighbour].dist = WScalarMat[source][neighbour]
|
| 149 |
+
nodelist[neighbour].src = source
|
| 150 |
+
h = Heap(nodelist)
|
| 151 |
+
|
| 152 |
+
mst_nodes = defaultdict(lambda: [])
|
| 153 |
+
mst_nodes_bool = np.array([False]*len(nodelist))
|
| 154 |
+
# Run MST only until first conflicting node is seen
|
| 155 |
+
# Conflicting node will have np.inf as dist
|
| 156 |
+
while True:
|
| 157 |
+
nextNode = h.Pop()
|
| 158 |
+
if nextNode == None:
|
| 159 |
+
break
|
| 160 |
+
print("next-id:"+str(nextNode.id))
|
| 161 |
+
print('picked by '+str(nodelist[nextNode.id].dist))
|
| 162 |
+
print()
|
| 163 |
+
# print(nextNode.src, nextNode.id, nextNode)
|
| 164 |
+
mst_nodes_bool[nextNode.id] = True
|
| 165 |
+
mst_nodes[nextNode.chunk_id].append(nextNode)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
if nextNode.src != -1:
|
| 169 |
+
mst_adj_graph[nextNode.src, nextNode.id] = True
|
| 170 |
+
# mst_adj_graph[nextNode.id, nextNode.src] = True
|
| 171 |
+
nid = nextNode.id
|
| 172 |
+
for conId in conflicts_Dict[nid]:
|
| 173 |
+
h.Delete(nodelist[conId])
|
| 174 |
+
for neighbour in range(len(nodelist)):
|
| 175 |
+
if neighbour != nextNode.id:
|
| 176 |
+
print(WScalarMat[nextNode.id][neighbour])
|
| 177 |
+
print(nodelist[neighbour].dist)
|
| 178 |
+
h.Decrease_Key(nodelist[neighbour], WScalarMat[nextNode.id][neighbour], nextNode.id)
|
| 179 |
+
|
| 180 |
+
print(mst_nodes_bool)
|
| 181 |
+
# print(mst_nodes_bool)
|
| 182 |
+
print('#'*30)
|
| 183 |
+
mst_nodes = dict(mst_nodes)
|
| 184 |
+
|
| 185 |
+
return (mst_nodes, mst_adj_graph, mst_nodes_bool)
|
| 186 |
+
|
| 187 |
+
def clique(nodelist, WScalarMat, conflicts_Dict, source):
|
| 188 |
+
# WTF Dude!!! This function should not be used... It is running Prim's on a directed graph!!!
|
| 189 |
+
# Doesn't return MST
|
| 190 |
+
mst_adj_graph = np.ndarray(WScalarMat.shape, bool) * False
|
| 191 |
+
# print(len(nodelist))
|
| 192 |
+
# Reset nodes and put ids
|
| 193 |
+
# print('node-ids')
|
| 194 |
+
for id in range(len(nodelist)):
|
| 195 |
+
# print(id)
|
| 196 |
+
nodelist[id].id = id
|
| 197 |
+
nodelist[id].dist = np.inf
|
| 198 |
+
nodelist[id].isConflicted = False
|
| 199 |
+
nodelist[id].src = -1
|
| 200 |
+
# print('*'*40)
|
| 201 |
+
# Initialize Graph and min-Heap
|
| 202 |
+
nodelist[source].dist = 0
|
| 203 |
+
|
| 204 |
+
nodeset=set()
|
| 205 |
+
for neighbour in range(len(nodelist)):
|
| 206 |
+
if neighbour != source:
|
| 207 |
+
nodelist[neighbour].dist = WScalarMat[source][neighbour]
|
| 208 |
+
nodelist[neighbour].src = source
|
| 209 |
+
# nodeset.add((nodelist[neighbour].dist,neighbour))
|
| 210 |
+
|
| 211 |
+
# nodeset = sorted(nodeset)
|
| 212 |
+
nodeset.add((0,source))
|
| 213 |
+
nodesadded=[]
|
| 214 |
+
nodesavailable = np.zeros(len(nodelist),dtype=int) # o if available, 1 if not available
|
| 215 |
+
|
| 216 |
+
mst_nodes = defaultdict(lambda: [])
|
| 217 |
+
mst_nodes_bool = np.array([False]*len(nodelist))
|
| 218 |
+
# Run MST only until first conflicting node is seen
|
| 219 |
+
# Conflicting node will have np.inf as dist
|
| 220 |
+
|
| 221 |
+
it=0
|
| 222 |
+
nextNode=-1
|
| 223 |
+
while True:
|
| 224 |
+
# print(nodeset)
|
| 225 |
+
it+=1
|
| 226 |
+
# print(it)
|
| 227 |
+
if(it>1000):
|
| 228 |
+
break
|
| 229 |
+
if(len(nodeset)==0):
|
| 230 |
+
break
|
| 231 |
+
# print('before nn assign: ')
|
| 232 |
+
# print(nextNode)
|
| 233 |
+
nextNode = next(iter(nodeset))
|
| 234 |
+
# print("after nn assign:")
|
| 235 |
+
# print(nextNode)
|
| 236 |
+
# print("Nextnode is :"+str(nextNode[1])+" Picked by :"+str(nextNode[0]))
|
| 237 |
+
nextNode=nodelist[nextNode[1]]
|
| 238 |
+
# print(type(nextNode))
|
| 239 |
+
# print(st_setr(nextNode.id)+"",)
|
| 240 |
+
|
| 241 |
+
# print(nextNode.id)
|
| 242 |
+
nodesavailable[nextNode.id]=1
|
| 243 |
+
# nodesavailable=1
|
| 244 |
+
if nextNode == None:
|
| 245 |
+
break
|
| 246 |
+
# print(nextNode.src, nextNode.id, nextNode)
|
| 247 |
+
mst_nodes_bool[nextNode.id] = True
|
| 248 |
+
mst_nodes[nextNode.chunk_id].append(nextNode)
|
| 249 |
+
|
| 250 |
+
nodeset = set()
|
| 251 |
+
|
| 252 |
+
if nextNode.src != -1:
|
| 253 |
+
mst_adj_graph[nextNode.src, nextNode.id] = True
|
| 254 |
+
# mst_adj_graph[nextNode.id, nextNode.src] = True
|
| 255 |
+
|
| 256 |
+
nid = nextNode.id
|
| 257 |
+
nodesadded.append(nid)
|
| 258 |
+
for conId in conflicts_Dict[nid]:
|
| 259 |
+
# h.Delete(nodelist[conId])
|
| 260 |
+
nodesavailable[conId]=1
|
| 261 |
+
# print('here')
|
| 262 |
+
|
| 263 |
+
for neighbour in range(len(nodelist)):
|
| 264 |
+
# print(type(nodesavailable))
|
| 265 |
+
# print(type(nodesavailable[0]))
|
| 266 |
+
if(nodesavailable[neighbour]==1):
|
| 267 |
+
continue
|
| 268 |
+
if neighbour != nextNode.id:
|
| 269 |
+
# h.Decrease_Key(nodelist[neighbour], WScalarMat[nextNode.id][neighbour], nextNode.id)
|
| 270 |
+
edgewt=0
|
| 271 |
+
# print(nodesadded)
|
| 272 |
+
for nodepresent in nodesadded:
|
| 273 |
+
edgewt+=WScalarMat[nodepresent][neighbour]
|
| 274 |
+
# print('adding '+str(neighbour))
|
| 275 |
+
nodeset.add((edgewt,neighbour))
|
| 276 |
+
# print(nodeset)
|
| 277 |
+
|
| 278 |
+
nodeset=sorted(nodeset)
|
| 279 |
+
# print(nodeset)
|
| 280 |
+
# print("#"*30)
|
| 281 |
+
# print(mst_nodes_bool)
|
| 282 |
+
# print('-'*20)
|
| 283 |
+
# print('#'*30)
|
| 284 |
+
mst_nodes = dict(mst_nodes)
|
| 285 |
+
if(it>1000):
|
| 286 |
+
print('!!!!*10')
|
| 287 |
+
for i in range(len(mst_nodes_bool)):
|
| 288 |
+
for j in range(len(mst_nodes_bool)):
|
| 289 |
+
if(i==j):
|
| 290 |
+
continue
|
| 291 |
+
if(mst_nodes_bool[i] and mst_nodes_bool[j]):
|
| 292 |
+
mst_adj_graph[i][j]=True
|
| 293 |
+
mst_adj_graph[j][i]=True
|
| 294 |
+
|
| 295 |
+
# print(mst_adj_graph)
|
| 296 |
+
# print("#")
|
| 297 |
+
return (mst_nodes, mst_adj_graph, mst_nodes_bool)
|
| 298 |
+
|
| 299 |
+
def RandomST_GoldOnly(nodelist, WScalarMat, conflicts_Dict, source):
|
| 300 |
+
(mst_nodes, mst_adj_graph, mst_nodes_bool) = MST(nodelist, WScalarMat, conflicts_Dict, source)
|
| 301 |
+
|
| 302 |
+
mst_adj_graph = np.zeros_like(mst_adj_graph)
|
| 303 |
+
nodelen = len(nodelist)
|
| 304 |
+
|
| 305 |
+
## Random mst_adj_graph
|
| 306 |
+
free_set = list(range(nodelen))
|
| 307 |
+
full_set = list(range(nodelen))
|
| 308 |
+
st_set = []
|
| 309 |
+
start_node = np.random.randint(nodelen)
|
| 310 |
+
st_set.append(start_node)
|
| 311 |
+
free_set.remove(start_node)
|
| 312 |
+
for x in range(nodelen - 1):
|
| 313 |
+
a = st_set[np.random.randint(len(st_set))]
|
| 314 |
+
b = free_set[np.random.randint(len(free_set))]
|
| 315 |
+
if b not in st_set:
|
| 316 |
+
st_set.append(b)
|
| 317 |
+
free_set.remove(b)
|
| 318 |
+
mst_adj_graph[a, b] = 1
|
| 319 |
+
# mst_adj_graph[b, a] = 1 # Directed Spanning tree
|
| 320 |
+
|
| 321 |
+
return (mst_nodes, mst_adj_graph, mst_nodes_bool)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def GetMSTWeight(mst_adj_graph, WScalarMat):
|
| 325 |
+
return np.sum(WScalarMat[mst_adj_graph])
|
dir/heldoutmatchtest.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
testfolder='../NewData/skt_dcs_DS.bz2_4K_bigram_mir_heldout'
|
| 4 |
+
print('loading Testing FIles')
|
| 5 |
+
TestFiles=set()
|
| 6 |
+
Allfiles=set()
|
| 7 |
+
for f in os.listdir(testfolder):
|
| 8 |
+
if '.ds.bz2' in f:
|
| 9 |
+
f = f.replace('.ds.bz2', '.p2')
|
| 10 |
+
TestFiles.add(f)
|
| 11 |
+
Allfiles.add((f,1))
|
| 12 |
+
bz2_input_folder = '../NewData/skt_dcs_DS.bz2_4K_bigram_mir_10K/'
|
| 13 |
+
print('loading Training Files')
|
| 14 |
+
TrainFiles = set()
|
| 15 |
+
for f in os.listdir(bz2_input_folder):
|
| 16 |
+
if '.ds.bz2' in f:
|
| 17 |
+
f = f.replace('.ds.bz2', '.p2')
|
| 18 |
+
TrainFiles.add(f)
|
| 19 |
+
Allfiles.add((f,2))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# TestFiles = sorted(TestFiles)
|
| 24 |
+
# print(TestFiles)
|
| 25 |
+
|
| 26 |
+
# print
|
| 27 |
+
|
| 28 |
+
# TrainFiles = sorted(TrainFiles)
|
| 29 |
+
# print(TrainFiles)
|
| 30 |
+
|
| 31 |
+
print(TestFiles&TrainFiles)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
print(len(TrainFiles))
|
| 35 |
+
print(len(TestFiles))
|
| 36 |
+
# for i in sorted(Allfiles):
|
| 37 |
+
# print i
|
dir/lemmawise_labeller.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv, os, pickle
|
| 2 |
+
import bz2
|
| 3 |
+
from optparse import OptionParser
|
| 4 |
+
|
| 5 |
+
def open_dsbz2(filename):
|
| 6 |
+
with bz2.BZ2File(filename, 'r') as f:
|
| 7 |
+
loader = pickle.load(f)
|
| 8 |
+
|
| 9 |
+
conflicts_Dict_correct = loader['conflicts_Dict_correct']
|
| 10 |
+
nodelist_to_correct_mapping = loader['nodelist_to_correct_mapping']
|
| 11 |
+
nodelist_correct = loader['nodelist_correct']
|
| 12 |
+
featVMat_correct = loader['featVMat_correct']
|
| 13 |
+
featVMat = loader['featVMat']
|
| 14 |
+
conflicts_Dict = loader['conflicts_Dict']
|
| 15 |
+
nodelist = loader['nodelist']
|
| 16 |
+
|
| 17 |
+
return (nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 18 |
+
nodelist, conflicts_Dict, featVMat)
|
| 19 |
+
|
| 20 |
+
def main(small_tag):
|
| 21 |
+
ho_folders = {
|
| 22 |
+
'PR2': 'skt_dcs_DS.bz2_4K_pmi_rfe_heldout',
|
| 23 |
+
'BR2': 'skt_dcs_DS.bz2_4K_bigram_rfe_heldout',
|
| 24 |
+
'PM2': 'skt_dcs_DS.bz2_4K_pmi_mir_heldout',
|
| 25 |
+
'BM2': 'skt_dcs_DS.bz2_4K_bigram_mir_heldout',
|
| 26 |
+
'PR3': 'skt_dcs_DS.bz2_1L_pmi_rfe_heldout',
|
| 27 |
+
'BR3': 'skt_dcs_DS.bz2_1L_bigram_rfe_heldout',
|
| 28 |
+
'PM3': 'skt_dcs_DS.bz2_1L_pmi_mir_heldout_again',
|
| 29 |
+
'BM3': 'skt_dcs_DS.bz2_1L_bigram_heldout'
|
| 30 |
+
}
|
| 31 |
+
bz_folder = '../NewData/{}/'.format(ho_folders[small_tag])
|
| 32 |
+
files = []
|
| 33 |
+
|
| 34 |
+
tag = '{}_NLoss_'.format(small_tag)
|
| 35 |
+
outFile = 'outputs/dump_predictions/lemma_label_{}.csv'.format(small_tag)
|
| 36 |
+
print('Writing to ', outFile)
|
| 37 |
+
|
| 38 |
+
for f in os.listdir('outputs/dump_predictions/'):
|
| 39 |
+
if tag in f:
|
| 40 |
+
print('Adding ', f)
|
| 41 |
+
files.append(f)
|
| 42 |
+
|
| 43 |
+
with open(outFile, 'w') as out_fh:
|
| 44 |
+
out_fh_csv = csv.writer(out_fh)
|
| 45 |
+
fi = 0
|
| 46 |
+
for root_file in files:
|
| 47 |
+
with open(os.path.join('outputs/dump_predictions/', root_file)) as fh:
|
| 48 |
+
print('Processing File: ', root_file)
|
| 49 |
+
fh_csv = csv.reader(fh)
|
| 50 |
+
for lr in fh_csv:
|
| 51 |
+
if fi % 100 == 0:
|
| 52 |
+
print('Files done: ', fi)
|
| 53 |
+
fi += 1
|
| 54 |
+
sent_id = lr[0]
|
| 55 |
+
dcs_name = sent_id + '.ds.bz2'
|
| 56 |
+
(nodelist_correct, _, _, nodelist_to_correct_mapping,\
|
| 57 |
+
_, _, _) = open_dsbz2(os.path.join(bz_folder, dcs_name))
|
| 58 |
+
for rx in range(5):
|
| 59 |
+
lr = next(fh_csv)[1:]
|
| 60 |
+
if rx == 3:
|
| 61 |
+
iam = [int(x) for x in lr]
|
| 62 |
+
for i in range(len(nodelist_correct)):
|
| 63 |
+
out_fh_csv.writerow([sent_id, nodelist_correct[i].lemma, 1*(nodelist_to_correct_mapping[i] in iam)])
|
| 64 |
+
|
| 65 |
+
if __name__ == '__main__':
|
| 66 |
+
parser = OptionParser()
|
| 67 |
+
parser.add_option("-t", "--tag", dest="tag",
|
| 68 |
+
help="Tag for feature set to use", metavar="TAG")
|
| 69 |
+
|
| 70 |
+
(options, args) = parser.parse_args()
|
| 71 |
+
|
| 72 |
+
options = vars(options)
|
| 73 |
+
_tag = options['tag']
|
| 74 |
+
if _tag is None:
|
| 75 |
+
raise Exception('None is tag')
|
| 76 |
+
print(_tag)
|
| 77 |
+
main(_tag)
|
dir/nnet.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
################################################################################################
|
| 5 |
+
################### METHODs: SIGMOID and DERIVATIVE OF SIGMOID ################################
|
| 6 |
+
################################################################################################
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
def sigmoid(vec):
|
| 10 |
+
evec = 1 + np.exp(-vec)
|
| 11 |
+
return 1/evec
|
| 12 |
+
|
| 13 |
+
def d_sigmoid(output_of_gate):
|
| 14 |
+
return output_of_gate*(1-output_of_gate)
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
################################################################################################
|
| 18 |
+
################### METHODs: ReLU AND DERIVATE OF ReLU ########################################
|
| 19 |
+
################################################################################################
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def relu(vec_x):
|
| 23 |
+
relu_x = vec_x.copy()
|
| 24 |
+
relu_x[vec_x < 0] = 0
|
| 25 |
+
return relu_x
|
| 26 |
+
|
| 27 |
+
def lrelu(vec_x):
|
| 28 |
+
relu_x = vec_x.copy()
|
| 29 |
+
relu_x[vec_x < 0] = relu_x[vec_x < 0]/100
|
| 30 |
+
return relu_x
|
| 31 |
+
|
| 32 |
+
def d_relu(vec_x):
|
| 33 |
+
d_relu_x = vec_x.copy()
|
| 34 |
+
d_relu_x[vec_x > 0] = 1
|
| 35 |
+
d_relu_x[vec_x <= 0] = 0
|
| 36 |
+
return d_relu_x
|
| 37 |
+
|
| 38 |
+
def d_lrelu(vec_x):
|
| 39 |
+
d_relu_x = vec_x.copy()
|
| 40 |
+
d_relu_x[vec_x > 0] = 1
|
| 41 |
+
d_relu_x[vec_x <= 0] = 0.01
|
| 42 |
+
return d_relu_x
|
| 43 |
+
|
| 44 |
+
"""
|
| 45 |
+
################################################################################################
|
| 46 |
+
################## IMPLEMENTATION OF NEURAL NETWORK ##########################################
|
| 47 |
+
################################################################################################
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
class NN:
|
| 51 |
+
def __init__(self, input_dimension, hidden_layer_size, outer_relu = True, keep_prob = 1.0):
|
| 52 |
+
# d: Input feature dimension i.e. the dimension of the edge feature vectors
|
| 53 |
+
# n: Hidden layer size
|
| 54 |
+
|
| 55 |
+
# TODO: Add Bias terms
|
| 56 |
+
self.n = hidden_layer_size
|
| 57 |
+
self.d = input_dimension
|
| 58 |
+
|
| 59 |
+
rand_init_range = 1e-2
|
| 60 |
+
self.W = np.random.uniform(-rand_init_range, rand_init_range, (self.n, self.d))
|
| 61 |
+
self.B1 = np.random.uniform(-rand_init_range, rand_init_range, (self.n, 1))
|
| 62 |
+
|
| 63 |
+
rand_init_range = 1e-1
|
| 64 |
+
self.U = np.random.uniform(-rand_init_range, rand_init_range, (self.n, 1))
|
| 65 |
+
self.B2 = np.random.uniform(-rand_init_range, rand_init_range, (1, 1))
|
| 66 |
+
|
| 67 |
+
# Apply relu or sigmoid at the output layer
|
| 68 |
+
# If relu is applied it will be assumed that log is applied to the
|
| 69 |
+
# feature before passing it to the network
|
| 70 |
+
# Else in case of outer sigmoid
|
| 71 |
+
# log is applied after the neural network
|
| 72 |
+
self.outer_relu = outer_relu
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# Learning Rates
|
| 76 |
+
self.etaW = None
|
| 77 |
+
self.etaB1 = None
|
| 78 |
+
|
| 79 |
+
self.etaU = None
|
| 80 |
+
self.etaB2 = None
|
| 81 |
+
|
| 82 |
+
self.version = 'h1'
|
| 83 |
+
|
| 84 |
+
# Dropout
|
| 85 |
+
self.keep_prob = keep_prob
|
| 86 |
+
self.dropout_prob = 1 - keep_prob
|
| 87 |
+
self.r1 = np.ones((input_dimension, 1)) # One hot for input layer
|
| 88 |
+
self.r2 = np.ones(self.B1.shape) # one hot for hidden layer
|
| 89 |
+
|
| 90 |
+
self.training_time = True
|
| 91 |
+
|
| 92 |
+
def new_dropout(self):
|
| 93 |
+
self.r1 = np.random.binomial(1, self.keep_prob, size=self.r1.shape)
|
| 94 |
+
self.r2 = np.random.binomial(1, self.keep_prob, size=self.r2.shape)
|
| 95 |
+
def ForTraining(self):
|
| 96 |
+
self.training_time = True
|
| 97 |
+
def ForTesting(self):
|
| 98 |
+
self.training_time = False
|
| 99 |
+
def Forward_Prop(self, x):
|
| 100 |
+
if self.training_time:
|
| 101 |
+
z2 = np.matmul(self.W, x*self.r1) + self.B1
|
| 102 |
+
a2 = lrelu(z2)*self.r2
|
| 103 |
+
o = np.matmul(self.U.transpose(), a2) + self.B2
|
| 104 |
+
else:
|
| 105 |
+
z2 = np.matmul(self.keep_prob*self.W, x) + self.B1
|
| 106 |
+
a2 = lrelu(z2)
|
| 107 |
+
o = np.matmul(self.keep_prob*self.U.transpose(), a2) + self.B2
|
| 108 |
+
|
| 109 |
+
if self.outer_relu:
|
| 110 |
+
# s = relu(o)
|
| 111 |
+
s = o
|
| 112 |
+
else:
|
| 113 |
+
raise Exception('Support for Non-Outer_Relu removed')
|
| 114 |
+
s = sigmoid(o)
|
| 115 |
+
|
| 116 |
+
return (z2, a2, s)
|
| 117 |
+
|
| 118 |
+
'''
|
| 119 |
+
def Forward_Prop(self, x):
|
| 120 |
+
z2 = np.matmul(self.keep_prob*self.W, x) + self.B1
|
| 121 |
+
a2 = lrelu(z2)
|
| 122 |
+
o = np.matmul(self.keep_prob*self.U.transpose(), a2) + self.B2
|
| 123 |
+
if self.outer_relu:
|
| 124 |
+
# s = relu(o)
|
| 125 |
+
s = o
|
| 126 |
+
else:
|
| 127 |
+
raise Exception('Support for Non-Outer_Relu removed')
|
| 128 |
+
s = sigmoid(o)
|
| 129 |
+
return (z2, a2, s)
|
| 130 |
+
'''
|
| 131 |
+
def Get_Energy(self, x):
|
| 132 |
+
# print("problem arises now")
|
| 133 |
+
x=x[0:1500]
|
| 134 |
+
# numpy.shape(self.W)
|
| 135 |
+
# numpy.shape(x)
|
| 136 |
+
z2 = np.matmul(self.W, x) + self.B1
|
| 137 |
+
# print(len(x))
|
| 138 |
+
a2 = lrelu(z2)
|
| 139 |
+
o = np.matmul(self.U.transpose(), a2) + self.B2
|
| 140 |
+
if self.outer_relu:
|
| 141 |
+
# s = relu(o)
|
| 142 |
+
s = o
|
| 143 |
+
else:
|
| 144 |
+
raise Exception('Support for Non-Outer_Relu removed')
|
| 145 |
+
s = sigmoid(o)
|
| 146 |
+
return s
|
| 147 |
+
|
| 148 |
+
# Back_Propagate gradient of Loss, L: Assuming S is the direct output of the network
|
| 149 |
+
def Back_Prop(self, dLdOut, nodeLen, featVMat, _debug = True):
|
| 150 |
+
N = nodeLen
|
| 151 |
+
dLdU = np.zeros(self.U.shape)
|
| 152 |
+
dLdB2 = np.zeros(self.B2.shape)
|
| 153 |
+
|
| 154 |
+
dLdW = np.zeros(self.W.shape)
|
| 155 |
+
dLdB1 = np.zeros(self.B1.shape)
|
| 156 |
+
|
| 157 |
+
if not self.outer_relu:
|
| 158 |
+
raise Exception('Support for Non-Outer_Relu removed')
|
| 159 |
+
return
|
| 160 |
+
else:
|
| 161 |
+
etaW = self.etaW
|
| 162 |
+
etaB1 = self.etaB1
|
| 163 |
+
|
| 164 |
+
etaU = self.etaU
|
| 165 |
+
etaB2 = self.etaB2
|
| 166 |
+
|
| 167 |
+
if (etaW is None) or (etaB1 is None) or (etaU is None) or (etaB2 is None):
|
| 168 |
+
raise Exception('Learning Rates Not Set...')
|
| 169 |
+
|
| 170 |
+
batch_size = 0
|
| 171 |
+
for i in range(N):
|
| 172 |
+
for j in range(N):
|
| 173 |
+
if dLdOut[i, j] != 0 and (featVMat[i][j] is not None):
|
| 174 |
+
batch_size += 1
|
| 175 |
+
x = featVMat[i][j][0:1500]
|
| 176 |
+
(z2, a2, s) = self.Forward_Prop(x)
|
| 177 |
+
# print(a2.transpose())
|
| 178 |
+
# print('o')
|
| 179 |
+
# print(np.matmul(self.U.transpose(), a2))
|
| 180 |
+
|
| 181 |
+
dLdU += dLdOut[i, j]*a2
|
| 182 |
+
|
| 183 |
+
dLdB2 += dLdOut[i, j]
|
| 184 |
+
|
| 185 |
+
dRelu = d_lrelu(z2)
|
| 186 |
+
dLdW += (dLdOut[i, j])*np.matmul((self.U*dRelu), (x*self.r1).transpose())
|
| 187 |
+
|
| 188 |
+
dLdB1 += dLdOut[i, j]*np.matmul(self.U.transpose(), dRelu)
|
| 189 |
+
|
| 190 |
+
if batch_size > 0:
|
| 191 |
+
delW = etaW*dLdW/(batch_size)
|
| 192 |
+
delU = etaU*dLdU/(batch_size)
|
| 193 |
+
delB1 = etaB1*dLdB1/batch_size
|
| 194 |
+
delB2 = etaB2*dLdB2/batch_size
|
| 195 |
+
if _debug:
|
| 196 |
+
print('Max(delW): %10.6f\tMax(delU): %10.6f'%(np.max(np.abs(delW)), np.max(np.abs(delU))))
|
| 197 |
+
self.W -= delW
|
| 198 |
+
self.B1 -= delB1
|
| 199 |
+
|
| 200 |
+
self.U -= delU
|
| 201 |
+
self.B2 -= delB2
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class NN_2:
|
| 205 |
+
def __init__(self, input_dimension, hidden_layer_1_size, hidden_layer_2_size = None, outer_relu = True):
|
| 206 |
+
# d: Input feature dimension i.e. the dimension of the edge feature vectors
|
| 207 |
+
# n: Hidden layer size
|
| 208 |
+
|
| 209 |
+
if hidden_layer_2_size is None:
|
| 210 |
+
hidden_layer_2_size = hidden_layer_1_size
|
| 211 |
+
|
| 212 |
+
# TODO: Add Bias terms
|
| 213 |
+
self.h1 = hidden_layer_1_size
|
| 214 |
+
self.h2 = hidden_layer_2_size
|
| 215 |
+
self.d = input_dimension
|
| 216 |
+
|
| 217 |
+
rand_init_range = 1e-2
|
| 218 |
+
self.W1 = np.random.uniform(-rand_init_range, rand_init_range, (self.h1, self.d))
|
| 219 |
+
self.B1 = np.random.uniform(-rand_init_range, rand_init_range, (self.h1, 1))
|
| 220 |
+
self.W2 = np.random.uniform(-rand_init_range, rand_init_range, (self.h2, self.h1))
|
| 221 |
+
self.B2 = np.random.uniform(-rand_init_range, rand_init_range, (self.h2, 1))
|
| 222 |
+
|
| 223 |
+
rand_init_range = 1e-1
|
| 224 |
+
self.U = np.random.uniform(-rand_init_range, rand_init_range, (self.h2, 1))
|
| 225 |
+
self.B3 = np.random.uniform(-rand_init_range, rand_init_range, (1, 1))
|
| 226 |
+
|
| 227 |
+
# Apply relu or sigmoid at the output layer
|
| 228 |
+
# If relu is applied it will be assumed that log is applied to the
|
| 229 |
+
# feature before passing it to the network
|
| 230 |
+
# Else in case of outer sigmoid
|
| 231 |
+
# log is applied after the neural network
|
| 232 |
+
self.outer_relu = outer_relu
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# Learning Rates
|
| 236 |
+
self.etaW1 = None
|
| 237 |
+
self.etaB1 = None
|
| 238 |
+
self.etaW2 = None
|
| 239 |
+
self.etaB2 = None
|
| 240 |
+
|
| 241 |
+
self.etaU = None
|
| 242 |
+
self.etaB3 = None
|
| 243 |
+
|
| 244 |
+
self.version = 'h2'
|
| 245 |
+
|
| 246 |
+
def Forward_Prop(self, x):
|
| 247 |
+
z2 = np.matmul(self.W1, x) + self.B1
|
| 248 |
+
a2 = lrelu(z2)
|
| 249 |
+
|
| 250 |
+
z3 = np.matmul(self.W2, a2) + self.B2
|
| 251 |
+
a3 = lrelu(z3)
|
| 252 |
+
|
| 253 |
+
o = np.matmul(self.U.transpose(), a3) + self.B3
|
| 254 |
+
if self.outer_relu:
|
| 255 |
+
# s = relu(o)
|
| 256 |
+
s = o
|
| 257 |
+
else:
|
| 258 |
+
raise Exception('Support for Non-Outer_Relu removed')
|
| 259 |
+
s = sigmoid(o)
|
| 260 |
+
return (z3, a3, z2, a2, s)
|
| 261 |
+
def Get_Energy(self, x):
|
| 262 |
+
z2 = np.matmul(self.W1, x) + self.B1
|
| 263 |
+
a2 = lrelu(z2)
|
| 264 |
+
|
| 265 |
+
z3 = np.matmul(self.W2, a2) + self.B2
|
| 266 |
+
a3 = lrelu(z3)
|
| 267 |
+
|
| 268 |
+
o = np.matmul(self.U.transpose(), a3) + self.B3
|
| 269 |
+
if self.outer_relu:
|
| 270 |
+
# s = relu(o)
|
| 271 |
+
s = o
|
| 272 |
+
else:
|
| 273 |
+
raise Exception('Support for Non-Outer_Relu removed')
|
| 274 |
+
s = sigmoid(o)
|
| 275 |
+
return s
|
| 276 |
+
|
| 277 |
+
# Back_Propagate gradient of Loss, L: Assuming S is the direct output of the network
|
| 278 |
+
def Back_Prop(self, dLdOut, nodeLen, featVMat, _debug = True):
|
| 279 |
+
N = nodeLen
|
| 280 |
+
|
| 281 |
+
dLdU = np.zeros(self.U.shape)
|
| 282 |
+
dLdB3 = np.zeros(self.B3.shape)
|
| 283 |
+
|
| 284 |
+
dLdW2 = np.zeros(self.W2.shape)
|
| 285 |
+
dLdB2 = np.zeros(self.B2.shape)
|
| 286 |
+
|
| 287 |
+
dLdW1 = np.zeros(self.W1.shape)
|
| 288 |
+
dLdB1 = np.zeros(self.B1.shape)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
if not self.outer_relu:
|
| 292 |
+
raise Exception('Support for Non-Outer_Relu removed')
|
| 293 |
+
return
|
| 294 |
+
else:
|
| 295 |
+
etaW1 = self.etaW1
|
| 296 |
+
etaB1 = self.etaB1
|
| 297 |
+
|
| 298 |
+
etaW2 = self.etaW2
|
| 299 |
+
etaB2 = self.etaB2
|
| 300 |
+
|
| 301 |
+
etaU = self.etaU
|
| 302 |
+
etaB3 = self.etaB3
|
| 303 |
+
|
| 304 |
+
if (etaW1 is None) or (etaB1 is None) or (etaW2 is None) or (etaB2 is None) or (etaU is None) or (etaB3 is None):
|
| 305 |
+
raise Exception('Learning Rates Not Set...')
|
| 306 |
+
|
| 307 |
+
batch_size = 0
|
| 308 |
+
for i in range(N):
|
| 309 |
+
for j in range(N):
|
| 310 |
+
if dLdOut[i, j] != 0 and (featVMat[i][j] is not None):
|
| 311 |
+
batch_size += 1
|
| 312 |
+
(z3, a3, z2, a2, s) = self.Forward_Prop(featVMat[i][j])
|
| 313 |
+
# print(a2.transpose())
|
| 314 |
+
# print('o')
|
| 315 |
+
# print(np.matmul(self.U.transpose(), a2))
|
| 316 |
+
|
| 317 |
+
dLdU += dLdOut[i, j]*a3
|
| 318 |
+
|
| 319 |
+
dLdB3 += dLdOut[i, j]
|
| 320 |
+
|
| 321 |
+
dRelu_z3 = d_lrelu(z3)
|
| 322 |
+
|
| 323 |
+
dLdW2 += (dLdOut[i, j])*np.matmul((self.U*dRelu_z3), a2.transpose())
|
| 324 |
+
|
| 325 |
+
dLdB2 += dLdOut[i, j]*self.U*dRelu_z3
|
| 326 |
+
|
| 327 |
+
dRelu_z2 = d_lrelu(z2)
|
| 328 |
+
|
| 329 |
+
dLdW1 += (dLdOut[i, j])*np.matmul(np.matmul(self.W2.transpose(), self.U*dRelu_z3)*dRelu_z2, featVMat[i][j].transpose())
|
| 330 |
+
|
| 331 |
+
dLdB1 += (dLdOut[i, j])*np.matmul(self.W2.transpose(), self.U*dRelu_z3)*dRelu_z2
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
# for k in range(self.n):
|
| 335 |
+
# if dRelu[k] != 0:
|
| 336 |
+
# dLdW[k, :, None] += (dLdOut[i, j])*self.U[k]*dRelu[k]*(featVMat[i][j])
|
| 337 |
+
# print('dlDW:')
|
| 338 |
+
# print(dLdW/(batch_size))
|
| 339 |
+
# print('dlDU:')
|
| 340 |
+
# print(dLdU/(batch_size))
|
| 341 |
+
# print('Batch size: ', batch_size)
|
| 342 |
+
if batch_size > 0:
|
| 343 |
+
delW1 = etaW1*dLdW1/(batch_size)
|
| 344 |
+
delW2 = etaW1*dLdW2/(batch_size)
|
| 345 |
+
delU = etaU*dLdU/(batch_size)
|
| 346 |
+
delB1 = etaB1*dLdB1/batch_size
|
| 347 |
+
delB2 = etaB2*dLdB2/batch_size
|
| 348 |
+
delB3 = etaB2*dLdB3/batch_size
|
| 349 |
+
if _debug:
|
| 350 |
+
print('Max(delW2): %10.6f\tMax(delW1): %10.6f\tMax(delU): %10.6f'%(np.max(np.abs(delW2)), np.max(np.abs(delW1)), np.max(np.abs(delU))))
|
| 351 |
+
|
| 352 |
+
# Layer 1
|
| 353 |
+
self.W1 -= delW1
|
| 354 |
+
self.B1 -= delB1
|
| 355 |
+
|
| 356 |
+
# Layer 2
|
| 357 |
+
self.B2 -= delB2
|
| 358 |
+
self.W2 -= delW2
|
| 359 |
+
|
| 360 |
+
# Layer 3
|
| 361 |
+
self.U -= delU
|
| 362 |
+
self.B3 -= delB3
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
|
dir/pvb.p
ADDED
|
Binary file (1.3 kB). View file
|
|
|
dir/rom.txt
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
a,a
|
| 2 |
+
ā,A
|
| 3 |
+
i,i
|
| 4 |
+
ī,I
|
| 5 |
+
u,u
|
| 6 |
+
ū,U
|
| 7 |
+
ṛ,f
|
| 8 |
+
e,e
|
| 9 |
+
o,o
|
| 10 |
+
ṃ,M
|
| 11 |
+
k,k
|
| 12 |
+
g,g
|
| 13 |
+
ṅ,N
|
| 14 |
+
c,c
|
| 15 |
+
j,j
|
| 16 |
+
ñ,Y
|
| 17 |
+
ṭ,w
|
| 18 |
+
ḍ,q
|
| 19 |
+
ṇ,R
|
| 20 |
+
t,t
|
| 21 |
+
d,d
|
| 22 |
+
n,n
|
| 23 |
+
p,p
|
| 24 |
+
b,b
|
| 25 |
+
y,y
|
| 26 |
+
r,r
|
| 27 |
+
l,l
|
| 28 |
+
v,v
|
| 29 |
+
ś,S
|
| 30 |
+
ṣ,z
|
| 31 |
+
s,s
|
| 32 |
+
h,h
|
| 33 |
+
ḥ,H
|
dir/rom2.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ai,E
|
| 2 |
+
au,O
|
| 3 |
+
kh,K
|
| 4 |
+
gh,G
|
| 5 |
+
ch,C
|
| 6 |
+
jh,J
|
| 7 |
+
ṭh,W
|
| 8 |
+
ḍh,Q
|
| 9 |
+
th,T
|
| 10 |
+
dh,D
|
| 11 |
+
ph,P
|
| 12 |
+
bh,B
|
dir/romtoslp.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Tue Apr 5 19:14:27 2016
|
| 4 |
+
|
| 5 |
+
@author: puneet
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def rom_slp(a):
|
| 11 |
+
|
| 12 |
+
double_dict={}
|
| 13 |
+
f=open('rom2.txt','r')
|
| 14 |
+
for lines in f.readlines():
|
| 15 |
+
words=lines.split(',')
|
| 16 |
+
words[1]=words[1].replace('\n','')
|
| 17 |
+
double_dict[words[0]]=words[1]
|
| 18 |
+
f.close()
|
| 19 |
+
single_dict={}
|
| 20 |
+
q=open('rom.txt','r')
|
| 21 |
+
for lines in q.readlines():
|
| 22 |
+
words=lines.split(',')
|
| 23 |
+
words[1]=words[1].replace('\n','')
|
| 24 |
+
single_dict[words[0]]=words[1]
|
| 25 |
+
q.close()
|
| 26 |
+
|
| 27 |
+
for elem in double_dict:
|
| 28 |
+
if elem in a:
|
| 29 |
+
a=a.replace(elem,double_dict[elem])
|
| 30 |
+
for elem in single_dict:
|
| 31 |
+
if elem in a:
|
| 32 |
+
a=a.replace(elem,single_dict[elem])
|
| 33 |
+
return(a)
|
| 34 |
+
|
| 35 |
+
|
dir/romtoslp.pyc
ADDED
|
Binary file (934 Bytes). View file
|
|
|
dir/sandhiRules.p
ADDED
|
Binary file (95.8 kB). View file
|
|
|
dir/sentences.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#Loading of SKT Pickles
|
| 2 |
+
from romtoslp import rom_slp
|
| 3 |
+
from json import *
|
| 4 |
+
import pprint
|
| 5 |
+
from utilities import *
|
| 6 |
+
class word_new():
|
| 7 |
+
def __init__(self,names):
|
| 8 |
+
self.lemmas=[]
|
| 9 |
+
self.names=names
|
| 10 |
+
self.urls=[]
|
| 11 |
+
self.forms=[]
|
| 12 |
+
|
| 13 |
+
class chunks:
|
| 14 |
+
def __init__(self,chunk_name):
|
| 15 |
+
self.chunk_name=chunk_name
|
| 16 |
+
self.chunk_words={}
|
| 17 |
+
|
| 18 |
+
class sentences:
|
| 19 |
+
def __init__(self,sent_id,sentence):
|
| 20 |
+
self.sent_id=sent_id
|
| 21 |
+
self.sentence=sentence
|
| 22 |
+
self.chunk=[]
|
| 23 |
+
|
| 24 |
+
# def getCNGs(formsDict):
|
| 25 |
+
# l = []
|
| 26 |
+
# if type(formsDict) == int or type(formsDict) == str:
|
| 27 |
+
# return [int(formsDict)]
|
| 28 |
+
# else:
|
| 29 |
+
# for form, configs in formsDict.items():
|
| 30 |
+
# for c in configs:
|
| 31 |
+
# if(form == 'verbform'):
|
| 32 |
+
# continue
|
| 33 |
+
# else:
|
| 34 |
+
# l.append(wtc_recursive(form, configs))
|
| 35 |
+
# return list(set(l))
|
| 36 |
+
|
| 37 |
+
class SentenceError(Exception):
|
| 38 |
+
def __init__(self, message):
|
| 39 |
+
|
| 40 |
+
# Call the base class constructor with the parameters it needs
|
| 41 |
+
super(SentenceError, self).__init__(message)
|
| 42 |
+
|
| 43 |
+
def SeeSentence(sentenceObj):
|
| 44 |
+
print('SKT ANALYZE')
|
| 45 |
+
print('-'*15)
|
| 46 |
+
print(sentenceObj.sentence)
|
| 47 |
+
zz = 0
|
| 48 |
+
# (chunkDict, lemmaList, wordList, revMap2Chunk, qu, cngList, verbs, tuplesMain) = SentencePreprocess(sentenceObj)
|
| 49 |
+
# for cid in chunkDict.keys():
|
| 50 |
+
# print('Analyzing:', rom_slp(sentenceObj.chunk[cid].chunk_name))
|
| 51 |
+
# for pos in chunkDict[cid].keys():
|
| 52 |
+
# tupIds = chunkDict[cid][pos]
|
| 53 |
+
# for ti in tupIds:
|
| 54 |
+
# print('%d :' % (pos, ), end = ' ')
|
| 55 |
+
# print(tuplesMain[ti][0][1], end=' ')
|
| 56 |
+
# for tup in tuplesMain[ti]:
|
| 57 |
+
# print([zz, tup[2], tup[3]], end=' ')
|
| 58 |
+
# zz += 1
|
| 59 |
+
# print('')
|
| 60 |
+
# print('-'*25)
|
| 61 |
+
|
| 62 |
+
for chunk in sentenceObj.chunk:
|
| 63 |
+
print("Analyzing ", rom_slp(chunk.chunk_name))
|
| 64 |
+
for pos in chunk.chunk_words.keys():
|
| 65 |
+
for word_sense in chunk.chunk_words[pos]:
|
| 66 |
+
word_sense = fix_w_new(word_sense)
|
| 67 |
+
print(pos, ": ", rom_slp(word_sense.names), word_sense.lemmas, word_sense.forms)
|
| 68 |
+
# for formsDict in word_sense.forms:
|
| 69 |
+
# print(getCNGs(formsDict))
|
| 70 |
+
print()
|
| 71 |
+
|
| 72 |
+
def getWord(sentenceObj, cid, pos,kii):
|
| 73 |
+
ch = sentenceObj.chunk[cid]
|
| 74 |
+
word = ch.chunk_words[pos][kii]
|
| 75 |
+
return {'lemmas': word.lemmas, 'forms':word.forms, 'names':word.names}
|
| 76 |
+
|
| 77 |
+
# ---------------------------------------------------------------------------------------------------------------------
|
| 78 |
+
# ---------------------------------------------------------------------------------------------------------------------
|
| 79 |
+
|
| 80 |
+
# ---------------------------------------------------------------------------------------------------------------------
|
| 81 |
+
# ---------------------------------------------------------------------------------------------------------------------
|
| 82 |
+
from wordTypeCheckFunction import *
|
| 83 |
+
import pickle
|
| 84 |
+
|
| 85 |
+
"""
|
| 86 |
+
SentencePreprocess:
|
| 87 |
+
-------------------
|
| 88 |
+
Read a sentence obj and create + return the following objects
|
| 89 |
+
|
| 90 |
+
-> chunkDict: chunk_id -> position -> index in lemmaList (nested dictionary)
|
| 91 |
+
-> lemmaList: list of possible words as a result of word segmentation
|
| 92 |
+
-> revMap2Chunk: Map word in wordList to (cid, position) in chunkDict
|
| 93 |
+
-> qu: Possible query nodes
|
| 94 |
+
"""
|
| 95 |
+
v2t = pickle.load(open('verbs_vs_cngs_matrix_countonly.p', 'rb'), encoding=u'utf8')
|
| 96 |
+
def wtc_recursive(form, c):
|
| 97 |
+
if type(c) ==list:
|
| 98 |
+
for cc in c:
|
| 99 |
+
return wtc_recursive(form, cc)
|
| 100 |
+
else:
|
| 101 |
+
return wordTypeCheck(form, c)
|
| 102 |
+
|
| 103 |
+
def CanBeQuery(chunk):
|
| 104 |
+
allLemmas = []
|
| 105 |
+
for pos, words in chunk.chunk_words.items():
|
| 106 |
+
for word in words:
|
| 107 |
+
for lemma in word.lemmas:
|
| 108 |
+
if lemma != '':
|
| 109 |
+
allLemmas.append(lemma)
|
| 110 |
+
if(len(allLemmas) == 1):
|
| 111 |
+
return True
|
| 112 |
+
|
| 113 |
+
def Get_QCs(tuplesMain, chunkDict):
|
| 114 |
+
# Form NON-competitor dictionary - Query - Candidate Pairs
|
| 115 |
+
qc_pairs = {}
|
| 116 |
+
nodeList = [t for ts in tuplesMain for t in ts]
|
| 117 |
+
|
| 118 |
+
for ni in range(len(nodeList)):
|
| 119 |
+
qc_pairs[ni] = set(range(len(nodeList))) - set([ni])
|
| 120 |
+
|
| 121 |
+
for cid in chunkDict.keys():
|
| 122 |
+
# Neighbours
|
| 123 |
+
for pos1 in chunkDict[cid].keys():
|
| 124 |
+
for pos2 in chunkDict[cid].keys():
|
| 125 |
+
if pos1 <= pos2:
|
| 126 |
+
nList1 = []
|
| 127 |
+
for ti1 in chunkDict[cid][pos1]:
|
| 128 |
+
for tup1 in tuplesMain[ti1]:
|
| 129 |
+
nList1.append(tup1[0])
|
| 130 |
+
nList2 = []
|
| 131 |
+
for ti2 in chunkDict[cid][pos2]:
|
| 132 |
+
for tup2 in tuplesMain[ti2]:
|
| 133 |
+
nList2.append(tup2[0])
|
| 134 |
+
nList1 = set(nList1)
|
| 135 |
+
nList2 = set(nList2)
|
| 136 |
+
for n1 in nList1:
|
| 137 |
+
qc_pairs[n1] = qc_pairs[n1] - nList1
|
| 138 |
+
|
| 139 |
+
for n2 in nList2:
|
| 140 |
+
qc_pairs[n2] = qc_pairs[n2] - nList2
|
| 141 |
+
|
| 142 |
+
if pos1 < pos2:
|
| 143 |
+
for n1 in nList1:
|
| 144 |
+
for n2 in nList2:
|
| 145 |
+
if not CanCoExist_sandhi(pos1, pos2, nodeList[n1][1], nodeList[n2][1]):
|
| 146 |
+
qc_pairs[n1] = qc_pairs[n1] - set([n2])
|
| 147 |
+
qc_pairs[n2] = qc_pairs[n2] - set([n1])
|
| 148 |
+
|
| 149 |
+
return qc_pairs
|
| 150 |
+
|
| 151 |
+
'''
|
| 152 |
+
===================
|
| 153 |
+
SentencePreprocess
|
| 154 |
+
===================
|
| 155 |
+
forceQuery: Setting it true will make the longest word available a query if no
|
| 156 |
+
other query is available
|
| 157 |
+
'''
|
| 158 |
+
def SentencePreprocess(sentenceObj, forceQuery = False):
|
| 159 |
+
"""
|
| 160 |
+
Considering word names only
|
| 161 |
+
***{Word forms or cngs can also be used}
|
| 162 |
+
"""
|
| 163 |
+
def getCNGs(formsDict):
|
| 164 |
+
if type(formsDict) == int or type(formsDict) == str:
|
| 165 |
+
return [int(formsDict)]
|
| 166 |
+
else:
|
| 167 |
+
l = []
|
| 168 |
+
for form, configs in formsDict.items():
|
| 169 |
+
for c in configs:
|
| 170 |
+
if(form == 'verbform'):
|
| 171 |
+
continue
|
| 172 |
+
else:
|
| 173 |
+
l.append(wtc_recursive(form, c))
|
| 174 |
+
return list(set(l))
|
| 175 |
+
|
| 176 |
+
chunkDict = {}
|
| 177 |
+
lemmaList = []
|
| 178 |
+
wordList = []
|
| 179 |
+
cngList = []
|
| 180 |
+
revMap2Chunk = []
|
| 181 |
+
qu = []
|
| 182 |
+
tuplesMain = []
|
| 183 |
+
|
| 184 |
+
cid = -1
|
| 185 |
+
tidExclusive = 0
|
| 186 |
+
|
| 187 |
+
## Traverse sentence and form data-structures
|
| 188 |
+
for chunk in sentenceObj.chunk:
|
| 189 |
+
# print(chunk.chunk_name)
|
| 190 |
+
cid = cid+1
|
| 191 |
+
chunkDict[cid] = {}
|
| 192 |
+
for pos in chunk.chunk_words.keys():
|
| 193 |
+
tupleSet = {}
|
| 194 |
+
chunkDict[cid][pos] = []
|
| 195 |
+
for word_sense in chunk.chunk_words[pos]:
|
| 196 |
+
# word_sense = fix_w_new(word_sense)
|
| 197 |
+
nama = rom_slp(word_sense.names)
|
| 198 |
+
if nama == '':
|
| 199 |
+
raise SentenceError('Empty Name Detected')
|
| 200 |
+
if(len(word_sense.lemmas) > 0 and len(word_sense.forms) > 0):
|
| 201 |
+
tuples = []
|
| 202 |
+
for lemmaI in range(len(word_sense.lemmas)):
|
| 203 |
+
# lemma = rom_slp(word_sense.lemmas[lemmaI].split('_')[0]) # NOT REQUIRED - DONE IN FIX_W_NEW
|
| 204 |
+
lemma = word_sense.lemmas[lemmaI]
|
| 205 |
+
if lemma == '':
|
| 206 |
+
continue
|
| 207 |
+
tempCNGs = getCNGs(word_sense.forms[lemmaI])
|
| 208 |
+
for cng in tempCNGs:
|
| 209 |
+
# UPDATE LISTS
|
| 210 |
+
newT_Key = (lemma, cng)
|
| 211 |
+
newT = (tidExclusive, nama, lemma, cng)
|
| 212 |
+
if(newT_Key not in tupleSet):
|
| 213 |
+
tupleSet[newT_Key] = 1
|
| 214 |
+
tuples.append(newT) # Remember the order
|
| 215 |
+
lemmaList.append(lemma)
|
| 216 |
+
wordList.append(nama)
|
| 217 |
+
cngList.append(cng)
|
| 218 |
+
revMap2Chunk.append((cid, pos, len(tuplesMain)))
|
| 219 |
+
tidExclusive += 1
|
| 220 |
+
|
| 221 |
+
if(len(tuples) > 0):
|
| 222 |
+
# print(tuples)
|
| 223 |
+
k = len(tuplesMain)
|
| 224 |
+
chunkDict[cid][pos].append(k)
|
| 225 |
+
tuplesMain.append(tuples)
|
| 226 |
+
|
| 227 |
+
## Find QUERY nodes now
|
| 228 |
+
for cid in chunkDict.keys():
|
| 229 |
+
tuples = []
|
| 230 |
+
for pos in chunkDict[cid].keys():
|
| 231 |
+
tupIds = chunkDict[cid][pos]
|
| 232 |
+
for tupId in tupIds:
|
| 233 |
+
[tuples.append((pos, tup[0], tup[1])) for tup in tuplesMain[tupId]]
|
| 234 |
+
for u in range(len(tuples)):
|
| 235 |
+
tup1 = tuples[u]
|
| 236 |
+
quFlag = True
|
| 237 |
+
for v in range(len(tuples)):
|
| 238 |
+
if(u == v):
|
| 239 |
+
continue
|
| 240 |
+
tup2 = tuples[v]
|
| 241 |
+
|
| 242 |
+
# '''
|
| 243 |
+
# FIXME: REMOVE TRY CATCH
|
| 244 |
+
# '''
|
| 245 |
+
# try:
|
| 246 |
+
if(tup1[0] < tup2[0]):
|
| 247 |
+
if not CanCoExist_sandhi(tup1[0], tup2[0], tup1[2], tup2[2]):
|
| 248 |
+
## Found a competing node - hence can't be a query
|
| 249 |
+
quFlag = False
|
| 250 |
+
break
|
| 251 |
+
elif(tup1[0] > tup2[0]):
|
| 252 |
+
if not CanCoExist_sandhi(tup2[0], tup1[0], tup2[2], tup1[2]):
|
| 253 |
+
## Found a competing node - hence can't be a query
|
| 254 |
+
quFlag = False
|
| 255 |
+
break
|
| 256 |
+
else:
|
| 257 |
+
quFlag = False
|
| 258 |
+
break
|
| 259 |
+
|
| 260 |
+
# except IndexError:
|
| 261 |
+
# print('From SentencePreprocess IndexError:', sentenceObj.sent_id)
|
| 262 |
+
# raise IndexError
|
| 263 |
+
|
| 264 |
+
if quFlag:
|
| 265 |
+
qu.append(tup1[1])
|
| 266 |
+
|
| 267 |
+
# if len(qu) == 0:
|
| 268 |
+
# print('No query available')
|
| 269 |
+
# maxI = 0
|
| 270 |
+
# for i in range(len(wordList)):
|
| 271 |
+
# if len(wordList[i]) > len(wordList[maxI]):
|
| 272 |
+
# maxI = i
|
| 273 |
+
# elif len(wordList[i]) == len(wordList[maxI]):
|
| 274 |
+
# # Check the competitor count
|
| 275 |
+
|
| 276 |
+
# print(wordList[maxI], 'is forced query')
|
| 277 |
+
|
| 278 |
+
verbs = []
|
| 279 |
+
i = -1
|
| 280 |
+
for w in lemmaList:
|
| 281 |
+
i += 1
|
| 282 |
+
if w in list(v2t.keys()):
|
| 283 |
+
verbs.append(i)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# pprint.pprint(tuplesMain)
|
| 287 |
+
# pprint.pprint(chunkDict)
|
| 288 |
+
# pprint.pprint(revMap2Chunk)
|
| 289 |
+
|
| 290 |
+
qc_pairs = Get_QCs(tuplesMain, chunkDict)
|
| 291 |
+
|
| 292 |
+
'''
|
| 293 |
+
qu = [] # Have to remove it later
|
| 294 |
+
'''
|
| 295 |
+
# print(chunkDict)
|
| 296 |
+
if len(qu) == 0 and len(lemmaList) > 0:
|
| 297 |
+
lens = np.array([len(t[1]) for ts in tuplesMain for t in ts])
|
| 298 |
+
cw = [(t[0], t[1]) for ts in tuplesMain for t in ts]
|
| 299 |
+
round1 = np.where(lens == np.max(lens))[0]
|
| 300 |
+
hits = [len(qc_pairs[r]) for r in round1]
|
| 301 |
+
finalist = round1[np.where(hits == np.min(hits))][0]
|
| 302 |
+
qu.append(finalist)
|
| 303 |
+
|
| 304 |
+
return (chunkDict, lemmaList, wordList, revMap2Chunk, qu, cngList, verbs, tuplesMain, qc_pairs)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
|
dir/sh_TestPool_MP_clique.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import multiprocessing as mp
|
| 2 |
+
import TestPool_Unit
|
| 3 |
+
from shutil import copyfile
|
| 4 |
+
import numpy as np
|
| 5 |
+
import time
|
| 6 |
+
import sys
|
| 7 |
+
from optparse import OptionParser
|
| 8 |
+
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
|
| 11 |
+
def Evaluate(result_arr):
|
| 12 |
+
print('Files Processed: ', len(result_arr))
|
| 13 |
+
recalls = []
|
| 14 |
+
recalls_of_word = []
|
| 15 |
+
precisions = []
|
| 16 |
+
precisions_of_words = []
|
| 17 |
+
fully_Correct_l = 0
|
| 18 |
+
fully_Correct_w = 0
|
| 19 |
+
for entry in result_arr:
|
| 20 |
+
(word_match, lemma_match, n_dcsWords, n_output_nodes) = entry
|
| 21 |
+
recalls.append(lemma_match/n_dcsWords)
|
| 22 |
+
recalls_of_word.append(word_match/n_dcsWords)
|
| 23 |
+
|
| 24 |
+
precisions.append(lemma_match/n_output_nodes)
|
| 25 |
+
precisions_of_words.append(word_match/n_output_nodes)
|
| 26 |
+
if lemma_match == n_dcsWords:
|
| 27 |
+
fully_Correct_l += 1
|
| 28 |
+
if word_match == n_dcsWords:
|
| 29 |
+
fully_Correct_w += 1
|
| 30 |
+
print('Avg. Micro Recall of Lemmas: {}'.format(np.mean(np.array(recalls))))
|
| 31 |
+
print('Avg. Micro Recall of Words: {}'.format(np.mean(np.array(recalls_of_word))))
|
| 32 |
+
print('Avg. Micro Precision of Lemmas: {}'.format(np.mean(np.array(precisions))))
|
| 33 |
+
print('Avg. Micro Precision of Words: {}'.format(np.mean(np.array(precisions_of_words))))
|
| 34 |
+
rl = np.mean(np.array(recalls))
|
| 35 |
+
pl = np.mean(np.array(precisions))
|
| 36 |
+
print('F-Score of Lemmas: ', (2*pl*rl)/(pl+rl))
|
| 37 |
+
print('Fully Correct Lemmawise: {}'.format(fully_Correct_l/len(recalls_of_word)))
|
| 38 |
+
print('Fully Correct Wordwise: {}'.format(fully_Correct_w/len(recalls_of_word)))
|
| 39 |
+
print('[{:0.2f}, {:0.2f}, {:0.2f}, {:0.2f}, {:0.2f}, {:0.2f}, {:0.2f}]'.format(100*np.mean(np.array(recalls)), 100*np.mean(np.array(recalls_of_word)), 100*np.mean(np.array(precisions)), \
|
| 40 |
+
100*np.mean(np.array(precisions_of_words)), 100*(2*pl*rl)/(pl+rl), 100*fully_Correct_l/len(recalls_of_word),\
|
| 41 |
+
100*fully_Correct_w/len(recalls_of_word)))
|
| 42 |
+
sys.stdout.flush()
|
| 43 |
+
|
| 44 |
+
tag = None
|
| 45 |
+
proc_count = 4
|
| 46 |
+
|
| 47 |
+
def main():
|
| 48 |
+
global proc_count, tag
|
| 49 |
+
ho_folders = {
|
| 50 |
+
'PR2': 'skt_dcs_DS.bz2_4K_pmi_rfe_heldout',
|
| 51 |
+
'BR2': 'skt_dcs_DS.bz2_4K_bigram_rfe_heldout',
|
| 52 |
+
'PM2': 'skt_dcs_DS.bz2_4K_pmi_mir_heldout',
|
| 53 |
+
'BM2': 'skt_dcs_DS.bz2_4K_bigram_mir_heldout',
|
| 54 |
+
'PR3': 'skt_dcs_DS.bz2_1L_pmi_rfe_heldout',
|
| 55 |
+
'BR3': 'skt_dcs_DS.bz2_1L_bigram_rfe_heldout',
|
| 56 |
+
'PM3': 'skt_dcs_DS.bz2_1L_pmi_mir_heldout_again',
|
| 57 |
+
'BM3': 'skt_dcs_DS.bz2_1L_bigram_heldout'
|
| 58 |
+
}
|
| 59 |
+
modelList = {
|
| 60 |
+
'PR2': 'outputs/train_{}/nnet_e1_i400.p'.format('t2788294192566'),
|
| 61 |
+
'BR2': 'outputs/train_{}/nnet_e1_i400.p'.format('t2789415023871'),
|
| 62 |
+
'PM2': 'outputs/train_{}/nnet_e1_i400.p'.format('t2753954441900'),
|
| 63 |
+
'BM2': 'outputs/train_{}/nnet_e1_i400.p'.format('t3401216067518'),
|
| 64 |
+
'PR3': 'outputs/train_{}/nnet_e1_i400.p'.format('t2761370242287'),
|
| 65 |
+
'BR3': 'outputs/train_{}/nnet_e1_i400.p'.format('t2779114903467'),
|
| 66 |
+
'PM3': 'outputs/train_{}/nnet_e1_i400.p'.format('t2756013734745'),
|
| 67 |
+
'BM3': 'outputs/train_{}/nnet_e1_i400.p'.format('t3471903174862')
|
| 68 |
+
}
|
| 69 |
+
modelFile = modelList[tag]
|
| 70 |
+
print('Tag: {}, ModelFile: {}'.format(tag, modelFile))
|
| 71 |
+
print('ProcCount: {}'.format(proc_count))
|
| 72 |
+
_dump = True
|
| 73 |
+
if _dump:
|
| 74 |
+
_outFile = 'outputs/dump_predictions/{}_NLoss'.format(tag)
|
| 75 |
+
else:
|
| 76 |
+
_outFile = None
|
| 77 |
+
print('OutFile: ', _outFile)
|
| 78 |
+
|
| 79 |
+
# Backup the model file
|
| 80 |
+
copyfile(modelFile, modelFile + '.bk')
|
| 81 |
+
|
| 82 |
+
# Create Queue, Result array
|
| 83 |
+
queue = mp.Queue()
|
| 84 |
+
result_arr = []
|
| 85 |
+
|
| 86 |
+
print('Source: ', '../NewData/{}/'.format(ho_folders[tag]))
|
| 87 |
+
# Start 6 workers - 8 slows down the pc
|
| 88 |
+
# proc_count = 4
|
| 89 |
+
procs = [None]*proc_count
|
| 90 |
+
for i in range(proc_count):
|
| 91 |
+
vpid = i
|
| 92 |
+
procs[i] = mp.Process(target = TestPool_Unit.pooled_Test, args = \
|
| 93 |
+
(modelFile, vpid, queue, '../NewData/{}/'.format(ho_folders[tag]), int(9600/proc_count), _dump, _outFile))
|
| 94 |
+
# Start Processes
|
| 95 |
+
for i in range(proc_count):
|
| 96 |
+
procs[i].start()
|
| 97 |
+
|
| 98 |
+
# Fetch partial results
|
| 99 |
+
stillRunning = True
|
| 100 |
+
printer_timer = 100
|
| 101 |
+
while stillRunning:
|
| 102 |
+
stillRunning = False
|
| 103 |
+
for i in range(proc_count):
|
| 104 |
+
p = procs[i]
|
| 105 |
+
# print('Process with\t vpid: {}\t ->\t pid: {}\t ->\t running status: {}'.format(i, p.pid, p.is_alive()))
|
| 106 |
+
if p.is_alive():
|
| 107 |
+
stillRunning = True
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
if printer_timer == 0:
|
| 111 |
+
printer_timer = 100
|
| 112 |
+
while not queue.empty():
|
| 113 |
+
result_arr.append(queue.get())
|
| 114 |
+
# Evaluate results till now
|
| 115 |
+
if len(result_arr) > 0:
|
| 116 |
+
Evaluate(result_arr)
|
| 117 |
+
|
| 118 |
+
printer_timer -= 1
|
| 119 |
+
|
| 120 |
+
time.sleep(1)
|
| 121 |
+
while not queue.empty():
|
| 122 |
+
result_arr.append(queue.get())
|
| 123 |
+
Evaluate(result_arr)
|
| 124 |
+
for i in range(proc_count):
|
| 125 |
+
procs[i].join()
|
| 126 |
+
def setArgs(_tag, _pc):
|
| 127 |
+
global proc_count, tag
|
| 128 |
+
tag = _tag
|
| 129 |
+
proc_count = _pc
|
| 130 |
+
print('Tag, ProcCount: {}, {}'.format(tag, proc_count))
|
| 131 |
+
|
| 132 |
+
if __name__ == '__main__':
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
#print('Number of arguments:', len(sys.argv), 'arguments.')
|
| 136 |
+
#print('Argument List:', str(sys.argv))
|
| 137 |
+
parser = OptionParser()
|
| 138 |
+
parser.add_option("-t", "--tag", dest="tag",
|
| 139 |
+
help="Tag for feature set to use", metavar="TAG")
|
| 140 |
+
parser.add_option("-p", "--procs", dest="proc_count", default = 4,
|
| 141 |
+
help="Number of child process", metavar="PROCS")
|
| 142 |
+
|
| 143 |
+
(options, args) = parser.parse_args()
|
| 144 |
+
|
| 145 |
+
options = vars(options)
|
| 146 |
+
_tag = options['tag']
|
| 147 |
+
if _tag is None:
|
| 148 |
+
raise Exception('None is tag')
|
| 149 |
+
pc = int(options['proc_count'])
|
| 150 |
+
setArgs(_tag, pc)
|
| 151 |
+
|
| 152 |
+
main()
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
|
dir/test_clique.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import multiprocessing as mp
|
| 2 |
+
import TestPool_Unit_clique
|
| 3 |
+
import csv
|
| 4 |
+
from shutil import copyfile
|
| 5 |
+
import numpy as np
|
| 6 |
+
import time
|
| 7 |
+
import sys
|
| 8 |
+
from optparse import OptionParser
|
| 9 |
+
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
rwfinal=0
|
| 12 |
+
pwfinal=0
|
| 13 |
+
rlfinal=0
|
| 14 |
+
plfinal=0
|
| 15 |
+
fnum=0
|
| 16 |
+
|
| 17 |
+
def Evaluate(result_arr):
|
| 18 |
+
print('Files Processed: ', len(result_arr))
|
| 19 |
+
recalls = []
|
| 20 |
+
recalls_of_word = []
|
| 21 |
+
precisions = []
|
| 22 |
+
precisions_of_words = []
|
| 23 |
+
fully_Correct_l = 0
|
| 24 |
+
fully_Correct_w = 0
|
| 25 |
+
for entry in result_arr:
|
| 26 |
+
(word_match, lemma_match, n_dcsWords, n_output_nodes) = entry
|
| 27 |
+
recalls.append(lemma_match/n_dcsWords)
|
| 28 |
+
recalls_of_word.append(word_match/n_dcsWords)
|
| 29 |
+
precisions.append(lemma_match/n_output_nodes)
|
| 30 |
+
precisions_of_words.append(word_match/n_output_nodes)
|
| 31 |
+
if lemma_match == n_dcsWords:
|
| 32 |
+
fully_Correct_l += 1
|
| 33 |
+
if word_match == n_dcsWords:
|
| 34 |
+
fully_Correct_w += 1
|
| 35 |
+
print('Avg. Micro Recall of Words: {}'.format(np.mean(np.array(recalls))))
|
| 36 |
+
print('Avg. Micro Recall of Word++s: {}'.format(np.mean(np.array(recalls_of_word))))
|
| 37 |
+
print('Avg. Micro Precision of Words: {}'.format(np.mean(np.array(precisions))))
|
| 38 |
+
print('Avg. Micro Precision of Word++s: {}'.format(np.mean(np.array(precisions_of_words))))
|
| 39 |
+
|
| 40 |
+
rl = np.mean(np.array(recalls))
|
| 41 |
+
pl = np.mean(np.array(precisions))
|
| 42 |
+
print('F-Score of Wordss: ', (2*pl*rl)/(pl+rl))
|
| 43 |
+
print('Fully Correct Wordwise: {}'.format(fully_Correct_l/len(recalls_of_word)))
|
| 44 |
+
print('Fully Correct Word++wise: {}'.format(fully_Correct_w/len(recalls_of_word)))
|
| 45 |
+
print('[{:0.2f}, {:0.2f}, {:0.2f}, {:0.2f}, {:0.2f}, {:0.2f}, {:0.2f}]'.format(100*np.mean(np.array(recalls)), 100*np.mean(np.array(recalls_of_word)), 100*np.mean(np.array(precisions)), \
|
| 46 |
+
100*np.mean(np.array(precisions_of_words)), 100*(2*pl*rl)/(pl+rl), 100*fully_Correct_l/len(recalls_of_word),\
|
| 47 |
+
100*fully_Correct_w/len(recalls_of_word)))
|
| 48 |
+
sys.stdout.flush()
|
| 49 |
+
|
| 50 |
+
tag = None
|
| 51 |
+
proc_count = 4
|
| 52 |
+
|
| 53 |
+
def main():
|
| 54 |
+
global proc_count, tag
|
| 55 |
+
ho_folders = {
|
| 56 |
+
'PR2': 'skt_dcs_DS.bz2_4K_pmi_rfe_heldout',
|
| 57 |
+
'BR2': 'skt_dcs_DS.bz2_4K_bigram_rfe_heldout',
|
| 58 |
+
'PM2': 'skt_dcs_DS.bz2_4K_pmi_mir_hel`dout',
|
| 59 |
+
'BM2': 'skt_dcs_DS.bz2_4K_bigram_mir_heldout',
|
| 60 |
+
'PR3': 'skt_dcs_DS.bz2_1L_pmi_rfe_heldout',
|
| 61 |
+
'BR3': 'skt_dcs_DS.bz2_1L_bigram_rfe_heldout',
|
| 62 |
+
'PM3': 'skt_dcs_DS.bz2_1L_pmi_mir_heldout2',
|
| 63 |
+
'BM3': 'skt_dcs_DS.bz2_1L_bigram_heldout'
|
| 64 |
+
}
|
| 65 |
+
modelList = {
|
| 66 |
+
'PR2': 'outputs/train_{}/nnet.p'.format('t8006684774222'),
|
| 67 |
+
'BR2': 'outputs/train_{}/nnet.p'.format('t7978761528557'),
|
| 68 |
+
'PM2': 'outputs/train_{}/nnet.p'.format('t7323235797178'),
|
| 69 |
+
'BM2': 'outputs/train_{}/nnet.p'.format('t7978754709018'),
|
| 70 |
+
'PR3': 'outputs/train_{}/nnet.p'.format('t8006711065860'),
|
| 71 |
+
'BR3': 'outputs/train_{}/nnet.p'.format('t8103694133496'),
|
| 72 |
+
'PM3': 'outputs/train_{}/nnet.p'.format('t8006607913382'),
|
| 73 |
+
'BM3': 'outputs/train_{}/nnet.p'.format('t7274036680592')
|
| 74 |
+
}
|
| 75 |
+
modelFile = modelList[tag]
|
| 76 |
+
print('Tag: {}, ModelFile: {}'.format(tag, modelFile))
|
| 77 |
+
print('ProcCount: {}'.format(proc_count))
|
| 78 |
+
_dump = True
|
| 79 |
+
if _dump:
|
| 80 |
+
_outFile = 'outputs/{}_NLoss'.format(tag)
|
| 81 |
+
else:
|
| 82 |
+
_outFile = None
|
| 83 |
+
print('OutFile: ', _outFile)
|
| 84 |
+
|
| 85 |
+
# Backup the model file
|
| 86 |
+
copyfile(modelFile, modelFile + '.bk')
|
| 87 |
+
|
| 88 |
+
# Create Queue, Result array
|
| 89 |
+
queue = mp.Queue()
|
| 90 |
+
result_arr = []
|
| 91 |
+
|
| 92 |
+
print('Source: ', '../wordsegmentation/{}/'.format(ho_folders[tag]))
|
| 93 |
+
# Start 6 workers - 8 slows down the pc
|
| 94 |
+
# proc_count = 4
|
| 95 |
+
procs = [None]*proc_count
|
| 96 |
+
for i in range(proc_count):
|
| 97 |
+
vpid = i
|
| 98 |
+
procs[i] = mp.Process(target = TestPool_Unit_clique.pooled_Test, args = \
|
| 99 |
+
(modelFile, vpid, queue, '../wordsegmentation/{}/'.format(ho_folders[tag]), int(9600/proc_count), _dump, _outFile))
|
| 100 |
+
# Start Processes
|
| 101 |
+
for i in range(proc_count):
|
| 102 |
+
procs[i].start()
|
| 103 |
+
|
| 104 |
+
# Fetch partial results
|
| 105 |
+
stillRunning = True
|
| 106 |
+
printer_timer = 100
|
| 107 |
+
while stillRunning:
|
| 108 |
+
stillRunning = False
|
| 109 |
+
for i in range(proc_count):
|
| 110 |
+
p = procs[i]
|
| 111 |
+
# print('Process with\t vpid: {}\t ->\t pid: {}\t ->\t running status: {}'.format(i, p.pid, p.is_alive()))
|
| 112 |
+
if p.is_alive():
|
| 113 |
+
stillRunning = True
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
if printer_timer == 0:
|
| 117 |
+
printer_timer = 100
|
| 118 |
+
while not queue.empty():
|
| 119 |
+
result_arr.append(queue.get())
|
| 120 |
+
# Evaluate results till now
|
| 121 |
+
if len(result_arr) > 0:
|
| 122 |
+
Evaluate(result_arr)
|
| 123 |
+
|
| 124 |
+
printer_timer -= 1
|
| 125 |
+
|
| 126 |
+
time.sleep(1)
|
| 127 |
+
while not queue.empty():
|
| 128 |
+
result_arr.append(queue.get())
|
| 129 |
+
Evaluate(result_arr)
|
| 130 |
+
for i in range(proc_count):
|
| 131 |
+
procs[i].join()
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def setArgs(_tag, _pc):
|
| 135 |
+
global proc_count, tag
|
| 136 |
+
tag = _tag
|
| 137 |
+
proc_count = _pc
|
| 138 |
+
print('Tag, ProcCount: {}, {}'.format(tag, proc_count))
|
| 139 |
+
|
| 140 |
+
if __name__ == '__main__':
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
parser = OptionParser()
|
| 144 |
+
parser.add_option("-t", "--tag", dest="tag",
|
| 145 |
+
help="Tag for feature set to use", metavar="TAG")
|
| 146 |
+
parser.add_option("-p", "--procs", dest="proc_count", default = 4,
|
| 147 |
+
help="Number of child process", metavar="PROCS")
|
| 148 |
+
|
| 149 |
+
(options, args) = parser.parse_args()
|
| 150 |
+
|
| 151 |
+
options = vars(options)
|
| 152 |
+
_tag = options['tag']
|
| 153 |
+
if _tag is None:
|
| 154 |
+
raise Exception('None is tag')
|
| 155 |
+
pc = int(options['proc_count'])
|
| 156 |
+
setArgs(_tag, pc)
|
| 157 |
+
|
| 158 |
+
main()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
|
dir/unpack.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from Train_clique import *
|
| 2 |
+
from heap_n_clique import *
|
| 3 |
+
from nnet import *
|
| 4 |
+
from TestPool_Unit_clique import *
|
| 5 |
+
from sentences import *
|
| 6 |
+
|
| 7 |
+
bz2_input_folder = '../NewData/skt_dcs_DS.bz2_4K_bigram_mir_10K/' #bm2
|
| 8 |
+
# bz2_input_folder = '../NewData/skt_dcs_DS.bz2_1L_bigram_mir_10K/' #bm3
|
| 9 |
+
# bz2_input_folder = '../NewData/skt_dcs_DS.bz2_4K_bigram_rfe_10K/' #br2
|
| 10 |
+
# bz2_input_folder = '../NewData/skt_dcs_DS.bz2_1L_bigram_rfe_10K/' #br3
|
| 11 |
+
# bz2_input_folder = '../NewData/skt_dcs_DS.bz2_4K_pmi_mir_10K/' #pm2
|
| 12 |
+
# bz2_input_folder = '../NewData/skt_dcs_DS.bz2_1L_pmi_mir_10K2/' #pm3
|
| 13 |
+
# bz2_input_folder = '../NewData/skt_dcs_DS.bz2_4K_pmi_rfe_10K/' #pr2
|
| 14 |
+
# bz2_input_folder = '../NewData/skt_dcs_DS.bz2_1L_pmi_rfe_10K/' #pr3
|
| 15 |
+
loaded_SKT = pickle.load(open('../Simultaneous_CompatSKT_10K.p', 'rb'), encoding=u'utf-8')
|
| 16 |
+
loaded_DCS = pickle.load(open('../Simultaneous_DCS_10K.p', 'rb'), encoding=u'utf-8')
|
| 17 |
+
|
| 18 |
+
dsbz2_name = '4442.ds.bz2'
|
| 19 |
+
|
| 20 |
+
(nodelist_correct, conflicts_Dict_correct, featVMat_correct, nodelist_to_correct_mapping,\
|
| 21 |
+
nodelist, conflicts_Dict, featVMat) = open_dsbz2(bz2_input_folder + dsbz2_name)
|
| 22 |
+
|
| 23 |
+
# print(nodelist_correct)
|
| 24 |
+
# print(nodelist)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
sentenceObj = loaded_SKT['4442.p2']
|
| 29 |
+
|
| 30 |
+
# SeeSentence(sentenceObj)
|
| 31 |
+
WScalarMat_correct = Get_W_Scalar_Matrix_from_FeatVect_Matrix(featVMat_correct, nodelist_correct,\
|
| 32 |
+
conflicts_Dict_correct, self.neuralnet)
|
| 33 |
+
source = 0
|
| 34 |
+
|
| 35 |
+
(min_st_gold_ndict, min_st_adj_gold_small, _) =MST(nodelist_correct, WScalarMat_correct, conflicts_Dict_correct, source)
|
| 36 |
+
energy_gold_max_ST = np.sum(WScalarMat_correct[min_st_adj_gold_small])
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
print(min_st_gold_ndict)
|
dir/utilities.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys as Sys
|
| 2 |
+
import pickle, re
|
| 3 |
+
import numpy as np
|
| 4 |
+
from romtoslp import *
|
| 5 |
+
|
| 6 |
+
# Print iterations progress
|
| 7 |
+
def printProgress (iteration, total, prefix = '', suffix = '', decimals = 2, barLength = 100):
|
| 8 |
+
"""
|
| 9 |
+
Call in a loop to create terminal progress bar
|
| 10 |
+
@params:
|
| 11 |
+
iteration - Required : current iteration (Int)
|
| 12 |
+
total - Required : total iterations (Int)
|
| 13 |
+
prefix - Optional : prefix string (Str)
|
| 14 |
+
suffix - Optional : suffix string (Str)
|
| 15 |
+
"""
|
| 16 |
+
filledLength = int(round(barLength * iteration / float(total)))
|
| 17 |
+
percents = round(100.00 * (iteration / float(total)), decimals)
|
| 18 |
+
bar = '#' * filledLength + '-' * (barLength - filledLength)
|
| 19 |
+
Sys.stdout.write('%s [%s] %s%s %s\r' % (prefix, bar, percents, '%', suffix)),
|
| 20 |
+
Sys.stdout.flush()
|
| 21 |
+
if iteration == total:
|
| 22 |
+
print("\n")
|
| 23 |
+
|
| 24 |
+
def pickleFixLoad(filename):
|
| 25 |
+
return pickle.load(open(filename, 'rb'), encoding=u'utf-8')
|
| 26 |
+
|
| 27 |
+
def validatePickleName(fName):
|
| 28 |
+
m = re.search("^[\w]*.p$", fName)
|
| 29 |
+
if m != None:
|
| 30 |
+
return(m.group(0))
|
| 31 |
+
return("")
|
| 32 |
+
|
| 33 |
+
sandhiRules = pickle.load(open('sandhiRules.p','rb'))
|
| 34 |
+
def CanCoExist_sandhi(p1, p2, name1, name2):
|
| 35 |
+
# P1 must be less than P2
|
| 36 |
+
# Just send it in the proper order
|
| 37 |
+
if(p1 < p2):
|
| 38 |
+
overlap = max((p1 + len(name1)) - p2, 0)
|
| 39 |
+
if overlap == 0:
|
| 40 |
+
return True
|
| 41 |
+
if overlap == 1 or overlap == 2:
|
| 42 |
+
# try:
|
| 43 |
+
p1 = (name1[len(name1) - overlap:len(name1):], name2[0])
|
| 44 |
+
p2 = (name1[-1], name2[0:overlap:])
|
| 45 |
+
# print(name1, name2, p1, p2)
|
| 46 |
+
# print(p1, p2)
|
| 47 |
+
if p1 in sandhiRules:
|
| 48 |
+
if(sandhiRules[p1]['length'] < len(p1[0]) + len(p1[1])):
|
| 49 |
+
return True
|
| 50 |
+
if p2 in sandhiRules:
|
| 51 |
+
if(sandhiRules[p2]['length'] < len(p2[0]) + len(p2[1])):
|
| 52 |
+
return True
|
| 53 |
+
# except IndexError:
|
| 54 |
+
# print('Sandhi function Error: arguments were', (p1, p2, name1, name2))
|
| 55 |
+
# return False
|
| 56 |
+
|
| 57 |
+
return False
|
| 58 |
+
|
| 59 |
+
def fix_w_new(word_new_obj):
|
| 60 |
+
# dicto= { 'asmad':'mad','yuzmad':'tvad','ayam':'idam','agn':'agni','ya':'yad','eza':'etad',
|
| 61 |
+
# 'tad':'sa','vd':'vid','va':'vE','-tva':'tva','ptta':'pitta','mahat':'mahant','ndra':'indra',
|
| 62 |
+
# 'ap':'api','h':'hi','t':'iti','tr':'tri','va':'iva'}
|
| 63 |
+
|
| 64 |
+
dicto= { 'asmad':'mad','yuzmad':'tvad','ayam':'idam','agn':'agni','ya':'yad','eza':'etad',
|
| 65 |
+
'vd':'vid','va':'vE','-tva':'tva','ptta':'pitta','mahat':'mahant','ndra':'indra',
|
| 66 |
+
'ap':'api','h':'hi','t':'iti','tr':'tri','va':'iva'}
|
| 67 |
+
|
| 68 |
+
for i in range(0,len(word_new_obj.lemmas)):
|
| 69 |
+
word_new_obj.lemmas[i]= rom_slp(word_new_obj.lemmas[i])
|
| 70 |
+
word_new_obj.lemmas[i]= word_new_obj.lemmas[i].split('_')[0]
|
| 71 |
+
|
| 72 |
+
if word_new_obj.lemmas[i] in dicto:
|
| 73 |
+
# print('CHANGED', word_new_obj.lemmas[i])
|
| 74 |
+
word_new_obj.lemmas[i]= dicto[word_new_obj.lemmas[i]]
|
| 75 |
+
|
| 76 |
+
if(word_new_obj.lemmas[i]== 'yad'):
|
| 77 |
+
if word_new_obj.names== 'yadi':
|
| 78 |
+
word_new_obj.lemmas[i]= 'yadi'
|
| 79 |
+
|
| 80 |
+
return(word_new_obj)
|
| 81 |
+
|
| 82 |
+
def FixSentence(sentenceObj):
|
| 83 |
+
for ci in range(len(sentenceObj.chunk)):
|
| 84 |
+
for pos in sentenceObj.chunk[ci].chunk_words.keys():
|
| 85 |
+
for wsi in range(len(sentenceObj.chunk[ci].chunk_words[pos])):
|
| 86 |
+
sentenceObj.chunk[ci].chunk_words[pos][wsi] = fix_w_new(sentenceObj.chunk[ci].chunk_words[pos][wsi])
|
| 87 |
+
|
| 88 |
+
return sentenceObj
|
| 89 |
+
|
| 90 |
+
def FillMissing(sentenceObj, dcsObj):
|
| 91 |
+
for ci in range(len(sentenceObj.chunk)):
|
| 92 |
+
corrLemmas = dcsObj.lemmas[ci]
|
| 93 |
+
cli = 0
|
| 94 |
+
iamdone = False
|
| 95 |
+
for pos in sentenceObj.chunk[ci].chunk_words.keys():
|
| 96 |
+
for wsi in range(len(sentenceObj.chunk[ci].chunk_words[pos])):
|
| 97 |
+
ws = sentenceObj.chunk[ci].chunk_words[pos][wsi]
|
| 98 |
+
for li in range(len(ws.lemmas)):
|
| 99 |
+
if ws.lemmas[li] == rom_slp(corrLemmas[cli]):
|
| 100 |
+
# print('MATCHED:', ws.lemmas[li], rom_slp(corrLemmas[cli]))
|
| 101 |
+
# print('CNG LIST:', ws.forms[li] if li < len(ws.forms) else [dcsObj.cng[ci][cli]])
|
| 102 |
+
|
| 103 |
+
if li >= len(ws.forms):
|
| 104 |
+
a = ['']*(li + 1)
|
| 105 |
+
for i in range(len(ws.forms)):
|
| 106 |
+
a[i] = ws.forms[i]
|
| 107 |
+
a[li] = int(dcsObj.cng[ci][cli])
|
| 108 |
+
sentenceObj.chunk[ci].chunk_words[pos][wsi].forms = a
|
| 109 |
+
|
| 110 |
+
cli += 1
|
| 111 |
+
if cli == len(corrLemmas):
|
| 112 |
+
iamdone = True
|
| 113 |
+
break
|
| 114 |
+
if iamdone:
|
| 115 |
+
break
|
| 116 |
+
if iamdone:
|
| 117 |
+
break
|
| 118 |
+
return sentenceObj
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# def loadSentence(fName, sntcPath):
|
| 123 |
+
# try:
|
| 124 |
+
# dcsObj = pickleFixLoad('../Text Segmentation/DCS_pick/' + fName)
|
| 125 |
+
# sentenceObj = pickleFixLoad(sntcPath)
|
| 126 |
+
# sentenceObj = FixSentence(sentenceObj)
|
| 127 |
+
# sentenceObj = FillMissing(sentenceObj, dcsObj)
|
| 128 |
+
# except (KeyError, EOFError, pickle.UnpicklingError) as e:
|
| 129 |
+
# print('Failed to load', sntcPath)
|
| 130 |
+
# return None, None
|
| 131 |
+
# return(sentenceObj, dcsObj)
|
| 132 |
+
|
| 133 |
+
def loadSentence_with_rom_slp(fName, sntcPath):
|
| 134 |
+
try:
|
| 135 |
+
try:
|
| 136 |
+
if fName[-1] == '2':
|
| 137 |
+
dcsObj = pickleFixLoad('../Text Segmentation/DCS_pick/' + fName[:-1])
|
| 138 |
+
else:
|
| 139 |
+
dcsObj = pickleFixLoad('../Text Segmentation/DCS_pick/' + fName)
|
| 140 |
+
except FileNotFoundError:
|
| 141 |
+
dcsObj = None
|
| 142 |
+
sentenceObj = pickleFixLoad(sntcPath)
|
| 143 |
+
sentenceObj = FixSentence(sentenceObj)
|
| 144 |
+
except (KeyError, EOFError, pickle.UnpicklingError) as e:
|
| 145 |
+
print('Failed to load', sntcPath)
|
| 146 |
+
return None, None
|
| 147 |
+
return(sentenceObj, dcsObj)
|
| 148 |
+
|
| 149 |
+
def loadSentence_nopre(fName, sntcPath):
|
| 150 |
+
try:
|
| 151 |
+
if fName[-1] == '2':
|
| 152 |
+
dcsObj = pickleFixLoad('../Text Segmentation/DCS_pick/' + fName[:-1])
|
| 153 |
+
else:
|
| 154 |
+
dcsObj = pickleFixLoad('../Text Segmentation/DCS_pick/' + fName)
|
| 155 |
+
sentenceObj = pickleFixLoad(sntcPath)
|
| 156 |
+
except (KeyError, EOFError, pickle.UnpicklingError) as e:
|
| 157 |
+
print('Failed to load', sntcPath)
|
| 158 |
+
return None, None
|
| 159 |
+
return(sentenceObj, dcsObj)
|
| 160 |
+
|
| 161 |
+
preList = pickle.load(open('pvb.p', 'rb'))
|
| 162 |
+
def removePrefix(lemma):
|
| 163 |
+
for pre in preList:
|
| 164 |
+
m = re.match(pre, lemma)
|
| 165 |
+
if(m != None):
|
| 166 |
+
s = m.span()
|
| 167 |
+
pat = lemma[s[0]:s[1]]
|
| 168 |
+
return (lemma.split(pat)[1])
|
| 169 |
+
return lemma
|
| 170 |
+
|
| 171 |
+
def GetSolutions(dcsObj):
|
| 172 |
+
solution = [rom_slp(c) for arr in dcsObj.lemmas for c in arr]
|
| 173 |
+
solution_no_pvb = [removePrefix(l) for l in solution]
|
| 174 |
+
return (solution, solution_no_pvb)
|
| 175 |
+
|
| 176 |
+
def Accuracy(prediction, dcsObj):
|
| 177 |
+
solution, solution_no_pvb = GetSolutions(dcsObj)
|
| 178 |
+
# print('Solution:', solution)
|
| 179 |
+
# print('Solution No Pvb:', solution_no_pvb)
|
| 180 |
+
ac = 0
|
| 181 |
+
for x in range(len(solution)):
|
| 182 |
+
if(solution[x] in prediction):
|
| 183 |
+
ac += 1
|
| 184 |
+
# elif(solution_no_pvb[x] in prediction):
|
| 185 |
+
# ac += 1
|
| 186 |
+
|
| 187 |
+
ac = 100*ac/len(solution)
|
| 188 |
+
return ac
|
| 189 |
+
|
| 190 |
+
def FullCoverage(skt, dcs):
|
| 191 |
+
# print('-'*40)
|
| 192 |
+
# print('NEW FILE RCVD')
|
| 193 |
+
goodFlag = True
|
| 194 |
+
for ci in range(len(dcs.lemmas)):
|
| 195 |
+
dlemmas = [rom_slp(l) for l in dcs.lemmas[ci]]
|
| 196 |
+
slemmas = []
|
| 197 |
+
chunk = skt.chunk[ci]
|
| 198 |
+
for pos in chunk.chunk_words.keys():
|
| 199 |
+
for wsi in range(len(chunk.chunk_words[pos])):
|
| 200 |
+
ws = chunk.chunk_words[pos][wsi]
|
| 201 |
+
[slemmas.append(wsl) for wsl in ws.lemmas]
|
| 202 |
+
# print('DCS:', dlemmas)
|
| 203 |
+
# print('SKT:', slemmas)
|
| 204 |
+
for l in dlemmas:
|
| 205 |
+
if l not in slemmas:
|
| 206 |
+
# print(l, 'not found')
|
| 207 |
+
goodFlag = False
|
| 208 |
+
break
|
| 209 |
+
if not goodFlag:
|
| 210 |
+
break
|
| 211 |
+
# print(goodFlag)
|
| 212 |
+
return goodFlag
|
| 213 |
+
|
| 214 |
+
def GetFeatNameSet():
|
| 215 |
+
mat_cngCount_1D = pickle.load(open('../NewData/gauravs/Temporary_1D/mat_cngCount_1D.p', 'rb'), encoding = u'utf-8')
|
| 216 |
+
|
| 217 |
+
_full_cnglist = list(mat_cngCount_1D)
|
| 218 |
+
_cg_count = len(mat_cngCount_1D)
|
| 219 |
+
|
| 220 |
+
feats = {}
|
| 221 |
+
fIndex = 0
|
| 222 |
+
feats[fIndex] = ('L', 'L'); fIndex += 1;
|
| 223 |
+
feats[fIndex] = ('L', 'C'); fIndex += 1;
|
| 224 |
+
feats[fIndex] = ('L', 'T'); fIndex += 1;
|
| 225 |
+
|
| 226 |
+
feats[fIndex] = ('C', 'L'); fIndex += 1;
|
| 227 |
+
feats[fIndex] = ('C', 'C'); fIndex += 1;
|
| 228 |
+
feats[fIndex] = ('C', 'T'); fIndex += 1;
|
| 229 |
+
|
| 230 |
+
feats[fIndex] = ('T', 'L'); fIndex += 1;
|
| 231 |
+
feats[fIndex] = ('T', 'C'); fIndex += 1;
|
| 232 |
+
feats[fIndex] = ('T', 'T'); fIndex += 1;
|
| 233 |
+
|
| 234 |
+
# Path Constraint - Length 2 - # _cg_count
|
| 235 |
+
|
| 236 |
+
# LEMMA->CNG->LEMMA
|
| 237 |
+
for k in range(0, _cg_count):
|
| 238 |
+
cng_k = _full_cnglist[k]
|
| 239 |
+
feats[fIndex + k] = ('L', cng_k, 'L')
|
| 240 |
+
fIndex += _cg_count
|
| 241 |
+
|
| 242 |
+
# LEMMA->CNG->CNG
|
| 243 |
+
for k in range(0, _cg_count):
|
| 244 |
+
cng_k = _full_cnglist[k]
|
| 245 |
+
feats[fIndex + k] = ('L', cng_k, 'C')
|
| 246 |
+
fIndex += _cg_count
|
| 247 |
+
|
| 248 |
+
# LEMMA->CNG->TUP
|
| 249 |
+
for k in range(0, _cg_count):
|
| 250 |
+
cng_k = _full_cnglist[k]
|
| 251 |
+
feats[fIndex + k] = ('L', cng_k, 'T')
|
| 252 |
+
fIndex += _cg_count
|
| 253 |
+
|
| 254 |
+
# CNG->CNG->LEMMA
|
| 255 |
+
for k in range(0, _cg_count):
|
| 256 |
+
cng_k = _full_cnglist[k]
|
| 257 |
+
feats[fIndex + k] = ('C', cng_k, 'L')
|
| 258 |
+
fIndex += _cg_count
|
| 259 |
+
|
| 260 |
+
# CNG->CNG->CNG
|
| 261 |
+
for k in range(0, _cg_count):
|
| 262 |
+
cng_k = _full_cnglist[k]
|
| 263 |
+
feats[fIndex + k] = ('C', cng_k, 'C')
|
| 264 |
+
fIndex += _cg_count
|
| 265 |
+
|
| 266 |
+
# CNG->CNG->TUP
|
| 267 |
+
for k in range(0, _cg_count):
|
| 268 |
+
cng_k = _full_cnglist[k]
|
| 269 |
+
feats[fIndex + k] = ('C', cng_k, 'T')
|
| 270 |
+
fIndex += _cg_count
|
| 271 |
+
|
| 272 |
+
# TUP->CNG->LEMMA :: TOO MANY ZEROS
|
| 273 |
+
for k in range(0, _cg_count):
|
| 274 |
+
cng_k = _full_cnglist[k]
|
| 275 |
+
feats[fIndex + k] = ('T', cng_k, 'L')
|
| 276 |
+
fIndex += _cg_count
|
| 277 |
+
|
| 278 |
+
# TUP->CNG->CNG :: TOO MANY ZEROS
|
| 279 |
+
for k in range(0, _cg_count):
|
| 280 |
+
cng_k = _full_cnglist[k]
|
| 281 |
+
feats[fIndex + k] = ('T', cng_k, 'C')
|
| 282 |
+
fIndex += _cg_count
|
| 283 |
+
|
| 284 |
+
# TUP->CNG->TUP :: TOO MANY ZEROS
|
| 285 |
+
for k in range(0, _cg_count):
|
| 286 |
+
cng_k = _full_cnglist[k]
|
| 287 |
+
feats[fIndex + k] = ('T', cng_k, 'T')
|
| 288 |
+
fIndex += _cg_count
|
| 289 |
+
|
| 290 |
+
# Path Constraint - Length 3 - # _cg_count^2
|
| 291 |
+
|
| 292 |
+
# LEMMA->CGS->CGS->LEMMA
|
| 293 |
+
for k1 in range(0, _cg_count):
|
| 294 |
+
cng_k1 = _full_cnglist[k1]
|
| 295 |
+
for k2 in range(0, _cg_count):
|
| 296 |
+
cng_k2 = _full_cnglist[k2]
|
| 297 |
+
feats[fIndex + k1*_cg_count + k2] = ('L', cng_k1, cng_k2, 'L')
|
| 298 |
+
fIndex += _cg_count**2
|
| 299 |
+
|
| 300 |
+
# LEMMA->CGS->CGS->TUP
|
| 301 |
+
for k1 in range(0, _cg_count):
|
| 302 |
+
cng_k1 = _full_cnglist[k1]
|
| 303 |
+
for k2 in range(0, _cg_count):
|
| 304 |
+
cng_k2 = _full_cnglist[k2]
|
| 305 |
+
feats[fIndex + k1*_cg_count + k2] = ('L', cng_k1, cng_k2, 'T')
|
| 306 |
+
fIndex += _cg_count**2
|
| 307 |
+
|
| 308 |
+
# TUP->CGS->CGS->LEM
|
| 309 |
+
for k1 in range(0, _cg_count):
|
| 310 |
+
cng_k1 = _full_cnglist[k1]
|
| 311 |
+
for k2 in range(0, _cg_count):
|
| 312 |
+
cng_k2 = _full_cnglist[k2]
|
| 313 |
+
feats[fIndex + k1*_cg_count + k2] = ('T', cng_k1, cng_k2, 'L')
|
| 314 |
+
fIndex += _cg_count**2
|
| 315 |
+
|
| 316 |
+
# TUP->CGS->CGS->TUP
|
| 317 |
+
for k1 in range(0, _cg_count):
|
| 318 |
+
cng_k1 = _full_cnglist[k1]
|
| 319 |
+
for k2 in range(0, _cg_count):
|
| 320 |
+
cng_k2 = _full_cnglist[k2]
|
| 321 |
+
feats[fIndex + k1*_cg_count + k2] = ('T', cng_k1, cng_k2, 'T')
|
| 322 |
+
fIndex += _cg_count**2
|
| 323 |
+
return feats
|
dir/verbs_vs_cngs_matrix_countonly.p
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dir/weighted.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# coding: utf-8
|
| 3 |
+
|
| 4 |
+
# In[1]:
|
| 5 |
+
|
| 6 |
+
import pandas,sys
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# In[7]:
|
| 10 |
+
fils = {
|
| 11 |
+
'BM3' : ["BM3_NLoss_proc0.csv","BM3_NLoss_proc2.csv","BM3_NLoss_proc1.csv","BM3_NLoss_proc3.csv"],
|
| 12 |
+
'BM2' : ["BM2_NLoss_proc0.csv","BM2_NLoss_proc2.csv","BM2_NLoss_proc1.csv","BM2_NLoss_proc3.csv"],
|
| 13 |
+
'BR2' : ["BR2_NLoss_proc0.csv","BR2_NLoss_proc2.csv","BR2_NLoss_proc1.csv","BR2_NLoss_proc3.csv"],
|
| 14 |
+
'BR3' : ["BR3_NLoss_proc0.csv","BR3_NLoss_proc2.csv","BR3_NLoss_proc1.csv","BR3_NLoss_proc3.csv"],
|
| 15 |
+
'PM2' : ["PM2_NLoss_proc0.csv","PM2_NLoss_proc2.csv","PM2_NLoss_proc1.csv","PM2_NLoss_proc3.csv"],
|
| 16 |
+
'PM3' : ["PM3_NLoss_proc0.csv","PM3_NLoss_proc2.csv","PM3_NLoss_proc1.csv","PM3_NLoss_proc3.csv"],
|
| 17 |
+
'PR2' : ["PR2_NLoss_proc0.csv","PR2_NLoss_proc2.csv","PR2_NLoss_proc1.csv","PR2_NLoss_proc3.csv"],
|
| 18 |
+
'PR3' : ["PR3_NLoss_proc0.csv","PR3_NLoss_proc2.csv","PR3_NLoss_proc1.csv","PR3_NLoss_proc3.csv"]
|
| 19 |
+
}
|
| 20 |
+
import pandas
|
| 21 |
+
from collections import defaultdict
|
| 22 |
+
|
| 23 |
+
def predLoss(tag):
|
| 24 |
+
gt = defaultdict(dict)
|
| 25 |
+
|
| 26 |
+
for item in fils[tag]:
|
| 27 |
+
fil = open('outputs/'+str(item)).read().splitlines()
|
| 28 |
+
for i,line in enumerate(fil):
|
| 29 |
+
if i % 6 == 0:
|
| 30 |
+
setCol = line.split(',')
|
| 31 |
+
gt[setCol[0]]['predLemma'] = setCol[1:]
|
| 32 |
+
if i%6 == 1:
|
| 33 |
+
gt[setCol[0]]['predCNG'] = line.split(',')[1:]
|
| 34 |
+
if len(gt[setCol[0]]['predLemma']) != len(gt[setCol[0]]['predCNG']):
|
| 35 |
+
print(gt[setCol[0]])
|
| 36 |
+
if i%6 == 2:
|
| 37 |
+
gt[setCol[0]]['chunkID'] = line.split(',')[1:]
|
| 38 |
+
if len(gt[setCol[0]]['predLemma']) != len(gt[setCol[0]]['chunkID']):
|
| 39 |
+
print(gt[setCol[0]])
|
| 40 |
+
if i%6 == 3:
|
| 41 |
+
gt[setCol[0]]['chunkIDCNG'] = line.split(',')[1:]
|
| 42 |
+
if len(gt[setCol[0]]['predLemma']) != len(gt[setCol[0]]['chunkIDCNG']):
|
| 43 |
+
print(gt[setCol[0]])
|
| 44 |
+
if i%6 == 4:
|
| 45 |
+
gt[setCol[0]]['idInNodeID'] = line.split(',')[1:]
|
| 46 |
+
if len(gt[setCol[0]]['predLemma']) != len(gt[setCol[0]]['idInNodeID']):
|
| 47 |
+
print(gt[setCol[0]])
|
| 48 |
+
if i%6 == 5:
|
| 49 |
+
gt[setCol[0]]['params'] = line.split(',')[1:]
|
| 50 |
+
|
| 51 |
+
if line.split(',')[0] != setCol[0]:
|
| 52 |
+
print(i,setCol,line)
|
| 53 |
+
print('breakin')
|
| 54 |
+
break
|
| 55 |
+
return gt
|
| 56 |
+
|
| 57 |
+
def pdframe(gt):
|
| 58 |
+
params = defaultdict(dict)
|
| 59 |
+
for item in gt.keys():
|
| 60 |
+
tatkal = gt[item]['params']
|
| 61 |
+
params[item]['corrWords'],params[item]['corrLemma'] = int(tatkal[0]),int(tatkal[1])
|
| 62 |
+
params[item]['dcsSize'],params[item]['predictions'] = int(tatkal[2]),int(tatkal[3])
|
| 63 |
+
params[item]['wordPrec'] = params[item]['corrWords']*1.0/params[item]['predictions']
|
| 64 |
+
params[item]['wordReca'] = params[item]['corrWords']*1.0/params[item]['dcsSize']
|
| 65 |
+
params[item]['lemmaPrec'] = params[item]['corrLemma']*1.0/params[item]['predictions']
|
| 66 |
+
params[item]['lemmaReca'] = params[item]['corrLemma']*1.0/params[item]['dcsSize']
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
initRes = pandas.DataFrame.from_dict(params,orient='index')
|
| 70 |
+
return initRes
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# In[8]:
|
| 74 |
+
|
| 75 |
+
if(len(sys.argv)<2):
|
| 76 |
+
print("Provide an argument for the feature to be evaluated")
|
| 77 |
+
|
| 78 |
+
else:
|
| 79 |
+
BM2gt = predLoss(str(sys.argv[1]))
|
| 80 |
+
BM2pd = pdframe(BM2gt)
|
| 81 |
+
print(BM2pd.mean())
|
dir/wordTypeCheckFunction.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def splitTillPeriod(config,listInput): #see that config is not empty and is of type string
|
| 2 |
+
#returns config sans first part and firstpart is appended to listInput
|
| 3 |
+
|
| 4 |
+
configList=list(config)
|
| 5 |
+
out=''
|
| 6 |
+
periodIndex=0
|
| 7 |
+
val=''
|
| 8 |
+
for i,val in enumerate(configList):
|
| 9 |
+
a=2
|
| 10 |
+
periodIndex=i
|
| 11 |
+
if val=='.':
|
| 12 |
+
break
|
| 13 |
+
if val!=" ":
|
| 14 |
+
out=out+val;
|
| 15 |
+
if val!=".":
|
| 16 |
+
config1=" ".join(config.split())
|
| 17 |
+
listInput.append(config1)
|
| 18 |
+
return ""
|
| 19 |
+
else:
|
| 20 |
+
config1="".join(configList[(periodIndex+1):])
|
| 21 |
+
listInput.append(out)
|
| 22 |
+
return config1
|
| 23 |
+
|
| 24 |
+
def wordTypeCheck(form,config):
|
| 25 |
+
#if it is noun Im assuming it has 3 parts
|
| 26 |
+
#form is noun or verb or...
|
| 27 |
+
|
| 28 |
+
# print(form, config)
|
| 29 |
+
|
| 30 |
+
nounMapping={28: 'xt?', 29: 'Nom. sg. masc.', 30: 'Nom. sg. fem.', 31: 'Nom. sg. neutr.', 32: 'Nom. sg. adj.', 33: 'xt?', 34: 'Nom. du. masc.', 35: 'Nom. du. fem.', 36: 'Nom. du. neutr.', 37: 'Nom. du. adj.', 38: 'xt?', 39: 'Nom. pl. masc.', 40: 'Nom. pl. fem.', 41: 'Nom. pl. neutr.', 42: 'Nom. pl. adj.', 48: 'xt?', 49: 'Voc. sg. masc.', 50: 'Voc. sg. fem.', 51: 'Voc. sg. neutr.', 54: 'Voc. du. masc.', 55: 'Voc. du. fem.', 56: 'Voc. du. neutr.', 58: 'xt?', 59: 'Voc. pl. masc.', 60: 'Voc. pl. fem.', 61: 'Voc. pl. neutr.', 68: 'xt?', 69: 'Acc. sg. masc.', 70: 'Acc. sg. fem.', 71: 'Acc. sg. neutr.', 72: 'Acc. sg. adj.', 73: 'xt?', 74: 'Acc. du. masc.', 75: 'Acc. du. fem.', 76: 'Acc. du. neutr.', 77: 'Acc. du. adj.', 78: 'xt?', 79: 'Acc. pl. masc.', 80: 'Acc. pl. fem.', 81: 'Acc. pl. neutr.', 82: 'Acc. pl. adj.', 88: 'xt?', 89: 'Instr. sg. masc.', 90: 'Instr. sg. fem.', 91: 'Instr. sg. neutr.', 92: 'Instr. sg. adj.', 93: 'xt?', 94: 'Instr. du. masc.', 95: 'Instr. du. fem.', 96: 'Instr. du. neutr.', 97: 'Instr. du. adj.', 98: 'xt?', 99: 'Instr. pl. masc.', 100: 'Instr. pl. fem.', 101: 'Instr. pl. neutr.', 102: 'Instr. pl. adj.', 108: 'xt?', 109: 'Dat. sg. masc.', 110: 'Dat. sg. fem.', 111: 'Dat. sg. neutr.', 112: 'Dat. sg. adj.', 114: 'Dat. du. masc.', 115: 'Dat. du. fem.', 116: 'Dat. du. neutr.', 117: 'Dat. du. adj.', 118: 'xt?', 119: 'Dat. pl. masc.', 120: 'Dat. pl. fem.', 121: 'Dat. pl. neutr.', 122: 'Dat. pl. adj.', 128: 'xt?', 129: 'Abl. sg. masc.', 130: 'Abl. sg. fem.', 131: 'Abl. sg. neutr.', 132: 'Abl. sg. adj.', 134: 'Abl. du. masc.', 135: 'Abl. du. fem.', 136: 'Abl. du. neutr.', 137: 'Abl. du. adj.', 138: 'xt?', 139: 'Abl. pl. masc.', 140: 'Abl. pl. fem.', 141: 'Abl. pl. neutr.', 142: 'Abl. pl. adj.', 148: 'xt?', 149: 'Gen. sg. masc.', 150: 'Gen. sg. fem.', 151: 'Gen. sg. neutr.', 152: 'Gen. sg. adj.', 153: 'xt?', 154: 'Gen. du. masc.', 155: 'Gen. du. fem.', 156: 'Gen. du. neutr.', 157: 'Gen. du. adj.', 158: 'xt?', 159: 'Gen. pl. masc.', 160: 'Gen. pl. fem.', 161: 'Gen. pl. neutr.', 162: 'Gen. pl. adj.', 168: 'xt?', 169: 'Loc. sg. masc.', 170: 'Loc. sg. fem.', 171: 'Loc. sg. neutr.', 172: 'Loc. sg. adj.', 173: 'xt?', 174: 'Loc. du. masc.', 175: 'Loc. du. fem.', 176: 'Loc. du. neutr.', 177: 'Loc. du. adj.', 178: 'xt?', 179: 'Loc. pl. masc.', 180: 'Loc. pl. fem.', 181: 'Loc. pl. neutr.', 182: 'Loc. pl. adj.', }
|
| 31 |
+
verbMapping1={1: 'pr. [*] ac.', 2: 'opt. [*] ac.', 3: 'imp. [*] ac', 4: 'impft. [*] ac.', 5: 'fut. ac/ps.', 6: 'cond. ac/ps.', 7: 'per. fut. ac/ps.', 8: 'aor. [1] ac/ps.', 9: 'aor. [2] ac/ps.', 10: 'aor. [3] ac/ps.', 11: 'aor. [4] ac/ps.', 12: 'aor. [5] ac/ps.', 13: 'aor. [7] ac/ps.', 14: 'ben. ac/ps.', 15: 'pft. ac.', 16: 'per. pft.', 19: 'pp.', 20: 'ppa.', 21: 'pfp.', 22: 'inf.', 23: 'abs.', 24: 'pr. ps.', 26: 'imp. ps.', 27: 'impft. ps.', 28: 'aor. ps.', 29: 'opt. ps.', 30: 'ou', }
|
| 32 |
+
verbMapping2={1: 'sg. 1', 2: 'sg. 2', 3: 'sg. 3', 4: 'du. 1', 5: 'du. 2', 6: 'du. 3', 7: 'pl. 1', 8: 'pl. 2', 9: 'pl. 3', }
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if form=='indeclinable':
|
| 37 |
+
if config=='part.':
|
| 38 |
+
return 2
|
| 39 |
+
elif config=='conj.':
|
| 40 |
+
return 2
|
| 41 |
+
elif config=='abs.':
|
| 42 |
+
return -230
|
| 43 |
+
elif config=='prep.':
|
| 44 |
+
return 2
|
| 45 |
+
elif config=='ind.':
|
| 46 |
+
return 2
|
| 47 |
+
elif config=='ca. abs.':
|
| 48 |
+
return -230
|
| 49 |
+
else:
|
| 50 |
+
return 'config is invalid'
|
| 51 |
+
|
| 52 |
+
elif form=='compound':
|
| 53 |
+
if config=='iic.':
|
| 54 |
+
return 3
|
| 55 |
+
elif config=='iiv.':
|
| 56 |
+
return 3
|
| 57 |
+
else:
|
| 58 |
+
return 'config is invalid'
|
| 59 |
+
|
| 60 |
+
elif form=='undetermined':
|
| 61 |
+
if config=='adv.':
|
| 62 |
+
return 2
|
| 63 |
+
elif config=='und.':
|
| 64 |
+
return 1
|
| 65 |
+
elif config=='tasil':
|
| 66 |
+
return 1
|
| 67 |
+
else:
|
| 68 |
+
return 'config is invalid'
|
| 69 |
+
|
| 70 |
+
elif form=='noun':
|
| 71 |
+
# print("entered noun")
|
| 72 |
+
config1=config
|
| 73 |
+
x=[]
|
| 74 |
+
config1=splitTillPeriod(config1,x)
|
| 75 |
+
one=x[0]
|
| 76 |
+
x=[]
|
| 77 |
+
config1=splitTillPeriod(config1,x)
|
| 78 |
+
two=x[0]
|
| 79 |
+
x=[]
|
| 80 |
+
config1=splitTillPeriod(config1,x)
|
| 81 |
+
three=x[0]
|
| 82 |
+
|
| 83 |
+
isAdj=0
|
| 84 |
+
if three=='*':
|
| 85 |
+
three='n'
|
| 86 |
+
isAdj=1
|
| 87 |
+
|
| 88 |
+
for i in nounMapping.keys():
|
| 89 |
+
if one!='i'and one!='g':
|
| 90 |
+
if one[len(one)-2:] in nounMapping[i]:
|
| 91 |
+
if two in nounMapping[i]:
|
| 92 |
+
|
| 93 |
+
if three in nounMapping[i]:
|
| 94 |
+
if(isAdj==0):
|
| 95 |
+
return i
|
| 96 |
+
else:
|
| 97 |
+
return i+1
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
elif one=='i':
|
| 101 |
+
if 'Instr' in nounMapping[i]:
|
| 102 |
+
if two in nounMapping[i]:
|
| 103 |
+
|
| 104 |
+
if three=='n':
|
| 105 |
+
if 'neutr' in nounMapping[i]:
|
| 106 |
+
if(isAdj==0):
|
| 107 |
+
return i
|
| 108 |
+
else:
|
| 109 |
+
return i+1
|
| 110 |
+
|
| 111 |
+
elif three in nounMapping[i]:
|
| 112 |
+
return i
|
| 113 |
+
elif one=='g':
|
| 114 |
+
if 'Gen' in nounMapping[i]:
|
| 115 |
+
if two in nounMapping[i]:
|
| 116 |
+
|
| 117 |
+
if three=='n':
|
| 118 |
+
if 'neutr' in nounMapping[i]:
|
| 119 |
+
if(isAdj==0):
|
| 120 |
+
return i
|
| 121 |
+
else:
|
| 122 |
+
return i+1
|
| 123 |
+
|
| 124 |
+
elif three in nounMapping[i]:
|
| 125 |
+
return i
|
| 126 |
+
|
| 127 |
+
elif form=='verb':
|
| 128 |
+
unit=0
|
| 129 |
+
ten=0
|
| 130 |
+
#to remove ca des int
|
| 131 |
+
x=[]
|
| 132 |
+
configActual=config
|
| 133 |
+
config=splitTillPeriod(config,x)
|
| 134 |
+
if(x[0]=='ca' or x[0]=='des' or x[0]=='int'):
|
| 135 |
+
y=2 #do nothing
|
| 136 |
+
else:
|
| 137 |
+
config=configActual
|
| 138 |
+
#if [vn.] is present
|
| 139 |
+
if 'vn.' in config:
|
| 140 |
+
config=config.replace('vn.','')
|
| 141 |
+
|
| 142 |
+
x=[]
|
| 143 |
+
config=splitTillPeriod(config,x)
|
| 144 |
+
|
| 145 |
+
one=x[0]
|
| 146 |
+
two=''
|
| 147 |
+
three=''
|
| 148 |
+
ONE=''
|
| 149 |
+
TWO=''
|
| 150 |
+
|
| 151 |
+
if config!='':
|
| 152 |
+
x=[]
|
| 153 |
+
config=splitTillPeriod(config,x)
|
| 154 |
+
temp=x[0]
|
| 155 |
+
if temp!='sg'and temp!='pl' and temp!='du':
|
| 156 |
+
two=temp
|
| 157 |
+
else:
|
| 158 |
+
ONE=temp
|
| 159 |
+
|
| 160 |
+
if config!='':
|
| 161 |
+
x=[]
|
| 162 |
+
config=splitTillPeriod(config,x)
|
| 163 |
+
temp=x[0]
|
| 164 |
+
print
|
| 165 |
+
if temp!='sg'and temp!='pl' and temp!='du':
|
| 166 |
+
if ONE=='':
|
| 167 |
+
three=temp
|
| 168 |
+
elif ONE!='':
|
| 169 |
+
TWO=temp
|
| 170 |
+
else:
|
| 171 |
+
ONE=temp
|
| 172 |
+
if config!='':
|
| 173 |
+
x=[]
|
| 174 |
+
config=splitTillPeriod(config,x)
|
| 175 |
+
temp=x[0]
|
| 176 |
+
if temp=='sg'or temp=='pl' or temp=='du':
|
| 177 |
+
ONE=temp
|
| 178 |
+
elif temp=='1'or temp=='2' or temp=='3':
|
| 179 |
+
TWO=temp
|
| 180 |
+
|
| 181 |
+
if config!='':
|
| 182 |
+
x=[]
|
| 183 |
+
config=splitTillPeriod(config,x)
|
| 184 |
+
temp=x[0]
|
| 185 |
+
if temp=='1'or temp=='2' or temp=='3':
|
| 186 |
+
TWO=temp
|
| 187 |
+
|
| 188 |
+
for i in verbMapping2.keys():
|
| 189 |
+
if ONE!='':
|
| 190 |
+
if ONE in verbMapping2[i] and TWO in verbMapping2[i]:
|
| 191 |
+
unit=i
|
| 192 |
+
break
|
| 193 |
+
|
| 194 |
+
if one=='pp':
|
| 195 |
+
ten=19
|
| 196 |
+
if one=='ppa':
|
| 197 |
+
ten=20
|
| 198 |
+
if one=='pfp':
|
| 199 |
+
ten=21
|
| 200 |
+
if one=='inf':
|
| 201 |
+
ten=22
|
| 202 |
+
if one=='abs':
|
| 203 |
+
ten=23
|
| 204 |
+
if one=='inj':
|
| 205 |
+
ten=30
|
| 206 |
+
|
| 207 |
+
if one=='pr' or one=='ppr':
|
| 208 |
+
if two=='ps':
|
| 209 |
+
ten=24
|
| 210 |
+
if one=='imp':
|
| 211 |
+
if two=='ps':
|
| 212 |
+
ten=26
|
| 213 |
+
if one=='impft':
|
| 214 |
+
if two=='ps':
|
| 215 |
+
ten=27
|
| 216 |
+
if one=='aor':
|
| 217 |
+
if two=='ps':
|
| 218 |
+
ten=28
|
| 219 |
+
if one=='opt':
|
| 220 |
+
if two=='ps':
|
| 221 |
+
ten=29
|
| 222 |
+
|
| 223 |
+
if one=='pr'or one=='ppr':
|
| 224 |
+
if 'ac' in two or 'md' in two:
|
| 225 |
+
ten=1
|
| 226 |
+
if one=='opt':
|
| 227 |
+
if 'ac' in two or 'md' in two:
|
| 228 |
+
ten=2
|
| 229 |
+
if one=='imp':
|
| 230 |
+
if 'ac' in two or 'md' in two:
|
| 231 |
+
ten=3
|
| 232 |
+
if one=='impft':
|
| 233 |
+
if 'ac' in two or 'md' in two:
|
| 234 |
+
ten=4
|
| 235 |
+
if one=='pft' or one=='ppf':
|
| 236 |
+
if 'ac' in two or 'md' in two:
|
| 237 |
+
ten=15
|
| 238 |
+
|
| 239 |
+
if one=='per':
|
| 240 |
+
if two=='pft':
|
| 241 |
+
ten=16
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
if one=='fut' or one=='pfu':
|
| 245 |
+
if 'ac' in two or 'ps' in two or 'md' in two:
|
| 246 |
+
ten=5
|
| 247 |
+
if one=='cond':
|
| 248 |
+
if 'ac' in two or 'ps' in two or 'md' in two:
|
| 249 |
+
ten=6
|
| 250 |
+
if one=='ben':
|
| 251 |
+
if 'ac' in two or 'ps' in two or 'md' in two:
|
| 252 |
+
ten=14
|
| 253 |
+
|
| 254 |
+
if one=='aor':
|
| 255 |
+
if 'ac' in two or 'ps' in two or 'md' in two:
|
| 256 |
+
if '1' in two:
|
| 257 |
+
ten=8
|
| 258 |
+
if '2' in two:
|
| 259 |
+
ten=9
|
| 260 |
+
if '3' in two:
|
| 261 |
+
ten=10
|
| 262 |
+
if '4' in two:
|
| 263 |
+
ten=11
|
| 264 |
+
if '5' in two or '6' in two:
|
| 265 |
+
ten=12
|
| 266 |
+
if '7' in two:
|
| 267 |
+
ten=13
|
| 268 |
+
|
| 269 |
+
if one=='per':
|
| 270 |
+
if two=='fut':
|
| 271 |
+
if (('ac' in three) or ('ps' in three) or 'md' in three):
|
| 272 |
+
ten=7
|
| 273 |
+
|
| 274 |
+
if ten!=0:
|
| 275 |
+
output=-1*(ten*10+unit)
|
| 276 |
+
return output
|
| 277 |
+
else:
|
| 278 |
+
x=3
|
| 279 |
+
|
| 280 |
+
else:
|
| 281 |
+
return 'none'
|
dir/word_definite.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dir/word_definite[d_1500_BM2_v12].py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
asttokens==3.0.0
|
| 2 |
+
bz2file==0.98
|
| 3 |
+
certifi==2025.7.14
|
| 4 |
+
charset-normalizer==3.4.2
|
| 5 |
+
comm==0.2.3
|
| 6 |
+
debugpy==1.8.15
|
| 7 |
+
decorator==5.2.1
|
| 8 |
+
dill==0.4.0
|
| 9 |
+
exceptiongroup==1.3.0
|
| 10 |
+
executing==2.2.0
|
| 11 |
+
filelock==3.18.0
|
| 12 |
+
fsspec==2025.7.0
|
| 13 |
+
hf-xet==1.1.5
|
| 14 |
+
huggingface-hub==0.34.1
|
| 15 |
+
idna==3.10
|
| 16 |
+
ipykernel==6.30.0
|
| 17 |
+
ipython==8.37.0
|
| 18 |
+
jedi==0.19.2
|
| 19 |
+
Jinja2==3.1.6
|
| 20 |
+
jupyter_client==8.6.3
|
| 21 |
+
jupyter_core==5.8.1
|
| 22 |
+
matplotlib-inline==0.1.7
|
| 23 |
+
mpmath==1.3.0
|
| 24 |
+
nest-asyncio==1.6.0
|
| 25 |
+
networkx==3.4.2
|
| 26 |
+
numpy==1.26.4
|
| 27 |
+
nvidia-cublas-cu12==12.6.4.1
|
| 28 |
+
nvidia-cuda-cupti-cu12==12.6.80
|
| 29 |
+
nvidia-cuda-nvrtc-cu12==12.6.77
|
| 30 |
+
nvidia-cuda-runtime-cu12==12.6.77
|
| 31 |
+
nvidia-cudnn-cu12==9.5.1.17
|
| 32 |
+
nvidia-cufft-cu12==11.3.0.4
|
| 33 |
+
nvidia-cufile-cu12==1.11.1.6
|
| 34 |
+
nvidia-curand-cu12==10.3.7.77
|
| 35 |
+
nvidia-cusolver-cu12==11.7.1.2
|
| 36 |
+
nvidia-cusparse-cu12==12.5.4.2
|
| 37 |
+
nvidia-cusparselt-cu12==0.6.3
|
| 38 |
+
nvidia-nccl-cu12==2.26.2
|
| 39 |
+
nvidia-nvjitlink-cu12==12.6.85
|
| 40 |
+
nvidia-nvtx-cu12==12.6.77
|
| 41 |
+
packaging==25.0
|
| 42 |
+
pandas==2.3.1
|
| 43 |
+
parso==0.8.4
|
| 44 |
+
prompt_toolkit==3.0.51
|
| 45 |
+
pure_eval==0.2.3
|
| 46 |
+
Pygments==2.19.2
|
| 47 |
+
python-dateutil==2.9.0.post0
|
| 48 |
+
pytz==2025.2
|
| 49 |
+
PyYAML==6.0.2
|
| 50 |
+
pyzmq==27.0.0
|
| 51 |
+
regex==2024.11.6
|
| 52 |
+
requests==2.32.4
|
| 53 |
+
safetensors==0.5.3
|
| 54 |
+
scipy==1.15.3
|
| 55 |
+
sentencepiece==0.2.0
|
| 56 |
+
stack-data==0.6.3
|
| 57 |
+
sympy==1.14.0
|
| 58 |
+
tokenizers==0.21.2
|
| 59 |
+
torch==2.7.1
|
| 60 |
+
tornado==6.5.1
|
| 61 |
+
tqdm==4.67.1
|
| 62 |
+
traitlets==5.14.3
|
| 63 |
+
transformers==4.54.0
|
| 64 |
+
triton==3.3.1
|
| 65 |
+
typing_extensions==4.14.1
|
| 66 |
+
tzdata==2025.2
|
| 67 |
+
urllib3==2.5.0
|
| 68 |
+
wcwidth==0.2.13
|