SalazarPevelll
commited on
Commit
·
f291f4a
1
Parent(s):
8fcf809
be
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- saved_models/codesearch_simp/__pycache__/context.cpython-310.pyc +0 -0
- saved_models/codesearch_simp/__pycache__/context.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/__pycache__/strategy.cpython-310.pyc +0 -0
- saved_models/codesearch_simp/__pycache__/strategy.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/context.py +603 -0
- saved_models/codesearch_simp/dataFeature.ipynb +0 -0
- saved_models/codesearch_simp/gen_label.py +32 -0
- saved_models/codesearch_simp/server/__init__.py +0 -0
- saved_models/codesearch_simp/server/__pycache__/utils.cpython-310.pyc +0 -0
- saved_models/codesearch_simp/server/__pycache__/utils.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/server/admin_API_result.csv +0 -0
- saved_models/codesearch_simp/server/server.py +620 -0
- saved_models/codesearch_simp/server/utils.py +475 -0
- saved_models/codesearch_simp/simplify.py +35 -0
- saved_models/codesearch_simp/singleVis/SingleVisualizationModel.py +188 -0
- saved_models/codesearch_simp/singleVis/__init__.py +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/SingleVisualizationModel.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/SingleVisualizationModel.cpython-39.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/__init__.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/__init__.cpython-39.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/active_sampling.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/backend.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/backend.cpython-39.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/custom_weighted_random_sampler.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/custom_weighted_random_sampler.cpython-39.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/data.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/data.cpython-39.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/edge_dataset.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/edge_dataset.cpython-39.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/intrinsic_dim.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/intrinsic_dim.cpython-39.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/jj1sk.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/jj51sk.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/jj551sk.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/jjsk.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/kcenter_greedy.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/kcenter_greedy.cpython-39.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/losses.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/losses.cpython-39.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/projector.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/sVis.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/s_Vis.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/segmenter.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/skeVis.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/skeleVis.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/skele_Vis.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/skele_viser.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/skeletonVis.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/skeletonViser.cpython-37.pyc +0 -0
- saved_models/codesearch_simp/singleVis/__pycache__/skeletonVisualizer.cpython-37.pyc +0 -0
saved_models/codesearch_simp/__pycache__/context.cpython-310.pyc
ADDED
|
Binary file (17.9 kB). View file
|
|
|
saved_models/codesearch_simp/__pycache__/context.cpython-37.pyc
ADDED
|
Binary file (18 kB). View file
|
|
|
saved_models/codesearch_simp/__pycache__/strategy.cpython-310.pyc
ADDED
|
Binary file (38.4 kB). View file
|
|
|
saved_models/codesearch_simp/__pycache__/strategy.cpython-37.pyc
ADDED
|
Binary file (44.5 kB). View file
|
|
|
saved_models/codesearch_simp/context.py
ADDED
|
@@ -0,0 +1,603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''This class serves as a intermediate layer for tensorboard frontend and DeepDebugger backend'''
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import json
|
| 6 |
+
import time
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pickle
|
| 10 |
+
import shutil
|
| 11 |
+
|
| 12 |
+
import torch.nn
|
| 13 |
+
|
| 14 |
+
from scipy.special import softmax
|
| 15 |
+
|
| 16 |
+
from strategy import StrategyAbstractClass
|
| 17 |
+
|
| 18 |
+
from singleVis.utils import *
|
| 19 |
+
from singleVis.trajectory_manager import Recommender
|
| 20 |
+
from singleVis.active_sampling import random_sampling, uncerainty_sampling
|
| 21 |
+
|
| 22 |
+
# active_learning_path = "../../ActiveLearning"
|
| 23 |
+
# sys.path.append(active_learning_path)
|
| 24 |
+
|
| 25 |
+
'''the context for different dataset setting'''
|
| 26 |
+
class Context(ABC):
|
| 27 |
+
"""
|
| 28 |
+
The Context defines the interface of interest to users of our visualization method.
|
| 29 |
+
"""
|
| 30 |
+
def __init__(self, strategy: StrategyAbstractClass) -> None:
|
| 31 |
+
"""
|
| 32 |
+
Usually, the Context accepts a visualization strategy through the constructor, but
|
| 33 |
+
also provides a setter to change it at runtime.
|
| 34 |
+
"""
|
| 35 |
+
self._strategy = strategy
|
| 36 |
+
|
| 37 |
+
@property
|
| 38 |
+
def strategy(self) -> StrategyAbstractClass:
|
| 39 |
+
return self._strategy
|
| 40 |
+
|
| 41 |
+
@strategy.setter
|
| 42 |
+
def strategy(self, strategy: StrategyAbstractClass) -> None:
|
| 43 |
+
self._strategy = strategy
|
| 44 |
+
|
| 45 |
+
class VisContext(Context):
|
| 46 |
+
'''Normal setting'''
|
| 47 |
+
#################################################################################################################
|
| 48 |
+
# #
|
| 49 |
+
# Adapter #
|
| 50 |
+
# #
|
| 51 |
+
#################################################################################################################
|
| 52 |
+
|
| 53 |
+
def train_representation_data(self, EPOCH):
|
| 54 |
+
return self.strategy.data_provider.train_representation(EPOCH)
|
| 55 |
+
|
| 56 |
+
def test_representation_data(self, EPOCH):
|
| 57 |
+
return self.strategy.data_provider.test_representation(EPOCH)
|
| 58 |
+
|
| 59 |
+
def train_labels(self, EPOCH):
|
| 60 |
+
return self.strategy.data_provider.train_labels(EPOCH)
|
| 61 |
+
|
| 62 |
+
def test_labels(self, EPOCH):
|
| 63 |
+
return self.strategy.data_provider.test_labels(EPOCH)
|
| 64 |
+
|
| 65 |
+
def suggest_abnormal(self, strategy, acc_idxs, rej_idxs, budget):
|
| 66 |
+
ntd = self._init_detection()
|
| 67 |
+
if strategy == "TBSampling":
|
| 68 |
+
suggest_idxs, scores = ntd.sample_batch_init(acc_idxs, rej_idxs, budget)
|
| 69 |
+
elif strategy == "Feedback":
|
| 70 |
+
suggest_idxs, scores = ntd.sample_batch(acc_idxs, rej_idxs, budget)
|
| 71 |
+
else:
|
| 72 |
+
raise NotImplementedError
|
| 73 |
+
suggest_labels = self.clean_labels[suggest_idxs]
|
| 74 |
+
return suggest_idxs, scores, suggest_labels
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
#################################################################################################################
|
| 78 |
+
# #
|
| 79 |
+
# data Panel #
|
| 80 |
+
# #
|
| 81 |
+
#################################################################################################################
|
| 82 |
+
|
| 83 |
+
def batch_inv_preserve(self, epoch, data):
|
| 84 |
+
"""
|
| 85 |
+
get inverse confidence for a single point
|
| 86 |
+
:param epoch: int
|
| 87 |
+
:param data: numpy.ndarray
|
| 88 |
+
:return l: boolean, whether reconstruction data have the same prediction
|
| 89 |
+
:return conf_diff: float, (0, 1), confidence difference
|
| 90 |
+
"""
|
| 91 |
+
embedding = self.strategy.projector.batch_project(epoch, data)
|
| 92 |
+
recon = self.strategy.projector.batch_inverse(epoch, embedding)
|
| 93 |
+
|
| 94 |
+
ori_pred = self.strategy.data_provider.get_pred(epoch, data)
|
| 95 |
+
new_pred = self.strategy.data_provider.get_pred(epoch, recon)
|
| 96 |
+
ori_pred = softmax(ori_pred, axis=1)
|
| 97 |
+
new_pred = softmax(new_pred, axis=1)
|
| 98 |
+
|
| 99 |
+
old_label = ori_pred.argmax(-1)
|
| 100 |
+
new_label = new_pred.argmax(-1)
|
| 101 |
+
l = old_label == new_label
|
| 102 |
+
|
| 103 |
+
old_conf = [ori_pred[i, old_label[i]] for i in range(len(old_label))]
|
| 104 |
+
new_conf = [new_pred[i, old_label[i]] for i in range(len(old_label))]
|
| 105 |
+
old_conf = np.array(old_conf)
|
| 106 |
+
new_conf = np.array(new_conf)
|
| 107 |
+
|
| 108 |
+
conf_diff = old_conf - new_conf
|
| 109 |
+
return l, conf_diff
|
| 110 |
+
|
| 111 |
+
#################################################################################################################
|
| 112 |
+
# #
|
| 113 |
+
# Search Panel #
|
| 114 |
+
# #
|
| 115 |
+
#################################################################################################################
|
| 116 |
+
|
| 117 |
+
# TODO: fix bugs accroding to new api
|
| 118 |
+
# customized features
|
| 119 |
+
def filter_label(self, label, epoch_id):
|
| 120 |
+
try:
|
| 121 |
+
index = self.strategy.data_provider.classes.index(label)
|
| 122 |
+
except:
|
| 123 |
+
index = -1
|
| 124 |
+
train_labels = self.strategy.data_provider.train_labels(epoch_id)
|
| 125 |
+
test_labels = self.strategy.data_provider.test_labels(epoch_id)
|
| 126 |
+
labels = np.concatenate((train_labels, test_labels), 0)
|
| 127 |
+
idxs = np.argwhere(labels == index)
|
| 128 |
+
idxs = np.squeeze(idxs)
|
| 129 |
+
return idxs
|
| 130 |
+
|
| 131 |
+
def filter_type(self, type, epoch_id):
|
| 132 |
+
if type == "train":
|
| 133 |
+
res = self.get_epoch_index(epoch_id)
|
| 134 |
+
elif type == "test":
|
| 135 |
+
train_num = self.strategy.data_provider.train_num
|
| 136 |
+
test_num = self.strategy.data_provider.test_num
|
| 137 |
+
res = list(range(train_num, train_num+ test_num, 1))
|
| 138 |
+
elif type == "unlabel":
|
| 139 |
+
labeled = np.array(self.get_epoch_index(epoch_id))
|
| 140 |
+
train_num = self.strategy.data_provider.train_num
|
| 141 |
+
all_data = np.arange(train_num)
|
| 142 |
+
unlabeled = np.setdiff1d(all_data, labeled)
|
| 143 |
+
res = unlabeled.tolist()
|
| 144 |
+
else:
|
| 145 |
+
# all data
|
| 146 |
+
train_num = self.strategy.data_provider.train_num
|
| 147 |
+
test_num = self.strategy.data_provider.test_num
|
| 148 |
+
res = list(range(0, train_num + test_num, 1))
|
| 149 |
+
return res
|
| 150 |
+
|
| 151 |
+
def filter_conf(self, conf_min, conf_max, epoch_id):
|
| 152 |
+
train_data = self.strategy.data_provider.train_representation(epoch_id)
|
| 153 |
+
test_data =self.strategy.data_provider.test_representation(epoch_id)
|
| 154 |
+
data = np.concatenate((train_data, test_data), axis=0)
|
| 155 |
+
pred = self.strategy.data_provider.get_pred(epoch_id, data)
|
| 156 |
+
scores = np.amax(softmax(pred, axis=1), axis=1)
|
| 157 |
+
res = np.argwhere(np.logical_and(scores<=conf_max, scores>=conf_min)).squeeze().tolist()
|
| 158 |
+
return res
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
#################################################################################################################
|
| 162 |
+
# #
|
| 163 |
+
# Helper Functions #
|
| 164 |
+
# #
|
| 165 |
+
#################################################################################################################
|
| 166 |
+
|
| 167 |
+
def save_acc_and_rej(self, acc_idxs, rej_idxs, file_name):
|
| 168 |
+
d = {
|
| 169 |
+
"acc_idxs": acc_idxs,
|
| 170 |
+
"rej_idxs": rej_idxs
|
| 171 |
+
}
|
| 172 |
+
path = os.path.join(self.strategy.data_provider.content_path, "{}_acc_rej.json".format(file_name))
|
| 173 |
+
with open(path, "w") as f:
|
| 174 |
+
json.dump(d, f)
|
| 175 |
+
print("Successfully save the acc and rej idxs selected by user...")
|
| 176 |
+
|
| 177 |
+
def get_epoch_index(self, epoch_id):
|
| 178 |
+
"""get the training data index for an epoch"""
|
| 179 |
+
index_file = os.path.join(self.strategy.data_provider.model_path, "Epoch_{:d}".format(epoch_id), "index.json")
|
| 180 |
+
index = load_labelled_data_index(index_file)
|
| 181 |
+
return index
|
| 182 |
+
|
| 183 |
+
def get_max_iter(self):
|
| 184 |
+
EPOCH_START = self.strategy.config["EPOCH_START"]
|
| 185 |
+
EPOCH_END = self.strategy.config["EPOCH_END"]
|
| 186 |
+
EPOCH_PERIOD = self.strategy.config["EPOCH_PERIOD"]
|
| 187 |
+
return int((EPOCH_END-EPOCH_START)/EPOCH_PERIOD)+1
|
| 188 |
+
|
| 189 |
+
def reset(self):
|
| 190 |
+
return
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class ActiveLearningContext(VisContext):
|
| 194 |
+
'''Active learning dataset'''
|
| 195 |
+
def __init__(self, strategy) -> None:
|
| 196 |
+
super().__init__(strategy)
|
| 197 |
+
|
| 198 |
+
'''Active learning setting'''
|
| 199 |
+
#################################################################################################################
|
| 200 |
+
# #
|
| 201 |
+
# Adapter #
|
| 202 |
+
# #
|
| 203 |
+
#################################################################################################################
|
| 204 |
+
|
| 205 |
+
def train_representation_data(self, iteration):
|
| 206 |
+
return self.strategy.data_provider.train_representation_all(iteration)
|
| 207 |
+
|
| 208 |
+
def train_labels(self, iteration):
|
| 209 |
+
labels = self.strategy.data_provider.train_labels_all()
|
| 210 |
+
return labels
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def save_acc_and_rej(self, iteration, acc_idxs, rej_idxs, file_name):
|
| 214 |
+
d = {
|
| 215 |
+
"acc_idxs": acc_idxs,
|
| 216 |
+
"rej_idxs": rej_idxs
|
| 217 |
+
}
|
| 218 |
+
path = os.path.join(self.strategy.data_provider.checkpoint_path(iteration), "{}_acc_rej.json".format(file_name))
|
| 219 |
+
with open(path, "w") as f:
|
| 220 |
+
json.dump(d, f)
|
| 221 |
+
print("Successfully save the acc and rej idxs selected by user at Iteration {}...".format(iteration))
|
| 222 |
+
|
| 223 |
+
def reset(self, iteration):
|
| 224 |
+
# delete [iteration,...)
|
| 225 |
+
max_i = self.get_max_iter()
|
| 226 |
+
for i in range(iteration, max_i+1, 1):
|
| 227 |
+
path = self.strategy.data_provider.checkpoint_path(iteration)
|
| 228 |
+
shutil.rmtree(path)
|
| 229 |
+
iter_structure_path = os.path.join(self.strategy.data_provider.content_path, "iteration_structure.json")
|
| 230 |
+
with open(iter_structure_path, "r") as f:
|
| 231 |
+
i_s = json.load(f)
|
| 232 |
+
new_is = list()
|
| 233 |
+
for item in i_s:
|
| 234 |
+
value = item["value"]
|
| 235 |
+
if value < iteration:
|
| 236 |
+
new_is.append(item)
|
| 237 |
+
with open(iter_structure_path, "w") as f:
|
| 238 |
+
json.dump(new_is, f)
|
| 239 |
+
print("Successfully remove cache data!")
|
| 240 |
+
|
| 241 |
+
def get_epoch_index(self, iteration):
|
| 242 |
+
"""get the training data index for an epoch"""
|
| 243 |
+
index_file = os.path.join(self.strategy.data_provider.checkpoint_path(iteration), "index.json")
|
| 244 |
+
index = load_labelled_data_index(index_file)
|
| 245 |
+
return index
|
| 246 |
+
|
| 247 |
+
def al_query(self, iteration, budget, strategy, acc_idxs, rej_idxs):
|
| 248 |
+
"""get the index of new selection from different strategies"""
|
| 249 |
+
CONTENT_PATH = self.strategy.data_provider.content_path
|
| 250 |
+
NUM_QUERY = budget
|
| 251 |
+
NET = self.strategy.config["TRAINING"]["NET"]
|
| 252 |
+
DATA_NAME = self.strategy.config["DATASET"]
|
| 253 |
+
TOTAL_EPOCH = self.strategy.config["TRAINING"]["total_epoch"]
|
| 254 |
+
sys.path.append(CONTENT_PATH)
|
| 255 |
+
|
| 256 |
+
# record output information
|
| 257 |
+
# now = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time()))
|
| 258 |
+
# sys.stdout = open(os.path.join(CONTENT_PATH, now+".txt"), "w")
|
| 259 |
+
|
| 260 |
+
# loading neural network
|
| 261 |
+
import Model.model as subject_model
|
| 262 |
+
task_model = eval("subject_model.{}()".format(NET))
|
| 263 |
+
# start experiment
|
| 264 |
+
n_pool = self.strategy.config["TRAINING"]["train_num"] # 50000
|
| 265 |
+
n_test = self.strategy.config["TRAINING"]['test_num'] # 10000
|
| 266 |
+
|
| 267 |
+
resume_path = self.strategy.data_provider.checkpoint_path(iteration)
|
| 268 |
+
|
| 269 |
+
idxs_lb = np.array(json.load(open(os.path.join(resume_path, "index.json"), "r")))
|
| 270 |
+
|
| 271 |
+
state_dict = torch.load(os.path.join(resume_path, "subject_model.pth"), map_location=torch.device('cpu'))
|
| 272 |
+
task_model.load_state_dict(state_dict)
|
| 273 |
+
NUM_INIT_LB = len(idxs_lb)
|
| 274 |
+
|
| 275 |
+
print('resume from iteration {}'.format(iteration))
|
| 276 |
+
print('number of labeled pool: {}'.format(NUM_INIT_LB))
|
| 277 |
+
print('number of unlabeled pool: {}'.format(n_pool - NUM_INIT_LB))
|
| 278 |
+
print('number of testing pool: {}'.format(n_test))
|
| 279 |
+
|
| 280 |
+
if strategy == "Random":
|
| 281 |
+
print(DATA_NAME)
|
| 282 |
+
print(strategy)
|
| 283 |
+
print('================Round {:d}==============='.format(iteration+1))
|
| 284 |
+
# query new samples
|
| 285 |
+
t0 = time.time()
|
| 286 |
+
# TODO implement active learning
|
| 287 |
+
new_indices, scores = random_sampling(n_pool, idxs_lb, acc_idxs, rej_idxs, NUM_QUERY)
|
| 288 |
+
t1 = time.time()
|
| 289 |
+
print("Query time is {:.2f}".format(t1-t0))
|
| 290 |
+
|
| 291 |
+
elif strategy == "Uncertainty":
|
| 292 |
+
print(DATA_NAME)
|
| 293 |
+
print(strategy)
|
| 294 |
+
print('================Round {:d}==============='.format(iteration+1))
|
| 295 |
+
samples = self.strategy.data_provider.train_representation(iteration)
|
| 296 |
+
pred = self.strategy.data_provider.get_pred(iteration, samples)
|
| 297 |
+
confidence = np.amax(softmax(pred, axis=1), axis=1)
|
| 298 |
+
uncertainty = 1-confidence
|
| 299 |
+
# query new samples
|
| 300 |
+
t0 = time.time()
|
| 301 |
+
new_indices, scores = uncerainty_sampling(n_pool, idxs_lb, acc_idxs, rej_idxs, NUM_QUERY, uncertainty=uncertainty)
|
| 302 |
+
t1 = time.time()
|
| 303 |
+
print("Query time is {:.2f}".format(t1-t0))
|
| 304 |
+
|
| 305 |
+
elif strategy == "TBSampling":
|
| 306 |
+
period = int(2/3*TOTAL_EPOCH)
|
| 307 |
+
print(DATA_NAME)
|
| 308 |
+
print("TBSampling")
|
| 309 |
+
print('================Round {:d}==============='.format(iteration+1))
|
| 310 |
+
t0 = time.time()
|
| 311 |
+
new_indices, scores = self._suggest_abnormal(strategy, iteration, idxs_lb, acc_idxs, rej_idxs, budget, period)
|
| 312 |
+
t1 = time.time()
|
| 313 |
+
print("Query time is {:.2f}".format(t1-t0))
|
| 314 |
+
|
| 315 |
+
elif strategy == "Feedback":
|
| 316 |
+
period = int(2/3*TOTAL_EPOCH)
|
| 317 |
+
print(DATA_NAME)
|
| 318 |
+
print("Feedback")
|
| 319 |
+
print('================Round {:d}==============='.format(iteration+1))
|
| 320 |
+
t0 = time.time()
|
| 321 |
+
new_indices, scores = self._suggest_abnormal(strategy, iteration, idxs_lb, acc_idxs, rej_idxs, budget, period)
|
| 322 |
+
t1 = time.time()
|
| 323 |
+
print("Query time is {:.2f}".format(t1-t0))
|
| 324 |
+
else:
|
| 325 |
+
raise NotImplementedError
|
| 326 |
+
|
| 327 |
+
true_labels = self.train_labels(iteration)
|
| 328 |
+
|
| 329 |
+
return new_indices, true_labels[new_indices], scores
|
| 330 |
+
|
| 331 |
+
def al_train(self, iteration, indices):
|
| 332 |
+
# TODO fix
|
| 333 |
+
raise NotImplementedError
|
| 334 |
+
# # customize ....
|
| 335 |
+
# CONTENT_PATH = self.strategy.data_provider.content_path
|
| 336 |
+
# # record output information
|
| 337 |
+
# now = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time()))
|
| 338 |
+
# sys.stdout = open(os.path.join(CONTENT_PATH, now+".txt"), "w")
|
| 339 |
+
|
| 340 |
+
# # for reproduce purpose
|
| 341 |
+
# print("New indices:\t{}".format(len(indices)))
|
| 342 |
+
# self.save_human_selection(iteration, indices)
|
| 343 |
+
# lb_idx = self.get_epoch_index(iteration)
|
| 344 |
+
# train_idx = np.hstack((lb_idx, indices))
|
| 345 |
+
# print("Training indices:\t{}".format(len(train_idx)))
|
| 346 |
+
# print("Valid indices:\t{}".format(len(set(train_idx))))
|
| 347 |
+
|
| 348 |
+
# TOTAL_EPOCH = self.strategy.config["TRAINING"]["total_epoch"]
|
| 349 |
+
# NET = self.strategy.config["TRAINING"]["NET"]
|
| 350 |
+
# DEVICE = self.strategy.data_provider.DEVICE
|
| 351 |
+
# NEW_ITERATION = self.get_max_iter() + 1
|
| 352 |
+
# GPU = self.strategy.config["GPU"]
|
| 353 |
+
# DATA_NAME = self.strategy.config["DATASET"]
|
| 354 |
+
# sys.path.append(CONTENT_PATH)
|
| 355 |
+
|
| 356 |
+
# # loading neural network
|
| 357 |
+
# from Model.model import resnet18
|
| 358 |
+
# task_model = resnet18()
|
| 359 |
+
# resume_path = self.strategy.data_provider.checkpoint_path(iteration)
|
| 360 |
+
# state_dict = torch.load(os.path.join(resume_path, "subject_model.pth"), map_location=torch.device("cpu"))
|
| 361 |
+
# task_model.load_state_dict(state_dict)
|
| 362 |
+
|
| 363 |
+
# self.save_iteration_index(NEW_ITERATION, train_idx)
|
| 364 |
+
# task_model_type = "pytorch"
|
| 365 |
+
# # start experiment
|
| 366 |
+
# n_pool = self.strategy.config["TRAINING"]["train_num"] # 50000
|
| 367 |
+
# save_path = self.strategy.data_provider.checkpoint_path(NEW_ITERATION)
|
| 368 |
+
# os.makedirs(save_path, exist_ok=True)
|
| 369 |
+
|
| 370 |
+
# from query_strategies.random import RandomSampling
|
| 371 |
+
# q_strategy = RandomSampling(task_model, task_model_type, n_pool, lb_idx, 10, DATA_NAME, NET, gpu=GPU, **self.hyperparameters["TRAINING"])
|
| 372 |
+
# # print information
|
| 373 |
+
# print('================Round {:d}==============='.format(NEW_ITERATION))
|
| 374 |
+
# # update
|
| 375 |
+
# q_strategy.update_lb_idxs(train_idx)
|
| 376 |
+
# resnet_model = resnet18()
|
| 377 |
+
# train_dataset = torchvision.datasets.CIFAR10(root="..//data//CIFAR10", download=True, train=True, transform=self.hyperparameters["TRAINING"]['transform_tr'])
|
| 378 |
+
# test_dataset = torchvision.datasets.CIFAR10(root="..//data//CIFAR10", download=True, train=False, transform=self.hyperparameters["TRAINING"]['transform_te'])
|
| 379 |
+
# t1 = time.time()
|
| 380 |
+
# q_strategy.train(total_epoch=TOTAL_EPOCH, task_model=resnet_model, complete_dataset=train_dataset,save_path=None)
|
| 381 |
+
# t2 = time.time()
|
| 382 |
+
# print("Training time is {:.2f}".format(t2-t1))
|
| 383 |
+
# self.save_subject_model(NEW_ITERATION, q_strategy.task_model.state_dict())
|
| 384 |
+
|
| 385 |
+
# # compute accuracy at each round
|
| 386 |
+
# accu = q_strategy.test_accu(test_dataset)
|
| 387 |
+
# print('Accuracy {:.3f}'.format(100*accu))
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def get_max_iter(self):
|
| 391 |
+
path = os.path.join(self.strategy.data_provider.content_path, "Model")
|
| 392 |
+
dir_list = os.listdir(path)
|
| 393 |
+
iteration_name = self.strategy.data_provider.iteration_name
|
| 394 |
+
max_iter = -1
|
| 395 |
+
for dir in dir_list:
|
| 396 |
+
if "{}_".format(iteration_name) in dir:
|
| 397 |
+
i = int(dir.replace("{}_".format(iteration_name),""))
|
| 398 |
+
max_iter = max(max_iter, i)
|
| 399 |
+
return max_iter
|
| 400 |
+
|
| 401 |
+
def save_human_selection(self, iteration, indices):
|
| 402 |
+
"""
|
| 403 |
+
save the selected index message from DVI frontend
|
| 404 |
+
:param epoch_id:
|
| 405 |
+
:param indices: list, selected indices
|
| 406 |
+
:return:
|
| 407 |
+
"""
|
| 408 |
+
save_location = os.path.join(self.strategy.data_provider.checkpoint_path(iteration), "human_select.json")
|
| 409 |
+
with open(save_location, "w") as f:
|
| 410 |
+
json.dump(indices, f)
|
| 411 |
+
|
| 412 |
+
def save_iteration_index(self, iteration, idxs):
|
| 413 |
+
new_iteration_dir = self.strategy.data_provider.checkpoint_path(iteration)
|
| 414 |
+
os.makedirs(new_iteration_dir, exist_ok=True)
|
| 415 |
+
save_location = os.path.join(new_iteration_dir, "index.json")
|
| 416 |
+
with open(save_location, "w") as f:
|
| 417 |
+
json.dump(idxs.tolist(), f)
|
| 418 |
+
|
| 419 |
+
def save_subject_model(self, iteration, state_dict):
|
| 420 |
+
new_iteration_dir = self.strategy.data_provider.checkpoint_path(iteration)
|
| 421 |
+
model_path = os.path.join(new_iteration_dir, "subject_model.pth")
|
| 422 |
+
torch.save(state_dict, model_path)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def vis_train(self, iteration, resume_iter):
|
| 426 |
+
self.strategy.visualize_embedding(iteration, resume_iter)
|
| 427 |
+
|
| 428 |
+
#################################################################################################################
|
| 429 |
+
# #
|
| 430 |
+
# Sample Selection #
|
| 431 |
+
# #
|
| 432 |
+
#################################################################################################################
|
| 433 |
+
def _save(self, iteration, ftm):
|
| 434 |
+
with open(os.path.join(self.strategy.data_provider.checkpoint_path(iteration), '{}_sample_recommender.pkl'.format(self.strategy.VIS_METHOD)), 'wb') as f:
|
| 435 |
+
pickle.dump(ftm, f, pickle.HIGHEST_PROTOCOL)
|
| 436 |
+
|
| 437 |
+
def _init_detection(self, iteration, lb_idxs, period=80):
|
| 438 |
+
# must be in the dense setting
|
| 439 |
+
assert "Dense" in self.strategy.VIS_METHOD
|
| 440 |
+
|
| 441 |
+
# prepare high dimensional trajectory
|
| 442 |
+
embedding_path = os.path.join(self.strategy.data_provider.checkpoint_path(iteration),'trajectory_embeddings.npy')
|
| 443 |
+
if os.path.exists(embedding_path):
|
| 444 |
+
trajectories = np.load(embedding_path)
|
| 445 |
+
print("Load trajectories from cache!")
|
| 446 |
+
else:
|
| 447 |
+
# extract samples
|
| 448 |
+
TOTAL_EPOCH = self.strategy.config["TRAINING"]["total_epoch"]
|
| 449 |
+
EPOCH_START = self.strategy.config["TRAINING"]["epoch_start"]
|
| 450 |
+
EPOCH_END = self.strategy.config["TRAINING"]["epoch_end"]
|
| 451 |
+
EPOCH_PERIOD = self.strategy.config["TRAINING"]["epoch_period"]
|
| 452 |
+
train_num = len(self.train_labels(None))
|
| 453 |
+
# change epoch_NUM
|
| 454 |
+
embeddings_2d = np.zeros((TOTAL_EPOCH, train_num, 2))
|
| 455 |
+
for i in range(EPOCH_START, EPOCH_END+1, EPOCH_PERIOD):
|
| 456 |
+
id = (i - EPOCH_START)//EPOCH_PERIOD
|
| 457 |
+
embeddings_2d[id] = self.strategy.projector.batch_project(iteration, i, self.strategy.data_provider.train_representation_all(iteration, i))
|
| 458 |
+
trajectories = np.transpose(embeddings_2d, [1,0,2])
|
| 459 |
+
np.save(embedding_path, trajectories)
|
| 460 |
+
# prepare uncertainty
|
| 461 |
+
uncertainty_path = os.path.join(self.strategy.data_provider.checkpoint_path(iteration), 'uncertainties.npy')
|
| 462 |
+
if os.path.exists(uncertainty_path):
|
| 463 |
+
uncertainty = np.load(uncertainty_path)
|
| 464 |
+
else:
|
| 465 |
+
TOTAL_EPOCH = self.strategy.config["TRAINING"]["total_epoch"]
|
| 466 |
+
EPOCH_START = self.strategy.config["TRAINING"]["epoch_start"]
|
| 467 |
+
EPOCH_END = self.strategy.config["TRAINING"]["epoch_end"]
|
| 468 |
+
EPOCH_PERIOD = self.strategy.config["TRAINING"]["epoch_period"]
|
| 469 |
+
train_num = len(self.train_labels(None))
|
| 470 |
+
|
| 471 |
+
samples = self.strategy.data_provider.train_representation_all(iteration, EPOCH_END)
|
| 472 |
+
pred = self.strategy.data_provider.get_pred(iteration, EPOCH_END, samples)
|
| 473 |
+
uncertainty = 1 - np.amax(softmax(pred, axis=1), axis=1)
|
| 474 |
+
np.save(uncertainty_path, uncertainty)
|
| 475 |
+
ulb_idxs = self.strategy.data_provider.get_unlabeled_idx(len(uncertainty), lb_idxs)
|
| 476 |
+
# prepare sampling manager
|
| 477 |
+
ntd_path = os.path.join(self.strategy.data_provider.checkpoint_path(iteration), '{}_sample_recommender.pkl'.format(self.strategy.VIS_METHOD))
|
| 478 |
+
if os.path.exists(ntd_path):
|
| 479 |
+
with open(ntd_path, 'rb') as f:
|
| 480 |
+
ntd = pickle.load(f)
|
| 481 |
+
else:
|
| 482 |
+
ntd = Recommender(uncertainty[ulb_idxs], trajectories[ulb_idxs], 30, period=period)
|
| 483 |
+
print("Detecting abnormal....")
|
| 484 |
+
ntd.clustered()
|
| 485 |
+
print("Finish detection!")
|
| 486 |
+
self._save(iteration, ntd)
|
| 487 |
+
return ntd, ulb_idxs
|
| 488 |
+
|
| 489 |
+
def _suggest_abnormal(self, strategy, iteration, lb_idxs, acc_idxs, rej_idxs, budget, period):
|
| 490 |
+
ntd,ulb_idxs = self._init_detection(iteration, lb_idxs, period)
|
| 491 |
+
map_ulb = ulb_idxs.tolist()
|
| 492 |
+
map_acc_idxs = np.array([map_ulb.index(i) for i in acc_idxs]).astype(np.int32)
|
| 493 |
+
map_rej_idxs = np.array([map_ulb.index(i) for i in rej_idxs]).astype(np.int32)
|
| 494 |
+
if strategy == "TBSampling":
|
| 495 |
+
suggest_idxs, scores = ntd.sample_batch_init(map_acc_idxs, map_rej_idxs, budget)
|
| 496 |
+
elif strategy == "Feedback":
|
| 497 |
+
suggest_idxs, scores = ntd.sample_batch(map_acc_idxs, map_rej_idxs, budget)
|
| 498 |
+
else:
|
| 499 |
+
raise NotImplementedError
|
| 500 |
+
return ulb_idxs[suggest_idxs], scores
|
| 501 |
+
|
| 502 |
+
def _suggest_normal(self, strategy, iteration, lb_idxs, acc_idxs, rej_idxs, budget, period):
|
| 503 |
+
ntd, ulb_idxs = self._init_detection(iteration, lb_idxs, period)
|
| 504 |
+
map_ulb = ulb_idxs.tolist()
|
| 505 |
+
map_acc_idxs = np.array([map_ulb.index(i) for i in acc_idxs]).astype(np.int32)
|
| 506 |
+
map_rej_idxs = np.array([map_ulb.index(i) for i in rej_idxs]).astype(np.int32)
|
| 507 |
+
if strategy == "TBSampling":
|
| 508 |
+
suggest_idxs, _ = ntd.sample_batch_normal_init(map_acc_idxs, map_rej_idxs, budget)
|
| 509 |
+
elif strategy == "Feedback":
|
| 510 |
+
suggest_idxs, _ = ntd.sample_batch_normal(map_acc_idxs, map_rej_idxs, budget)
|
| 511 |
+
else:
|
| 512 |
+
raise NotImplementedError
|
| 513 |
+
return ulb_idxs[suggest_idxs]
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
class AnormalyContext(VisContext):
|
| 517 |
+
|
| 518 |
+
def __init__(self, strategy) -> None:
|
| 519 |
+
super().__init__(strategy)
|
| 520 |
+
EPOCH_START = self.strategy.config["EPOCH_START"]
|
| 521 |
+
EPOCH_END = self.strategy.config["EPOCH_END"]
|
| 522 |
+
EPOCH_PERIOD = self.strategy.config["EPOCH_PERIOD"]
|
| 523 |
+
self.period = int(2/3*((EPOCH_END-EPOCH_START)/EPOCH_PERIOD+1))
|
| 524 |
+
file_path = os.path.join(self.strategy.data_provider.content_path, 'clean_label.json')
|
| 525 |
+
with open(file_path, "r") as f:
|
| 526 |
+
self.clean_labels = np.array(json.load(f))
|
| 527 |
+
|
| 528 |
+
def reset(self):
|
| 529 |
+
return
|
| 530 |
+
|
| 531 |
+
#################################################################################################################
|
| 532 |
+
# #
|
| 533 |
+
# Anormaly Detection #
|
| 534 |
+
# #
|
| 535 |
+
#################################################################################################################
|
| 536 |
+
|
| 537 |
+
def _save(self, ntd):
|
| 538 |
+
with open(os.path.join(self.strategy.data_provider.content_path, '{}_sample_recommender.pkl'.format(self.strategy.VIS_METHOD)), 'wb') as f:
|
| 539 |
+
pickle.dump(ntd, f, pickle.HIGHEST_PROTOCOL)
|
| 540 |
+
|
| 541 |
+
def _init_detection(self):
|
| 542 |
+
# prepare trajectories
|
| 543 |
+
embedding_path = os.path.join(self.strategy.data_provider.content_path, 'trajectory_embeddings.npy')
|
| 544 |
+
if os.path.exists(embedding_path):
|
| 545 |
+
trajectories = np.load(embedding_path)
|
| 546 |
+
else:
|
| 547 |
+
# extract samples
|
| 548 |
+
train_num = self.strategy.data_provider.train_num
|
| 549 |
+
# change epoch_NUM
|
| 550 |
+
epoch_num = (self.strategy.data_provider.e - self.strategy.data_provider.s)//self.strategy.data_provider.p + 1
|
| 551 |
+
embeddings_2d = np.zeros((epoch_num, train_num, 2))
|
| 552 |
+
for i in range(self.strategy.data_provider.s, self.strategy.data_provider.e+1, self.strategy.data_provider.p):
|
| 553 |
+
id = (i - self.strategy.data_provider.s)//self.strategy.data_provider.p
|
| 554 |
+
embeddings_2d[id] = self.strategy.projector.batch_project(i, self.strategy.data_provider.train_representation(i))
|
| 555 |
+
trajectories = np.transpose(embeddings_2d, [1,0,2])
|
| 556 |
+
np.save(embedding_path, trajectories)
|
| 557 |
+
# prepare uncertainty scores
|
| 558 |
+
uncertainty_path = os.path.join(self.strategy.data_provider.content_path, 'uncertainties.npy')
|
| 559 |
+
if os.path.exists(uncertainty_path):
|
| 560 |
+
uncertainty = np.load(uncertainty_path)
|
| 561 |
+
else:
|
| 562 |
+
epoch_num = (self.strategy.data_provider.e - self.strategy.data_provider.s)//self.strategy.data_provider.p + 1
|
| 563 |
+
samples = self.strategy.data_provider.train_representation(epoch_num)
|
| 564 |
+
pred = self.strategy.data_provider.get_pred(epoch_num, samples)
|
| 565 |
+
uncertainty = 1 - np.amax(softmax(pred, axis=1), axis=1)
|
| 566 |
+
np.save(uncertainty_path, uncertainty)
|
| 567 |
+
|
| 568 |
+
# prepare sampling manager
|
| 569 |
+
ntd_path = os.path.join(self.strategy.data_provider.content_path, '{}_sample_recommender.pkl'.format(self.strategy.VIS_METHOD))
|
| 570 |
+
if os.path.exists(ntd_path):
|
| 571 |
+
with open(ntd_path, 'rb') as f:
|
| 572 |
+
ntd = pickle.load(f)
|
| 573 |
+
else:
|
| 574 |
+
ntd = Recommender(uncertainty, trajectories, 30, self.period)
|
| 575 |
+
print("Detecting abnormal....")
|
| 576 |
+
ntd.clustered()
|
| 577 |
+
print("Finish detection!")
|
| 578 |
+
self._save(ntd)
|
| 579 |
+
return ntd
|
| 580 |
+
|
| 581 |
+
def suggest_abnormal(self, strategy, acc_idxs, rej_idxs, budget):
|
| 582 |
+
ntd = self._init_detection()
|
| 583 |
+
if strategy == "TBSampling":
|
| 584 |
+
suggest_idxs, scores = ntd.sample_batch_init(acc_idxs, rej_idxs, budget)
|
| 585 |
+
elif strategy == "Feedback":
|
| 586 |
+
suggest_idxs, scores = ntd.sample_batch(acc_idxs, rej_idxs, budget)
|
| 587 |
+
else:
|
| 588 |
+
raise NotImplementedError
|
| 589 |
+
suggest_labels = self.clean_labels[suggest_idxs]
|
| 590 |
+
return suggest_idxs, scores, suggest_labels
|
| 591 |
+
|
| 592 |
+
def suggest_normal(self, strategy, acc_idxs, rej_idxs, budget):
|
| 593 |
+
ntd = self._init_detection()
|
| 594 |
+
if strategy == "TBSampling":
|
| 595 |
+
suggest_idxs, _ = ntd.sample_batch_normal_init(acc_idxs, rej_idxs, budget)
|
| 596 |
+
elif strategy == "Feedback":
|
| 597 |
+
suggest_idxs, _ = ntd.sample_batch_normal(acc_idxs, rej_idxs, budget)
|
| 598 |
+
else:
|
| 599 |
+
raise NotImplementedError
|
| 600 |
+
suggest_labels = self.clean_labels[suggest_idxs]
|
| 601 |
+
return suggest_idxs, suggest_labels
|
| 602 |
+
|
| 603 |
+
|
saved_models/codesearch_simp/dataFeature.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
saved_models/codesearch_simp/gen_label.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# file_path = '/home/yiming/ContrastDebugger/EXP/codesearch_query/Model/Epoch_1/index.json'
|
| 6 |
+
|
| 7 |
+
# # 打开并读取 JSON 文件
|
| 8 |
+
# with open(file_path, 'r') as file:
|
| 9 |
+
# json_data = json.load(file)
|
| 10 |
+
|
| 11 |
+
# testset_label = None
|
| 12 |
+
# for i in range(len(json_data)):
|
| 13 |
+
# if testset_label != None:
|
| 14 |
+
# testset_label = torch.cat((testset_label, torch.tensor([0])), 0)
|
| 15 |
+
# else:
|
| 16 |
+
# testset_label = torch.tensor([0])
|
| 17 |
+
|
| 18 |
+
# torch.save(testset_label, "/home/yiming/ContrastDebugger/EXP/codesearch_query/Training_data/training_dataset_label.pth")
|
| 19 |
+
|
| 20 |
+
input_file = "/home/yiming/ContrastDebugger/EXP/codesearch/Model/label_list.json"
|
| 21 |
+
output_file = "/home/yiming/ContrastDebugger/EXP/codesearch/Model/new_label_list.json" # 替换为输出文件的路径
|
| 22 |
+
|
| 23 |
+
# 读取输入文件
|
| 24 |
+
with open(input_file, "r") as f:
|
| 25 |
+
data = json.load(f)
|
| 26 |
+
|
| 27 |
+
# 提取每个数据的前十个字符
|
| 28 |
+
processed_data = [item[:30] for item in data]
|
| 29 |
+
|
| 30 |
+
# 保存到新的 JSON 文件
|
| 31 |
+
with open(output_file, "w") as f:
|
| 32 |
+
json.dump(processed_data, f)
|
saved_models/codesearch_simp/server/__init__.py
ADDED
|
File without changes
|
saved_models/codesearch_simp/server/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (9.66 kB). View file
|
|
|
saved_models/codesearch_simp/server/__pycache__/utils.cpython-37.pyc
ADDED
|
Binary file (9.53 kB). View file
|
|
|
saved_models/codesearch_simp/server/admin_API_result.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
saved_models/codesearch_simp/server/server.py
ADDED
|
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from flask import request, Flask, jsonify, make_response
|
| 2 |
+
from flask_cors import CORS, cross_origin
|
| 3 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 4 |
+
import base64
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import json
|
| 8 |
+
import pickle
|
| 9 |
+
import io
|
| 10 |
+
import numpy as np
|
| 11 |
+
import gc
|
| 12 |
+
import shutil
|
| 13 |
+
from utils import update_epoch_projection, initialize_backend, add_line,getCriticalChangeIndices, getConfChangeIndices, getContraVisChangeIndices, getContraVisChangeIndicesSingle
|
| 14 |
+
import time
|
| 15 |
+
|
| 16 |
+
# flask for API server
|
| 17 |
+
app = Flask(__name__)
|
| 18 |
+
cors = CORS(app, supports_credentials=True)
|
| 19 |
+
app.config['CORS_HEADERS'] = 'Content-Type'
|
| 20 |
+
|
| 21 |
+
API_result_path = "./admin_API_result.csv"
|
| 22 |
+
|
| 23 |
+
@app.route('/updateProjection', methods=["POST", "GET"])
|
| 24 |
+
@cross_origin()
|
| 25 |
+
def update_projection():
|
| 26 |
+
res = request.get_json()
|
| 27 |
+
start_time = time.time()
|
| 28 |
+
|
| 29 |
+
CONTENT_PATH = os.path.normpath(res['path'])
|
| 30 |
+
VIS_METHOD = res['vis_method']
|
| 31 |
+
SETTING = res["setting"]
|
| 32 |
+
iteration = int(res['iteration'])
|
| 33 |
+
predicates = res["predicates"]
|
| 34 |
+
username = res['username']
|
| 35 |
+
isContraVis = res['isContraVis']
|
| 36 |
+
|
| 37 |
+
# sys.path.append(CONTENT_PATH)
|
| 38 |
+
context = initialize_backend(CONTENT_PATH, VIS_METHOD, SETTING)
|
| 39 |
+
# use the true one
|
| 40 |
+
# EPOCH = (iteration-1)*context.strategy.data_provider.p + context.strategy.data_provider.s
|
| 41 |
+
EPOCH = int(iteration)
|
| 42 |
+
embedding_2d, grid, decision_view, label_name_dict, label_color_list, label_list, max_iter, training_data_index, \
|
| 43 |
+
testing_data_index, eval_new, prediction_list, selected_points, properties, highlightedPointIndices = update_epoch_projection(context, EPOCH, predicates, isContraVis)
|
| 44 |
+
|
| 45 |
+
if (len(highlightedPointIndices) != 0):
|
| 46 |
+
highlightedPointIndices = highlightedPointIndices.tolist()
|
| 47 |
+
end_time = time.time()
|
| 48 |
+
elapsed_time = end_time - start_time
|
| 49 |
+
print("updateprojection", elapsed_time)
|
| 50 |
+
# sys.path.remove(CONTENT_PATH)
|
| 51 |
+
# add_line(API_result_path,['TT',username])
|
| 52 |
+
return make_response(jsonify({'result': embedding_2d,
|
| 53 |
+
'grid_index': grid.tolist(),
|
| 54 |
+
'grid_color': 'data:image/png;base64,' + decision_view,
|
| 55 |
+
'label_name_dict':label_name_dict,
|
| 56 |
+
'label_color_list': label_color_list,
|
| 57 |
+
'label_list': label_list,
|
| 58 |
+
'maximum_iteration': max_iter,
|
| 59 |
+
'training_data': training_data_index,
|
| 60 |
+
'testing_data': testing_data_index,
|
| 61 |
+
'evaluation': eval_new,
|
| 62 |
+
'prediction_list': prediction_list,
|
| 63 |
+
"selectedPoints":selected_points.tolist(),
|
| 64 |
+
"properties":properties.tolist(),
|
| 65 |
+
"highlightedPointIndices": highlightedPointIndices
|
| 66 |
+
|
| 67 |
+
}), 200)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@app.route('/highlightCriticalChange', methods=["POST", "GET"])
|
| 71 |
+
@cross_origin()
|
| 72 |
+
def highlight_critical_change():
|
| 73 |
+
res = request.get_json()
|
| 74 |
+
CONTENT_PATH = os.path.normpath(res['path'])
|
| 75 |
+
VIS_METHOD = res['vis_method']
|
| 76 |
+
SETTING = res["setting"]
|
| 77 |
+
curr_iteration = int(res['iteration'])
|
| 78 |
+
last_iteration = int(res['last_iteration'])
|
| 79 |
+
username = res['username']
|
| 80 |
+
|
| 81 |
+
# sys.path.append(CONTENT_PATH)
|
| 82 |
+
context = initialize_backend(CONTENT_PATH, VIS_METHOD, SETTING)
|
| 83 |
+
|
| 84 |
+
predChangeIndices = getCriticalChangeIndices(context, curr_iteration, last_iteration)
|
| 85 |
+
|
| 86 |
+
# sys.path.remove(CONTENT_PATH)
|
| 87 |
+
# add_line(API_result_path,['TT',username])
|
| 88 |
+
return make_response(jsonify({
|
| 89 |
+
"predChangeIndices": predChangeIndices.tolist()
|
| 90 |
+
}), 200)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@app.route('/contraVisHighlight', methods=["POST", "GET"])
|
| 94 |
+
@cross_origin()
|
| 95 |
+
def contravis_highlight():
|
| 96 |
+
res = request.get_json()
|
| 97 |
+
CONTENT_PATH = os.path.normpath(res['path'])
|
| 98 |
+
VIS_METHOD = res['vis_method']
|
| 99 |
+
SETTING = res["setting"]
|
| 100 |
+
curr_iteration = int(res['iterationLeft'])
|
| 101 |
+
last_iteration = int(res['iterationRight'])
|
| 102 |
+
method = res['method']
|
| 103 |
+
username = res['username']
|
| 104 |
+
|
| 105 |
+
context = initialize_backend(CONTENT_PATH, VIS_METHOD, SETTING)
|
| 106 |
+
|
| 107 |
+
contraVisChangeIndices = getContraVisChangeIndices(context, curr_iteration, last_iteration, method)
|
| 108 |
+
print(len(contraVisChangeIndices))
|
| 109 |
+
return make_response(jsonify({
|
| 110 |
+
"contraVisChangeIndices": contraVisChangeIndices
|
| 111 |
+
}), 200)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@app.route('/contraVisHighlightSingle', methods=["POST", "GET"])
|
| 115 |
+
@cross_origin()
|
| 116 |
+
def contravis_highlight_single():
|
| 117 |
+
start_time = time.time()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
res = request.get_json()
|
| 126 |
+
CONTENT_PATH = os.path.normpath(res['path'])
|
| 127 |
+
VIS_METHOD = res['vis_method']
|
| 128 |
+
SETTING = res["setting"]
|
| 129 |
+
curr_iteration = int(res['iterationLeft'])
|
| 130 |
+
last_iteration = int(res['iterationRight'])
|
| 131 |
+
method = res['method']
|
| 132 |
+
left_selected = res['selectedPointLeft']
|
| 133 |
+
right_selected = res['selectedPointRight']
|
| 134 |
+
username = res['username']
|
| 135 |
+
|
| 136 |
+
context = initialize_backend(CONTENT_PATH, VIS_METHOD, SETTING)
|
| 137 |
+
|
| 138 |
+
contraVisChangeIndicesLeft, contraVisChangeIndicesRight, contraVisChangeIndicesLeftLeft, contraVisChangeIndicesLeftRight, contraVisChangeIndicesRightLeft, contraVisChangeIndicesRightRight = getContraVisChangeIndicesSingle(context, curr_iteration, last_iteration, method, left_selected, right_selected)
|
| 139 |
+
|
| 140 |
+
end_time = time.time()
|
| 141 |
+
elapsed_time = end_time - start_time
|
| 142 |
+
print(elapsed_time)
|
| 143 |
+
return make_response(jsonify({
|
| 144 |
+
"contraVisChangeIndicesLeft": contraVisChangeIndicesLeft,
|
| 145 |
+
"contraVisChangeIndicesRight": contraVisChangeIndicesRight,
|
| 146 |
+
"contraVisChangeIndicesLeftLeft": contraVisChangeIndicesLeftLeft,
|
| 147 |
+
"contraVisChangeIndicesLeftRight": contraVisChangeIndicesLeftRight,
|
| 148 |
+
"contraVisChangeIndicesRightLeft": contraVisChangeIndicesRightLeft,
|
| 149 |
+
"contraVisChangeIndicesRightRight": contraVisChangeIndicesRightRight
|
| 150 |
+
}), 200)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@app.route('/highlightConfChange', methods=["POST", "GET"])
|
| 155 |
+
@cross_origin()
|
| 156 |
+
def highlight_conf_change():
|
| 157 |
+
res = request.get_json()
|
| 158 |
+
CONTENT_PATH = os.path.normpath(res['path'])
|
| 159 |
+
VIS_METHOD = res['vis_method']
|
| 160 |
+
SETTING = res["setting"]
|
| 161 |
+
curr_iteration = int(res['iteration'])
|
| 162 |
+
last_iteration = int(res['last_iteration'])
|
| 163 |
+
confChangeInput = float(res['confChangeInput'])
|
| 164 |
+
print(confChangeInput)
|
| 165 |
+
username = res['username']
|
| 166 |
+
|
| 167 |
+
# sys.path.append(CONTENT_PATH)
|
| 168 |
+
context = initialize_backend(CONTENT_PATH, VIS_METHOD, SETTING)
|
| 169 |
+
|
| 170 |
+
confChangeIndices = getConfChangeIndices(context, curr_iteration, last_iteration, confChangeInput)
|
| 171 |
+
print(confChangeIndices)
|
| 172 |
+
# sys.path.remove(CONTENT_PATH)
|
| 173 |
+
# add_line(API_result_path,['TT',username])
|
| 174 |
+
return make_response(jsonify({
|
| 175 |
+
"confChangeIndices": confChangeIndices.tolist()
|
| 176 |
+
}), 200)
|
| 177 |
+
|
| 178 |
+
@app.route('/query', methods=["POST"])
|
| 179 |
+
@cross_origin()
|
| 180 |
+
def filter():
|
| 181 |
+
start_time = time.time()
|
| 182 |
+
res = request.get_json()
|
| 183 |
+
CONTENT_PATH = os.path.normpath(res['content_path'])
|
| 184 |
+
VIS_METHOD = res['vis_method']
|
| 185 |
+
SETTING = res["setting"]
|
| 186 |
+
|
| 187 |
+
iteration = int(res['iteration'])
|
| 188 |
+
predicates = res["predicates"]
|
| 189 |
+
username = res['username']
|
| 190 |
+
|
| 191 |
+
sys.path.append(CONTENT_PATH)
|
| 192 |
+
context = initialize_backend(CONTENT_PATH, VIS_METHOD, SETTING)
|
| 193 |
+
# TODO: fix when active learning
|
| 194 |
+
EPOCH = (iteration-1)*context.strategy.data_provider.p + context.strategy.data_provider.s
|
| 195 |
+
|
| 196 |
+
training_data_number = context.strategy.config["TRAINING"]["train_num"]
|
| 197 |
+
testing_data_number = context.strategy.config["TRAINING"]["test_num"]
|
| 198 |
+
|
| 199 |
+
current_index = context.get_epoch_index(EPOCH)
|
| 200 |
+
selected_points = np.arange(training_data_number)[current_index]
|
| 201 |
+
selected_points = np.concatenate((selected_points, np.arange(training_data_number, training_data_number + testing_data_number, 1)), axis=0)
|
| 202 |
+
# selected_points = np.arange(training_data_number + testing_data_number)
|
| 203 |
+
for key in predicates.keys():
|
| 204 |
+
if key == "label":
|
| 205 |
+
tmp = np.array(context.filter_label(predicates[key], int(EPOCH)))
|
| 206 |
+
elif key == "type":
|
| 207 |
+
tmp = np.array(context.filter_type(predicates[key], int(EPOCH)))
|
| 208 |
+
elif key == "confidence":
|
| 209 |
+
tmp = np.array(context.filter_conf(predicates[key][0],predicates[key][1],int(EPOCH)))
|
| 210 |
+
else:
|
| 211 |
+
tmp = np.arange(training_data_number + testing_data_number)
|
| 212 |
+
selected_points = np.intersect1d(selected_points, tmp)
|
| 213 |
+
sys.path.remove(CONTENT_PATH)
|
| 214 |
+
add_line(API_result_path,['SQ',username])
|
| 215 |
+
end_time = time.time()
|
| 216 |
+
elapsed_time = end_time - start_time
|
| 217 |
+
print("query", elapsed_time)
|
| 218 |
+
return make_response(jsonify({"selectedPoints": selected_points.tolist()}), 200)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# base64
|
| 222 |
+
@app.route('/spriteImage', methods=["POST","GET"])
|
| 223 |
+
@cross_origin()
|
| 224 |
+
def sprite_image():
|
| 225 |
+
path = request.args.get("path")
|
| 226 |
+
index = request.args.get("index")
|
| 227 |
+
username = request.args.get("username")
|
| 228 |
+
|
| 229 |
+
CONTENT_PATH = os.path.normpath(path)
|
| 230 |
+
print('index', index)
|
| 231 |
+
idx = int(index)
|
| 232 |
+
pic_save_dir_path = os.path.join(CONTENT_PATH, "sprites", "{}.png".format(idx))
|
| 233 |
+
img_stream = ''
|
| 234 |
+
with open(pic_save_dir_path, 'rb') as img_f:
|
| 235 |
+
img_stream = img_f.read()
|
| 236 |
+
img_stream = base64.b64encode(img_stream).decode()
|
| 237 |
+
add_line(API_result_path,['SI',username])
|
| 238 |
+
return make_response(jsonify({"imgUrl":'data:image/png;base64,' + img_stream}), 200)
|
| 239 |
+
|
| 240 |
+
@app.route('/spriteText', methods=["POST","GET"])
|
| 241 |
+
@cross_origin()
|
| 242 |
+
def sprite_text():
|
| 243 |
+
path = request.args.get("path")
|
| 244 |
+
index = request.args.get("index")
|
| 245 |
+
username = request.args.get("username")
|
| 246 |
+
iteration = request.args.get("iteration")
|
| 247 |
+
|
| 248 |
+
# Adjust font path as needed. Use a path to a .ttf file on your system, or remove the 'truetype' part to use a default font.
|
| 249 |
+
# Load font - ensure 'arial.ttf' is available at this path or use a default font
|
| 250 |
+
# try:
|
| 251 |
+
# font = ImageFont.truetype("arial.ttf", 15)
|
| 252 |
+
# except IOError:
|
| 253 |
+
# font = ImageFont.load_default()
|
| 254 |
+
|
| 255 |
+
# # Calculate image size dynamically based on text length
|
| 256 |
+
# text_width, text_height = font.getsize(text)
|
| 257 |
+
# image_size = (text_width, text_height) # Add some padding
|
| 258 |
+
|
| 259 |
+
# # Create an image
|
| 260 |
+
# background_color = "white"
|
| 261 |
+
# font_color = "black"
|
| 262 |
+
# image = Image.new("RGB", image_size, background_color)
|
| 263 |
+
# draw = ImageDraw.Draw(image)
|
| 264 |
+
# draw.text((1, 1), text, fill=font_color, font=font) # Start drawing the text from a small margin
|
| 265 |
+
|
| 266 |
+
# # Save the image to a BytesIO object
|
| 267 |
+
# img_io = io.BytesIO()
|
| 268 |
+
# image.save(img_io, 'PNG')
|
| 269 |
+
# img_io.seek(0)
|
| 270 |
+
# Assuming you have a function to get sprite texts
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
# sprite_texts = get_sprite_texts(CONTENT_PATH, idx)
|
| 275 |
+
|
| 276 |
+
# # Include both the image and texts in the response
|
| 277 |
+
# response_data = {
|
| 278 |
+
# "texts": sprite_texts
|
| 279 |
+
# }
|
| 280 |
+
|
| 281 |
+
# return make_response(jsonify(response_data), 200)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
CONTENT_PATH = os.path.normpath(path)
|
| 285 |
+
|
| 286 |
+
idx = int(index)
|
| 287 |
+
start = time.time()
|
| 288 |
+
# text_save_dir_path = os.path.join(CONTENT_PATH, f"/Model/Epoch_{iteration}/labels", "text_{}.txt".format(idx))
|
| 289 |
+
text_save_dir_path = os.path.join(CONTENT_PATH, f"Model/Epoch_{iteration}/labels", f"text_{idx}.txt")
|
| 290 |
+
if os.path.exists(text_save_dir_path):
|
| 291 |
+
with open(text_save_dir_path, 'r') as text_f:
|
| 292 |
+
# Read the contents of the file and store it in sprite_texts
|
| 293 |
+
sprite_texts = text_f.read()
|
| 294 |
+
else:
|
| 295 |
+
print("File does not exist:", text_save_dir_path)
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
print(sprite_texts)
|
| 299 |
+
response_data = {
|
| 300 |
+
"texts": sprite_texts
|
| 301 |
+
}
|
| 302 |
+
end = time.time()
|
| 303 |
+
print("processTime", end-start)
|
| 304 |
+
return make_response(jsonify(response_data), 200)
|
| 305 |
+
# img_stream = ''
|
| 306 |
+
# with open(text_save_dir_path, 'rb') as img_f:
|
| 307 |
+
# img_stream = img_f.read()
|
| 308 |
+
# img_stream = base64.b64encode(img_stream).decode()
|
| 309 |
+
# img_stream = base64.b64encode(img_io.getvalue()).decode()
|
| 310 |
+
|
| 311 |
+
# Return the base64-encoded image as JSON
|
| 312 |
+
# return make_response(jsonify({"imgUrl": 'data:image/png;base64,' + img_stream}), 200)
|
| 313 |
+
|
| 314 |
+
# @app.route('/spriteList', methods=["POST"])
|
| 315 |
+
# @cross_origin()
|
| 316 |
+
# def sprite_list_image():
|
| 317 |
+
# data = request.get_json()
|
| 318 |
+
# indices = data["index"]
|
| 319 |
+
# path = data["path"]
|
| 320 |
+
|
| 321 |
+
# CONTENT_PATH = os.path.normpath(path)
|
| 322 |
+
# length = len(indices)
|
| 323 |
+
# urlList = {}
|
| 324 |
+
# start_time = time.time()
|
| 325 |
+
# for i in range(length):
|
| 326 |
+
# idx = indices[i]
|
| 327 |
+
# pic_save_dir_path = os.path.join(CONTENT_PATH, "sprites", "{}.png".format(idx))
|
| 328 |
+
# img_stream = ''
|
| 329 |
+
# with open(pic_save_dir_path, 'rb') as img_f:
|
| 330 |
+
# img_stream = img_f.read()
|
| 331 |
+
# img_stream = base64.b64encode(img_stream).decode()
|
| 332 |
+
# urlList[idx] = 'data:image/png;base64,' + img_stream
|
| 333 |
+
# # urlList.append('data:image/png;base64,' + img_stream)
|
| 334 |
+
|
| 335 |
+
# end_time = time.time()
|
| 336 |
+
# elapsed_time = end_time - start_time
|
| 337 |
+
# print("Spritelist", elapsed_time)
|
| 338 |
+
# return make_response(jsonify({"urlList":urlList}), 200)
|
| 339 |
+
@app.route('/spriteList', methods=["POST"])
|
| 340 |
+
@cross_origin()
|
| 341 |
+
def sprite_list_image():
|
| 342 |
+
data = request.get_json()
|
| 343 |
+
indices = data["index"]
|
| 344 |
+
path = data["path"]
|
| 345 |
+
|
| 346 |
+
CONTENT_PATH = os.path.normpath(path)
|
| 347 |
+
length = len(indices)
|
| 348 |
+
urlList = {}
|
| 349 |
+
start_time = time.time()
|
| 350 |
+
for i in range(length):
|
| 351 |
+
idx = indices[i]
|
| 352 |
+
pic_save_dir_path = os.path.join(CONTENT_PATH, "sprites", "{}.png".format(idx))
|
| 353 |
+
img_stream = ''
|
| 354 |
+
with open(pic_save_dir_path, 'rb') as img_f:
|
| 355 |
+
img_stream = img_f.read()
|
| 356 |
+
img_stream = base64.b64encode(img_stream).decode()
|
| 357 |
+
urlList[idx] = 'data:image/png;base64,' + img_stream
|
| 358 |
+
# urlList.append('data:image/png;base64,' + img_stream)
|
| 359 |
+
|
| 360 |
+
end_time = time.time()
|
| 361 |
+
elapsed_time = end_time - start_time
|
| 362 |
+
print("Spritelist", elapsed_time)
|
| 363 |
+
return make_response(jsonify({"urlList":urlList}), 200)
|
| 364 |
+
|
| 365 |
+
@app.route('/al_query', methods=["POST"])
|
| 366 |
+
@cross_origin()
|
| 367 |
+
def al_query():
|
| 368 |
+
data = request.get_json()
|
| 369 |
+
CONTENT_PATH = os.path.normpath(data['content_path'])
|
| 370 |
+
VIS_METHOD = data['vis_method']
|
| 371 |
+
SETTING = data["setting"]
|
| 372 |
+
|
| 373 |
+
# TODO fix iteration, align with frontend
|
| 374 |
+
iteration = data["iteration"]
|
| 375 |
+
strategy = data["strategy"]
|
| 376 |
+
budget = int(data["budget"])
|
| 377 |
+
acc_idxs = data["accIndices"]
|
| 378 |
+
rej_idxs = data["rejIndices"]
|
| 379 |
+
user_name = data["username"]
|
| 380 |
+
isRecommend = data["isRecommend"]
|
| 381 |
+
|
| 382 |
+
sys.path.append(CONTENT_PATH)
|
| 383 |
+
context = initialize_backend(CONTENT_PATH, VIS_METHOD, SETTING, dense=True)
|
| 384 |
+
# TODO add new sampling rule
|
| 385 |
+
indices, labels, scores = context.al_query(iteration, budget, strategy, np.array(acc_idxs).astype(np.int64), np.array(rej_idxs).astype(np.int64))
|
| 386 |
+
|
| 387 |
+
sort_i = np.argsort(-scores)
|
| 388 |
+
indices = indices[sort_i]
|
| 389 |
+
labels = labels[sort_i]
|
| 390 |
+
scores = scores[sort_i]
|
| 391 |
+
|
| 392 |
+
sys.path.remove(CONTENT_PATH)
|
| 393 |
+
if not isRecommend:
|
| 394 |
+
add_line(API_result_path,['Feedback', user_name])
|
| 395 |
+
else:
|
| 396 |
+
add_line(API_result_path,['Recommend', user_name])
|
| 397 |
+
return make_response(jsonify({"selectedPoints": indices.tolist(), "scores": scores.tolist(), "suggestLabels":labels.tolist()}), 200)
|
| 398 |
+
|
| 399 |
+
@app.route('/anomaly_query', methods=["POST"])
|
| 400 |
+
@cross_origin()
|
| 401 |
+
def anomaly_query():
|
| 402 |
+
data = request.get_json()
|
| 403 |
+
CONTENT_PATH = os.path.normpath(data['content_path'])
|
| 404 |
+
VIS_METHOD = data['vis_method']
|
| 405 |
+
SETTING = data["setting"]
|
| 406 |
+
|
| 407 |
+
budget = int(data["budget"])
|
| 408 |
+
strategy = data["strategy"]
|
| 409 |
+
acc_idxs = data["accIndices"]
|
| 410 |
+
rej_idxs = data["rejIndices"]
|
| 411 |
+
user_name = data["username"]
|
| 412 |
+
isRecommend = data["isRecommend"]
|
| 413 |
+
|
| 414 |
+
sys.path.append(CONTENT_PATH)
|
| 415 |
+
context = initialize_backend(CONTENT_PATH, VIS_METHOD, SETTING)
|
| 416 |
+
|
| 417 |
+
context.save_acc_and_rej(acc_idxs, rej_idxs, user_name)
|
| 418 |
+
indices, scores, labels = context.suggest_abnormal(strategy, np.array(acc_idxs).astype(np.int64), np.array(rej_idxs).astype(np.int64), budget)
|
| 419 |
+
clean_list,_ = context.suggest_normal(strategy, np.array(acc_idxs).astype(np.int64), np.array(rej_idxs).astype(np.int64), 1)
|
| 420 |
+
|
| 421 |
+
sort_i = np.argsort(-scores)
|
| 422 |
+
indices = indices[sort_i]
|
| 423 |
+
labels = labels[sort_i]
|
| 424 |
+
scores = scores[sort_i]
|
| 425 |
+
|
| 426 |
+
sys.path.remove(CONTENT_PATH)
|
| 427 |
+
if not isRecommend:
|
| 428 |
+
add_line(API_result_path,['Feedback', user_name])
|
| 429 |
+
else:
|
| 430 |
+
add_line(API_result_path,['Recommend', user_name])
|
| 431 |
+
return make_response(jsonify({"selectedPoints": indices.tolist(), "scores": scores.tolist(), "suggestLabels":labels.tolist(),"cleanList":clean_list.tolist()}), 200)
|
| 432 |
+
|
| 433 |
+
@app.route('/al_train', methods=["POST"])
|
| 434 |
+
@cross_origin()
|
| 435 |
+
def al_train():
|
| 436 |
+
data = request.get_json()
|
| 437 |
+
CONTENT_PATH = os.path.normpath(data['content_path'])
|
| 438 |
+
VIS_METHOD = data['vis_method']
|
| 439 |
+
SETTING = data["setting"]
|
| 440 |
+
|
| 441 |
+
acc_idxs = data["accIndices"]
|
| 442 |
+
rej_idxs = data["rejIndices"]
|
| 443 |
+
iteration = data["iteration"]
|
| 444 |
+
user_name = data["username"]
|
| 445 |
+
|
| 446 |
+
sys.path.append(CONTENT_PATH)
|
| 447 |
+
# default setting al_train is light version, we only save the last epoch
|
| 448 |
+
|
| 449 |
+
context = initialize_backend(CONTENT_PATH, VIS_METHOD, SETTING)
|
| 450 |
+
context.save_acc_and_rej(iteration, acc_idxs, rej_idxs, user_name)
|
| 451 |
+
context.al_train(iteration, acc_idxs)
|
| 452 |
+
NEW_ITERATION = context.get_max_iter()
|
| 453 |
+
context.vis_train(NEW_ITERATION, iteration)
|
| 454 |
+
|
| 455 |
+
# update iteration projection
|
| 456 |
+
embedding_2d, grid, decision_view, label_name_dict, label_color_list, label_list, _, training_data_index, \
|
| 457 |
+
testing_data_index, eval_new, prediction_list, selected_points, properties = update_epoch_projection(context, NEW_ITERATION, dict())
|
| 458 |
+
|
| 459 |
+
# rewirte json =========
|
| 460 |
+
res_json_path = os.path.join(CONTENT_PATH, "iteration_structure.json")
|
| 461 |
+
with open(res_json_path,encoding='utf8')as fp:
|
| 462 |
+
json_data = json.load(fp)
|
| 463 |
+
|
| 464 |
+
json_data.append({'value': NEW_ITERATION, 'name': 'iteration', 'pid': iteration})
|
| 465 |
+
print('json_data',json_data)
|
| 466 |
+
with open(res_json_path,'w')as r:
|
| 467 |
+
json.dump(json_data, r)
|
| 468 |
+
r.close()
|
| 469 |
+
# rewirte json =========
|
| 470 |
+
|
| 471 |
+
del config
|
| 472 |
+
gc.collect()
|
| 473 |
+
|
| 474 |
+
sys.path.remove(CONTENT_PATH)
|
| 475 |
+
|
| 476 |
+
add_line(API_result_path,['al_train', user_name])
|
| 477 |
+
return make_response(jsonify({'result': embedding_2d, 'grid_index': grid, 'grid_color': 'data:image/png;base64,' + decision_view,
|
| 478 |
+
'label_name_dict': label_name_dict,
|
| 479 |
+
'label_color_list': label_color_list, 'label_list': label_list,
|
| 480 |
+
'maximum_iteration': NEW_ITERATION, 'training_data': training_data_index,
|
| 481 |
+
'testing_data': testing_data_index, 'evaluation': eval_new,
|
| 482 |
+
'prediction_list': prediction_list,
|
| 483 |
+
"selectedPoints":selected_points.tolist(),
|
| 484 |
+
"properties":properties.tolist()}), 200)
|
| 485 |
+
|
| 486 |
+
def clear_cache(con_paths):
|
| 487 |
+
for CONTENT_PATH in con_paths.values():
|
| 488 |
+
ac_flag = False
|
| 489 |
+
target_path = os.path.join(CONTENT_PATH, "Model")
|
| 490 |
+
dir_list = os.listdir(target_path)
|
| 491 |
+
for dir in dir_list:
|
| 492 |
+
if "Iteration_" in dir:
|
| 493 |
+
ac_flag=True
|
| 494 |
+
i = int(dir.replace("Iteration_", ""))
|
| 495 |
+
if i > 2:
|
| 496 |
+
shutil.rmtree(os.path.join(target_path, dir))
|
| 497 |
+
if ac_flag:
|
| 498 |
+
iter_structure_path = os.path.join(CONTENT_PATH, "iteration_structure.json")
|
| 499 |
+
with open(iter_structure_path, "r") as f:
|
| 500 |
+
i_s = json.load(f)
|
| 501 |
+
new_is = list()
|
| 502 |
+
for item in i_s:
|
| 503 |
+
value = item["value"]
|
| 504 |
+
if value < 3:
|
| 505 |
+
new_is.append(item)
|
| 506 |
+
with open(iter_structure_path, "w") as f:
|
| 507 |
+
json.dump(new_is, f)
|
| 508 |
+
print("Successfully remove cache data!")
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
@app.route('/login', methods=["POST"])
|
| 512 |
+
@cross_origin()
|
| 513 |
+
def login():
|
| 514 |
+
data = request.get_json()
|
| 515 |
+
# username = data["username"]
|
| 516 |
+
# password = data["password"]
|
| 517 |
+
content_path = data["content_path"]
|
| 518 |
+
# clear_cache(con_paths)
|
| 519 |
+
|
| 520 |
+
# Verify username and password
|
| 521 |
+
return make_response(jsonify({"normal_content_path": content_path, "unormaly_content_path": content_path}), 200)
|
| 522 |
+
|
| 523 |
+
@app.route('/boundingbox_record', methods=["POST"])
|
| 524 |
+
@cross_origin()
|
| 525 |
+
def record_bb():
|
| 526 |
+
data = request.get_json()
|
| 527 |
+
username = data['username']
|
| 528 |
+
add_line(API_result_path,['boundingbox', username])
|
| 529 |
+
return make_response(jsonify({}), 200)
|
| 530 |
+
|
| 531 |
+
@app.route('/all_result_list', methods=["POST"])
|
| 532 |
+
@cross_origin()
|
| 533 |
+
def get_res():
|
| 534 |
+
data = request.get_json()
|
| 535 |
+
CONTENT_PATH = os.path.normpath(data['content_path'])
|
| 536 |
+
VIS_METHOD = data['vis_method']
|
| 537 |
+
SETTING = data["setting"]
|
| 538 |
+
username = data["username"]
|
| 539 |
+
|
| 540 |
+
predicates = dict() # placeholder
|
| 541 |
+
|
| 542 |
+
results = dict()
|
| 543 |
+
imglist = dict()
|
| 544 |
+
gridlist = dict()
|
| 545 |
+
|
| 546 |
+
sys.path.append(CONTENT_PATH)
|
| 547 |
+
context = initialize_backend(CONTENT_PATH, VIS_METHOD, SETTING)
|
| 548 |
+
|
| 549 |
+
EPOCH_START = context.strategy.config["EPOCH_START"]
|
| 550 |
+
EPOCH_PERIOD = context.strategy.config["EPOCH_PERIOD"]
|
| 551 |
+
EPOCH_END = context.strategy.config["EPOCH_END"]
|
| 552 |
+
|
| 553 |
+
epoch_num = (EPOCH_END - EPOCH_START)// EPOCH_PERIOD + 1
|
| 554 |
+
|
| 555 |
+
for i in range(1, epoch_num+1, 1):
|
| 556 |
+
EPOCH = (i-1)*EPOCH_PERIOD + EPOCH_START
|
| 557 |
+
|
| 558 |
+
trustvis = initialize_backend(CONTENT_PATH)
|
| 559 |
+
|
| 560 |
+
# detect whether we have query before
|
| 561 |
+
fname = "Epoch" if trustvis.data_provider.mode == "normal" or trustvis.data_provider.mode == "abnormal" else "Iteration"
|
| 562 |
+
checkpoint_path = context.strategy.data_provider.checkpoint_path(EPOCH)
|
| 563 |
+
bgimg_path = os.path.join(checkpoint_path, "bgimg.png")
|
| 564 |
+
embedding_path = os.path.join(checkpoint_path, "embedding.npy")
|
| 565 |
+
grid_path = os.path.join(checkpoint_path, "grid.pkl")
|
| 566 |
+
if os.path.exists(bgimg_path) and os.path.exists(embedding_path) and os.path.exists(grid_path):
|
| 567 |
+
path = os.path.join(trustvis.data_provider.model_path, "{}_{}".format(fname, EPOCH))
|
| 568 |
+
result_path = os.path.join(path,"embedding.npy")
|
| 569 |
+
results[str(i)] = np.load(result_path).tolist()
|
| 570 |
+
with open(os.path.join(path, "grid.pkl"), "rb") as f:
|
| 571 |
+
grid = pickle.load(f)
|
| 572 |
+
gridlist[str(i)] = grid
|
| 573 |
+
else:
|
| 574 |
+
embedding_2d, grid, _, _, _, _, _, _, _, _, _, _, _ = update_epoch_projection(trustvis, EPOCH, predicates)
|
| 575 |
+
results[str(i)] = embedding_2d
|
| 576 |
+
gridlist[str(i)] = grid
|
| 577 |
+
# read background img
|
| 578 |
+
with open(bgimg_path, 'rb') as img_f:
|
| 579 |
+
img_stream = img_f.read()
|
| 580 |
+
img_stream = base64.b64encode(img_stream).decode()
|
| 581 |
+
imglist[str(i)] = 'data:image/png;base64,' + img_stream
|
| 582 |
+
# imglist[str(i)] = "http://{}{}".format(ip_adress, bgimg_path)
|
| 583 |
+
sys.path.remove(CONTENT_PATH)
|
| 584 |
+
|
| 585 |
+
del config
|
| 586 |
+
gc.collect()
|
| 587 |
+
|
| 588 |
+
add_line(API_result_path,['animation', username])
|
| 589 |
+
return make_response(jsonify({"results":results,"bgimgList":imglist, "grid": gridlist}), 200)
|
| 590 |
+
|
| 591 |
+
@app.route('/get_itertaion_structure', methods=["POST", "GET"])
|
| 592 |
+
@cross_origin()
|
| 593 |
+
def get_tree():
|
| 594 |
+
CONTENT_PATH = request.args.get("path")
|
| 595 |
+
res_json_path = os.path.join(CONTENT_PATH, "iteration_structure.json")
|
| 596 |
+
with open(res_json_path,encoding='utf8')as fp:
|
| 597 |
+
json_data = json.load(fp)
|
| 598 |
+
|
| 599 |
+
return make_response(jsonify({"structure":json_data}), 200)
|
| 600 |
+
|
| 601 |
+
def check_port_inuse(port, host):
|
| 602 |
+
try:
|
| 603 |
+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
| 604 |
+
s.settimeout(1)
|
| 605 |
+
s.connect((host, port))
|
| 606 |
+
return True
|
| 607 |
+
except socket.error:
|
| 608 |
+
return False
|
| 609 |
+
finally:
|
| 610 |
+
if s:
|
| 611 |
+
s.close()
|
| 612 |
+
|
| 613 |
+
if __name__ == "__main__":
|
| 614 |
+
import socket
|
| 615 |
+
hostname = socket.gethostname()
|
| 616 |
+
ip_address = socket.gethostbyname(hostname)
|
| 617 |
+
port = 5000
|
| 618 |
+
while check_port_inuse(port, ip_address):
|
| 619 |
+
port = port + 1
|
| 620 |
+
app.run(host=ip_address, port=int(port))
|
saved_models/codesearch_simp/server/utils.py
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
import csv
|
| 5 |
+
import numpy as np
|
| 6 |
+
import sys
|
| 7 |
+
import pickle
|
| 8 |
+
import base64
|
| 9 |
+
from scipy.special import softmax
|
| 10 |
+
vis_path = ".."
|
| 11 |
+
sys.path.append(vis_path)
|
| 12 |
+
from context import VisContext, ActiveLearningContext, AnormalyContext
|
| 13 |
+
from strategy import DeepDebugger, TimeVis, tfDeepVisualInsight, DVIAL, tfDVIDenseAL, TimeVisDenseAL, TrustActiveLearningDVI,DeepVisualInsight, TrustProxyDVI
|
| 14 |
+
from singleVis.eval.evaluate import evaluate_isAlign, evaluate_isNearestNeighbour, evaluate_isAlign_single, evaluate_isNearestNeighbour_single
|
| 15 |
+
"""Interface align"""
|
| 16 |
+
|
| 17 |
+
def initialize_strategy(CONTENT_PATH, VIS_METHOD, SETTING, dense=False):
|
| 18 |
+
# initailize strategy (visualization method)
|
| 19 |
+
with open(os.path.join(CONTENT_PATH, "config.json"), "r") as f:
|
| 20 |
+
conf = json.load(f)
|
| 21 |
+
|
| 22 |
+
config = conf[VIS_METHOD]
|
| 23 |
+
|
| 24 |
+
# todo support timevis, curretnly only support dvi
|
| 25 |
+
# remove unnecessary parts
|
| 26 |
+
if SETTING == "normal" or SETTING == "abnormal":
|
| 27 |
+
|
| 28 |
+
if VIS_METHOD == "TrustVisActiveLearning":
|
| 29 |
+
strategy = TrustActiveLearningDVI(CONTENT_PATH, config)
|
| 30 |
+
elif VIS_METHOD == "TrustVisProxy":
|
| 31 |
+
strategy = TrustProxyDVI(CONTENT_PATH, config)
|
| 32 |
+
elif VIS_METHOD == "DVI":
|
| 33 |
+
strategy = DeepVisualInsight(CONTENT_PATH, config)
|
| 34 |
+
elif VIS_METHOD == "TimeVis":
|
| 35 |
+
strategy = TimeVis(CONTENT_PATH, config)
|
| 36 |
+
elif VIS_METHOD == "DeepDebugger":
|
| 37 |
+
strategy = DeepDebugger(CONTENT_PATH, config)
|
| 38 |
+
else:
|
| 39 |
+
raise NotImplementedError
|
| 40 |
+
elif SETTING == "active learning":
|
| 41 |
+
if dense:
|
| 42 |
+
if VIS_METHOD == "DVI":
|
| 43 |
+
strategy = tfDVIDenseAL(CONTENT_PATH, config)
|
| 44 |
+
elif VIS_METHOD == "TimeVis":
|
| 45 |
+
strategy = TimeVisDenseAL(CONTENT_PATH, config)
|
| 46 |
+
else:
|
| 47 |
+
raise NotImplementedError
|
| 48 |
+
else:
|
| 49 |
+
strategy = DVIAL(CONTENT_PATH, config)
|
| 50 |
+
|
| 51 |
+
else:
|
| 52 |
+
raise NotImplementedError
|
| 53 |
+
|
| 54 |
+
return strategy
|
| 55 |
+
|
| 56 |
+
# todo remove unnecessary parts
|
| 57 |
+
def initialize_context(strategy, setting):
|
| 58 |
+
if setting == "normal":
|
| 59 |
+
context = VisContext(strategy)
|
| 60 |
+
elif setting == "active learning":
|
| 61 |
+
context = ActiveLearningContext(strategy)
|
| 62 |
+
elif setting == "abnormal":
|
| 63 |
+
context = AnormalyContext(strategy)
|
| 64 |
+
else:
|
| 65 |
+
raise NotImplementedError
|
| 66 |
+
return context
|
| 67 |
+
|
| 68 |
+
def initialize_backend(CONTENT_PATH, VIS_METHOD, SETTING, dense=False):
|
| 69 |
+
""" initialize backend for visualization
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
CONTENT_PATH (str): the directory to training process
|
| 73 |
+
VIS_METHOD (str): visualization strategy
|
| 74 |
+
"DVI", "TimeVis", "DeepDebugger",...
|
| 75 |
+
setting (str): context
|
| 76 |
+
"normal", "active learning", "dense al", "abnormal"
|
| 77 |
+
|
| 78 |
+
Raises:
|
| 79 |
+
NotImplementedError: _description_
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
backend: a context with a specific strategy
|
| 83 |
+
"""
|
| 84 |
+
strategy = initialize_strategy(CONTENT_PATH, VIS_METHOD, SETTING, dense)
|
| 85 |
+
context = initialize_context(strategy=strategy, setting=SETTING)
|
| 86 |
+
return context
|
| 87 |
+
|
| 88 |
+
def get_train_test_data(context, EPOCH):
|
| 89 |
+
|
| 90 |
+
train_data = context.train_representation_data(EPOCH)
|
| 91 |
+
test_data = context.test_representation_data(EPOCH)
|
| 92 |
+
all_data = np.concatenate((train_data, test_data), axis=0)
|
| 93 |
+
return all_data
|
| 94 |
+
def get_train_test_label(context, EPOCH):
|
| 95 |
+
train_labels = context.train_labels(EPOCH)
|
| 96 |
+
test_labels = context.test_labels(EPOCH)
|
| 97 |
+
labels = np.concatenate((train_labels, test_labels), axis=0).astype(int)
|
| 98 |
+
return labels
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# def get_strategy_by_setting(CONTENT_PATH, config, VIS_METHOD, SETTING, dense=False):
|
| 102 |
+
# if SETTING == "normal" or SETTING == "abnormal":
|
| 103 |
+
# if VIS_METHOD == "DVI":
|
| 104 |
+
# strategy = tfDeepVisualInsight(CONTENT_PATH, config)
|
| 105 |
+
# elif VIS_METHOD == "TimeVis":
|
| 106 |
+
# strategy = TimeVis(CONTENT_PATH, config)
|
| 107 |
+
# elif VIS_METHOD == "DeepDebugger":
|
| 108 |
+
# strategy = DeepDebugger(CONTENT_PATH, config)
|
| 109 |
+
# else:
|
| 110 |
+
# raise NotImplementedError
|
| 111 |
+
# elif SETTING == "active learning":
|
| 112 |
+
# if dense:
|
| 113 |
+
# if VIS_METHOD == "DVI":
|
| 114 |
+
# strategy = tfDVIDenseAL(CONTENT_PATH, config)
|
| 115 |
+
# elif VIS_METHOD == "TimeVis":
|
| 116 |
+
# strategy = TimeVisDenseAL(CONTENT_PATH, config)
|
| 117 |
+
# else:
|
| 118 |
+
# raise NotImplementedError
|
| 119 |
+
# else:
|
| 120 |
+
# strategy = DVIAL(CONTENT_PATH, config)
|
| 121 |
+
|
| 122 |
+
# else:
|
| 123 |
+
# raise NotImplementedError
|
| 124 |
+
# return strategy
|
| 125 |
+
|
| 126 |
+
# def update_embeddings(new_strategy, context, EPOCH, all_data, is_focus):
|
| 127 |
+
|
| 128 |
+
# embedding_path = os.path.join(context.strategy.data_provider.checkpoint_path(EPOCH), "embedding.npy")
|
| 129 |
+
# if os.path.exists(embedding_path):
|
| 130 |
+
# original_embedding_2d = np.load(embedding_path)
|
| 131 |
+
|
| 132 |
+
# dd = TimeVis(context.contentpath,new_conf)
|
| 133 |
+
# dd._preprocess()
|
| 134 |
+
# dd._train()
|
| 135 |
+
# embedding_2d = dd.projector.batch_project(EPOCH, all_data)
|
| 136 |
+
# return embedding_2d
|
| 137 |
+
|
| 138 |
+
# def find_and_add_nearest_neighbors(data, subset_indices, num_neighbors=10):
|
| 139 |
+
# dimension = len(data[0]) # Assuming all data points have the same dimension
|
| 140 |
+
# t = AnnoyIndex(dimension, 'euclidean') # 'euclidean' distance metric; you can use 'angular' as well
|
| 141 |
+
|
| 142 |
+
# # Build the index with the entire data
|
| 143 |
+
# for i, vector in enumerate(data):
|
| 144 |
+
# t.add_item(i, vector)
|
| 145 |
+
|
| 146 |
+
# t.build(10) # Number of trees. More trees gives higher precision.
|
| 147 |
+
|
| 148 |
+
# # Use a set for faster look-up and ensuring no duplicates
|
| 149 |
+
# subset_indices_set = set(subset_indices)
|
| 150 |
+
|
| 151 |
+
# for idx in subset_indices:
|
| 152 |
+
# nearest_neighbors = t.get_nns_by_item(idx, num_neighbors)
|
| 153 |
+
# # Use set union operation to merge indices without duplicates
|
| 154 |
+
# subset_indices_set = subset_indices_set.union(nearest_neighbors)
|
| 155 |
+
# # Convert set back to list
|
| 156 |
+
# return list(subset_indices_set)
|
| 157 |
+
|
| 158 |
+
# def get_expanded_subset(context, EPOCH, subset_indices):
|
| 159 |
+
# all_data = get_train_test_data(context, EPOCH)
|
| 160 |
+
# expanded_subset = find_and_add_nearest_neighbors(all_data, subset_indices)
|
| 161 |
+
# return expanded_subset
|
| 162 |
+
|
| 163 |
+
# def update_vis_error_points(new_strategy, context, EPOCH, is_focus):
|
| 164 |
+
# embedding_path = os.path.join(context.strategy.data_provider.checkpoint_path(EPOCH), "embedding.npy")
|
| 165 |
+
# if os.path.exists(embedding_path):
|
| 166 |
+
# original_embedding_2d = np.load(embedding_path)
|
| 167 |
+
# new_strategy._train()
|
| 168 |
+
# new_strategy.projector.batch_project
|
| 169 |
+
# embedding_2d = dd.projector.batch_project(EPOCH, all_data)
|
| 170 |
+
|
| 171 |
+
# update_embeddings(strategy, context, EPOCH, True)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def update_epoch_projection(context, EPOCH, predicates, isContraVis):
|
| 176 |
+
# TODO consider active learning setting
|
| 177 |
+
|
| 178 |
+
train_data = context.train_representation_data(EPOCH)
|
| 179 |
+
test_data = context.test_representation_data(EPOCH)
|
| 180 |
+
all_data = np.concatenate((train_data, test_data), axis=0)
|
| 181 |
+
print(len(all_data))
|
| 182 |
+
|
| 183 |
+
train_labels = context.train_labels(EPOCH)
|
| 184 |
+
# test_labels = context.test_labels(EPOCH)
|
| 185 |
+
# labels = np.concatenate((train_labels, test_labels), axis=0).astype(int)
|
| 186 |
+
labels = train_labels
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
embedding_path = os.path.join(context.strategy.data_provider.checkpoint_path(EPOCH), "embedding.npy")
|
| 190 |
+
if os.path.exists(embedding_path):
|
| 191 |
+
embedding_2d = np.load(embedding_path)
|
| 192 |
+
else:
|
| 193 |
+
embedding_2d = context.strategy.projector.batch_project(EPOCH, all_data)
|
| 194 |
+
np.save(embedding_path, embedding_2d)
|
| 195 |
+
|
| 196 |
+
training_data_number = context.strategy.config["TRAINING"]["train_num"]
|
| 197 |
+
testing_data_number = context.strategy.config["TRAINING"]["test_num"]
|
| 198 |
+
training_data_index = list(range(training_data_number))
|
| 199 |
+
testing_data_index = list(range(training_data_number, training_data_number + testing_data_number))
|
| 200 |
+
|
| 201 |
+
# return the image of background
|
| 202 |
+
# read cache if exists
|
| 203 |
+
bgimg_path = os.path.join(context.strategy.data_provider.checkpoint_path(EPOCH), "bgimg.png")
|
| 204 |
+
scale_path = os.path.join(context.strategy.data_provider.checkpoint_path(EPOCH), "scale.npy")
|
| 205 |
+
# grid_path = os.path.join(context.strategy.data_provider.checkpoint_path(EPOCH), "grid.pkl")
|
| 206 |
+
if os.path.exists(bgimg_path) and os.path.exists(scale_path):
|
| 207 |
+
# with open(os.path.join(grid_path), "rb") as f:
|
| 208 |
+
# grid = pickle.load(f)
|
| 209 |
+
with open(bgimg_path, 'rb') as img_f:
|
| 210 |
+
img_stream = img_f.read()
|
| 211 |
+
b_fig = base64.b64encode(img_stream).decode()
|
| 212 |
+
grid = np.load(scale_path)
|
| 213 |
+
else:
|
| 214 |
+
x_min, y_min, x_max, y_max, b_fig = context.strategy.vis.get_background(EPOCH, context.strategy.config["VISUALIZATION"]["RESOLUTION"])
|
| 215 |
+
grid = [x_min, y_min, x_max, y_max]
|
| 216 |
+
# formating
|
| 217 |
+
grid = [float(i) for i in grid]
|
| 218 |
+
b_fig = str(b_fig, encoding='utf-8')
|
| 219 |
+
# save results, grid and decision_view
|
| 220 |
+
# with open(grid_path, "wb") as f:
|
| 221 |
+
# pickle.dump(grid, f)
|
| 222 |
+
np.save(embedding_path, embedding_2d)
|
| 223 |
+
|
| 224 |
+
# TODO fix its structure
|
| 225 |
+
eval_new = dict()
|
| 226 |
+
file_name = context.strategy.config["VISUALIZATION"]["EVALUATION_NAME"]
|
| 227 |
+
save_eval_dir = os.path.join(context.strategy.data_provider.model_path, file_name + ".json")
|
| 228 |
+
if os.path.exists(save_eval_dir):
|
| 229 |
+
evaluation = context.strategy.evaluator.get_eval(file_name=file_name)
|
| 230 |
+
eval_new["train_acc"] = evaluation["train_acc"][str(EPOCH)]
|
| 231 |
+
eval_new["test_acc"] = evaluation["test_acc"][str(EPOCH)]
|
| 232 |
+
else:
|
| 233 |
+
eval_new["train_acc"] = 0
|
| 234 |
+
eval_new["test_acc"] = 0
|
| 235 |
+
|
| 236 |
+
color = context.strategy.vis.get_standard_classes_color() * 255
|
| 237 |
+
|
| 238 |
+
color = color.astype(int)
|
| 239 |
+
|
| 240 |
+
CLASSES = np.array(context.strategy.config["CLASSES"])
|
| 241 |
+
# label_color_list = [0] * len(labels)
|
| 242 |
+
label_color_list = color[labels].tolist()
|
| 243 |
+
label_list = CLASSES[labels].tolist()
|
| 244 |
+
label_name_dict = dict(enumerate(CLASSES))
|
| 245 |
+
|
| 246 |
+
prediction_list = []
|
| 247 |
+
# if (isContraVis == 'false'):
|
| 248 |
+
# prediction = context.strategy.data_provider.get_pred(EPOCH, all_data).argmax(1)
|
| 249 |
+
|
| 250 |
+
# for i in range(len(prediction)):
|
| 251 |
+
# prediction_list.append(CLASSES[prediction[i]])
|
| 252 |
+
|
| 253 |
+
for i in range(len(train_data)):
|
| 254 |
+
prediction_list.append("0")
|
| 255 |
+
|
| 256 |
+
EPOCH_START = context.strategy.config["EPOCH_START"]
|
| 257 |
+
EPOCH_PERIOD = context.strategy.config["EPOCH_PERIOD"]
|
| 258 |
+
EPOCH_END = context.strategy.config["EPOCH_END"]
|
| 259 |
+
max_iter = (EPOCH_END - EPOCH_START) // EPOCH_PERIOD + 1
|
| 260 |
+
# max_iter = context.get_max_iter()
|
| 261 |
+
|
| 262 |
+
# current_index = timevis.get_epoch_index(EPOCH)
|
| 263 |
+
# selected_points = np.arange(training_data_number + testing_data_number)[current_index]
|
| 264 |
+
selected_points = np.arange(training_data_number + testing_data_number)
|
| 265 |
+
for key in predicates.keys():
|
| 266 |
+
if key == "label":
|
| 267 |
+
tmp = np.array(context.filter_label(predicates[key]))
|
| 268 |
+
elif key == "type":
|
| 269 |
+
tmp = np.array(context.filter_type(predicates[key], int(EPOCH)))
|
| 270 |
+
else:
|
| 271 |
+
tmp = np.arange(training_data_number + testing_data_number)
|
| 272 |
+
selected_points = np.intersect1d(selected_points, tmp)
|
| 273 |
+
|
| 274 |
+
properties = np.concatenate((np.zeros(training_data_number, dtype=np.int16), 2*np.ones(testing_data_number, dtype=np.int16)), axis=0)
|
| 275 |
+
lb = context.get_epoch_index(EPOCH)
|
| 276 |
+
ulb = np.setdiff1d(training_data_index, lb)
|
| 277 |
+
properties[ulb] = 1
|
| 278 |
+
|
| 279 |
+
highlightedPointIndices = []
|
| 280 |
+
|
| 281 |
+
if (isContraVis == 'false'):
|
| 282 |
+
high_pred = context.strategy.data_provider.get_pred(EPOCH, all_data).argmax(1)
|
| 283 |
+
inv_high_dim_data = context.strategy.projector.batch_inverse(EPOCH, embedding_2d)
|
| 284 |
+
inv_high_pred = context.strategy.data_provider.get_pred(EPOCH, inv_high_dim_data).argmax(1)
|
| 285 |
+
highlightedPointIndices = np.where(high_pred != inv_high_pred)[0]
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
print("EMBEDDINGLEN", len(embedding_2d))
|
| 289 |
+
return embedding_2d.tolist(), grid, b_fig, label_name_dict, label_color_list, label_list, max_iter, training_data_index, testing_data_index, eval_new, prediction_list, selected_points, properties, highlightedPointIndices,
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def getContraVisChangeIndices(context, iterationLeft, iterationRight, method):
|
| 295 |
+
|
| 296 |
+
predChangeIndices = []
|
| 297 |
+
|
| 298 |
+
train_data = context.train_representation_data(iterationLeft)
|
| 299 |
+
test_data = context.test_representation_data(iterationLeft)
|
| 300 |
+
all_data = np.concatenate((train_data, test_data), axis=0)
|
| 301 |
+
|
| 302 |
+
embedding_path = os.path.join(context.strategy.data_provider.checkpoint_path(iterationLeft), "embedding.npy")
|
| 303 |
+
if os.path.exists(embedding_path):
|
| 304 |
+
embedding_2d = np.load(embedding_path)
|
| 305 |
+
else:
|
| 306 |
+
embedding_2d = context.strategy.projector.batch_project(iterationLeft, all_data)
|
| 307 |
+
np.save(embedding_path, embedding_2d)
|
| 308 |
+
|
| 309 |
+
last_train_data = context.train_representation_data(iterationRight)
|
| 310 |
+
last_test_data = context.test_representation_data(iterationRight)
|
| 311 |
+
last_all_data = np.concatenate((last_train_data, last_test_data), axis=0)
|
| 312 |
+
|
| 313 |
+
last_embedding_path = os.path.join(context.strategy.data_provider.checkpoint_path(iterationRight), "embedding.npy")
|
| 314 |
+
if os.path.exists(last_embedding_path):
|
| 315 |
+
last_embedding_2d = np.load(last_embedding_path)
|
| 316 |
+
else:
|
| 317 |
+
last_embedding_2d = context.strategy.projector.batch_project(iterationRight, last_all_data)
|
| 318 |
+
np.save(last_embedding_path, last_embedding_2d)
|
| 319 |
+
|
| 320 |
+
if (method == "align"):
|
| 321 |
+
predChangeIndices = evaluate_isAlign(embedding_2d, last_embedding_2d)
|
| 322 |
+
elif (method == "nearest neighbour"):
|
| 323 |
+
predChangeIndices = evaluate_isNearestNeighbour(embedding_2d, last_embedding_2d)
|
| 324 |
+
elif (method == "both"):
|
| 325 |
+
predChangeIndices_align = evaluate_isAlign(embedding_2d, last_embedding_2d)
|
| 326 |
+
predChangeIndices_nearest = evaluate_isNearestNeighbour(embedding_2d, last_embedding_2d)
|
| 327 |
+
|
| 328 |
+
intersection = set(predChangeIndices_align).intersection(predChangeIndices_nearest)
|
| 329 |
+
|
| 330 |
+
predChangeIndices = list(intersection)
|
| 331 |
+
|
| 332 |
+
else:
|
| 333 |
+
print("wrong method")
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
return predChangeIndices
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def getContraVisChangeIndicesSingle(context, iterationLeft, iterationRight, method, left_selected, right_selected):
|
| 340 |
+
|
| 341 |
+
train_data = context.train_representation_data(iterationLeft)
|
| 342 |
+
test_data = context.test_representation_data(iterationLeft)
|
| 343 |
+
all_data = np.concatenate((train_data, test_data), axis=0)
|
| 344 |
+
|
| 345 |
+
embedding_path = os.path.join(context.strategy.data_provider.checkpoint_path(iterationLeft), "embedding.npy")
|
| 346 |
+
if os.path.exists(embedding_path):
|
| 347 |
+
embedding_2d = np.load(embedding_path)
|
| 348 |
+
else:
|
| 349 |
+
embedding_2d = context.strategy.projector.batch_project(iterationLeft, all_data)
|
| 350 |
+
np.save(embedding_path, embedding_2d)
|
| 351 |
+
|
| 352 |
+
last_train_data = context.train_representation_data(iterationRight)
|
| 353 |
+
last_test_data = context.test_representation_data(iterationRight)
|
| 354 |
+
last_all_data = np.concatenate((last_train_data, last_test_data), axis=0)
|
| 355 |
+
|
| 356 |
+
last_embedding_path = os.path.join(context.strategy.data_provider.checkpoint_path(iterationRight), "embedding.npy")
|
| 357 |
+
if os.path.exists(last_embedding_path):
|
| 358 |
+
last_embedding_2d = np.load(last_embedding_path)
|
| 359 |
+
else:
|
| 360 |
+
last_embedding_2d = context.strategy.projector.batch_project(iterationRight, last_all_data)
|
| 361 |
+
np.save(last_embedding_path, last_embedding_2d)
|
| 362 |
+
|
| 363 |
+
predChangeIndicesLeft = []
|
| 364 |
+
predChangeIndicesRight = []
|
| 365 |
+
predChangeIndicesLeft_Left = []
|
| 366 |
+
predChangeIndicesLeft_Right = []
|
| 367 |
+
predChangeIndicesRight_Left = []
|
| 368 |
+
predChangeIndicesRight_Right = []
|
| 369 |
+
|
| 370 |
+
if (method == "align"):
|
| 371 |
+
predChangeIndicesLeft, predChangeIndicesRight = evaluate_isAlign_single(embedding_2d, last_embedding_2d, left_selected, right_selected)
|
| 372 |
+
elif (method == "nearest neighbour"):
|
| 373 |
+
predChangeIndicesLeft_Left, predChangeIndicesLeft_Right,predChangeIndicesRight_Left, predChangeIndicesRight_Right= evaluate_isNearestNeighbour_single(embedding_2d, last_embedding_2d, left_selected, right_selected)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
return predChangeIndicesLeft, predChangeIndicesRight, predChangeIndicesLeft_Left, predChangeIndicesLeft_Right, predChangeIndicesRight_Left, predChangeIndicesRight_Right
|
| 377 |
+
|
| 378 |
+
def getCriticalChangeIndices(context, curr_iteration, last_iteration):
|
| 379 |
+
|
| 380 |
+
predChangeIndices = []
|
| 381 |
+
|
| 382 |
+
train_data = context.train_representation_data(curr_iteration)
|
| 383 |
+
test_data = context.test_representation_data(curr_iteration)
|
| 384 |
+
all_data = np.concatenate((train_data, test_data), axis=0)
|
| 385 |
+
|
| 386 |
+
embedding_path = os.path.join(context.strategy.data_provider.checkpoint_path(curr_iteration), "embedding.npy")
|
| 387 |
+
if os.path.exists(embedding_path):
|
| 388 |
+
embedding_2d = np.load(embedding_path)
|
| 389 |
+
else:
|
| 390 |
+
embedding_2d = context.strategy.projector.batch_project(curr_iteration, all_data)
|
| 391 |
+
np.save(embedding_path, embedding_2d)
|
| 392 |
+
|
| 393 |
+
last_train_data = context.train_representation_data(last_iteration)
|
| 394 |
+
last_test_data = context.test_representation_data(last_iteration)
|
| 395 |
+
last_all_data = np.concatenate((last_train_data, last_test_data), axis=0)
|
| 396 |
+
|
| 397 |
+
last_embedding_path = os.path.join(context.strategy.data_provider.checkpoint_path(last_iteration), "embedding.npy")
|
| 398 |
+
if os.path.exists(last_embedding_path):
|
| 399 |
+
last_embedding_2d = np.load(last_embedding_path)
|
| 400 |
+
else:
|
| 401 |
+
last_embedding_2d = context.strategy.projector.batch_project(last_iteration, last_all_data)
|
| 402 |
+
np.save(last_embedding_path, last_embedding_2d)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
high_pred = context.strategy.data_provider.get_pred(curr_iteration, all_data).argmax(1)
|
| 406 |
+
last_high_pred = context.strategy.data_provider.get_pred(last_iteration, last_all_data).argmax(1)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
predChangeIndices = np.where(high_pred != last_high_pred)[0]
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
return predChangeIndices
|
| 413 |
+
|
| 414 |
+
def getConfChangeIndices(context, curr_iteration, last_iteration, confChangeInput):
|
| 415 |
+
|
| 416 |
+
train_data = context.train_representation_data(curr_iteration)
|
| 417 |
+
test_data = context.test_representation_data(curr_iteration)
|
| 418 |
+
all_data = np.concatenate((train_data, test_data), axis=0)
|
| 419 |
+
|
| 420 |
+
embedding_path = os.path.join(context.strategy.data_provider.checkpoint_path(curr_iteration), "embedding.npy")
|
| 421 |
+
if os.path.exists(embedding_path):
|
| 422 |
+
embedding_2d = np.load(embedding_path)
|
| 423 |
+
else:
|
| 424 |
+
embedding_2d = context.strategy.projector.batch_project(curr_iteration, all_data)
|
| 425 |
+
np.save(embedding_path, embedding_2d)
|
| 426 |
+
|
| 427 |
+
last_train_data = context.train_representation_data(last_iteration)
|
| 428 |
+
last_test_data = context.test_representation_data(last_iteration)
|
| 429 |
+
last_all_data = np.concatenate((last_train_data, last_test_data), axis=0)
|
| 430 |
+
|
| 431 |
+
last_embedding_path = os.path.join(context.strategy.data_provider.checkpoint_path(last_iteration), "embedding.npy")
|
| 432 |
+
if os.path.exists(last_embedding_path):
|
| 433 |
+
last_embedding_2d = np.load(last_embedding_path)
|
| 434 |
+
else:
|
| 435 |
+
last_embedding_2d = context.strategy.projector.batch_project(last_iteration, last_all_data)
|
| 436 |
+
np.save(last_embedding_path, last_embedding_2d)
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
high_pred = context.strategy.data_provider.get_pred(curr_iteration, all_data)
|
| 441 |
+
last_high_pred = context.strategy.data_provider.get_pred(last_iteration, last_all_data)
|
| 442 |
+
|
| 443 |
+
high_conf = softmax(high_pred, axis=1)
|
| 444 |
+
last_high_conf = softmax(last_high_pred, axis=1)
|
| 445 |
+
|
| 446 |
+
# get class type with highest prob
|
| 447 |
+
high_pred_class = high_conf.argmax(axis=1)
|
| 448 |
+
last_high_pred_class = last_high_conf.argmax(axis=1)
|
| 449 |
+
|
| 450 |
+
same_pred_indices = np.where(high_pred_class == last_high_pred_class)[0]
|
| 451 |
+
print("same")
|
| 452 |
+
print(same_pred_indices)
|
| 453 |
+
# get
|
| 454 |
+
conf_diff = np.abs(high_conf[np.arange(len(high_conf)), high_pred_class] - last_high_conf[np.arange(len(last_high_conf)), last_high_pred_class])
|
| 455 |
+
print("conf")
|
| 456 |
+
print(conf_diff)
|
| 457 |
+
significant_conf_change_indices = same_pred_indices[conf_diff[same_pred_indices] > confChangeInput]
|
| 458 |
+
print("siginificant")
|
| 459 |
+
print(significant_conf_change_indices)
|
| 460 |
+
|
| 461 |
+
return significant_conf_change_indices
|
| 462 |
+
|
| 463 |
+
def add_line(path, data_row):
|
| 464 |
+
"""
|
| 465 |
+
data_row: list, [API_name, username, time]
|
| 466 |
+
"""
|
| 467 |
+
now_time = time.strftime('%Y-%m-%d-%H:%M:%S', time.localtime())
|
| 468 |
+
data_row.append(now_time)
|
| 469 |
+
with open(path, "a+") as f:
|
| 470 |
+
csv_write = csv.writer(f)
|
| 471 |
+
csv_write.writerow(data_row)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
|
saved_models/codesearch_simp/simplify.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
# file_path = "/home/yiming/ContrastDebugger/EXP/codesearch_simp/Model/Epoch_1/train_data.npy"
|
| 5 |
+
|
| 6 |
+
# # 读取 ndarray 数据
|
| 7 |
+
# data = np.load(file_path)
|
| 8 |
+
|
| 9 |
+
# print(len(data))
|
| 10 |
+
# # # 选择前 50000 条数据
|
| 11 |
+
# # selected_data = data[:50000]
|
| 12 |
+
|
| 13 |
+
# # # 重新保存到文件中
|
| 14 |
+
# # np.save(file_path, selected_data)
|
| 15 |
+
|
| 16 |
+
# idxs = [i for i in range(len(data))]
|
| 17 |
+
|
| 18 |
+
# idxs_path = "/home/yiming/ContrastDebugger/EXP/codesearch_simp/Model/Epoch_1/index.json"
|
| 19 |
+
# json_file = open(idxs_path, mode='w')
|
| 20 |
+
# json.dump(idxs, json_file, indent=4)
|
| 21 |
+
|
| 22 |
+
input_file = "/home/yiming/ContrastDebugger/EXP/codesearch_query_simp/Model/label_list.json"
|
| 23 |
+
output_file = "/home/yiming/ContrastDebugger/EXP/codesearch_query_simp/Model/label.txt"
|
| 24 |
+
|
| 25 |
+
# 读取输入文件
|
| 26 |
+
with open(input_file, "r") as f:
|
| 27 |
+
data = json.load(f)
|
| 28 |
+
|
| 29 |
+
# 选择前 50000 条数据
|
| 30 |
+
selected_data = data[:50000]
|
| 31 |
+
|
| 32 |
+
# 将每条数据作为一行存储到输出文件
|
| 33 |
+
with open(output_file, "w") as f:
|
| 34 |
+
for item in selected_data:
|
| 35 |
+
f.write(item + "\n")
|
saved_models/codesearch_simp/singleVis/SingleVisualizationModel.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class SingleVisualizationModel(nn.Module):
|
| 5 |
+
def __init__(self, input_dims, output_dims, units, hidden_layer=3):
|
| 6 |
+
super(SingleVisualizationModel, self).__init__()
|
| 7 |
+
|
| 8 |
+
self.input_dims = input_dims
|
| 9 |
+
self.output_dims = output_dims
|
| 10 |
+
self.units = units
|
| 11 |
+
self.hidden_layer = hidden_layer
|
| 12 |
+
self._init_autoencoder()
|
| 13 |
+
|
| 14 |
+
# TODO find the best model architecture
|
| 15 |
+
def _init_autoencoder(self):
|
| 16 |
+
self.encoder = nn.Sequential(
|
| 17 |
+
nn.Linear(self.input_dims, self.units),
|
| 18 |
+
nn.ReLU(True))
|
| 19 |
+
for h in range(self.hidden_layer):
|
| 20 |
+
self.encoder.add_module("{}".format(2*h+2), nn.Linear(self.units, self.units))
|
| 21 |
+
self.encoder.add_module("{}".format(2*h+3), nn.ReLU(True))
|
| 22 |
+
self.encoder.add_module("{}".format(2*(self.hidden_layer+1)), nn.Linear(self.units, self.output_dims))
|
| 23 |
+
|
| 24 |
+
self.decoder = nn.Sequential(
|
| 25 |
+
nn.Linear(self.output_dims, self.units),
|
| 26 |
+
nn.ReLU(True))
|
| 27 |
+
for h in range(self.hidden_layer):
|
| 28 |
+
self.decoder.add_module("{}".format(2*h+2), nn.Linear(self.units, self.units))
|
| 29 |
+
self.decoder.add_module("{}".format(2*h+3), nn.ReLU(True))
|
| 30 |
+
self.decoder.add_module("{}".format(2*(self.hidden_layer+1)), nn.Linear(self.units, self.input_dims))
|
| 31 |
+
|
| 32 |
+
def forward(self, edge_to, edge_from):
|
| 33 |
+
outputs = dict()
|
| 34 |
+
embedding_to = self.encoder(edge_to)
|
| 35 |
+
embedding_from = self.encoder(edge_from)
|
| 36 |
+
recon_to = self.decoder(embedding_to)
|
| 37 |
+
recon_from = self.decoder(embedding_from)
|
| 38 |
+
|
| 39 |
+
outputs["umap"] = (embedding_to, embedding_from)
|
| 40 |
+
outputs["recon"] = (recon_to, recon_from)
|
| 41 |
+
|
| 42 |
+
return outputs
|
| 43 |
+
|
| 44 |
+
class VisModel(nn.Module):
|
| 45 |
+
"""define you own visualizatio model by specifying the structure
|
| 46 |
+
|
| 47 |
+
"""
|
| 48 |
+
def __init__(self, encoder_dims, decoder_dims):
|
| 49 |
+
"""define you own visualizatio model by specifying the structure
|
| 50 |
+
|
| 51 |
+
Parameters
|
| 52 |
+
----------
|
| 53 |
+
encoder_dims : list of int
|
| 54 |
+
the neuron number of your encoder
|
| 55 |
+
for example, [100,50,2], denote two fully connect layers, with shape (100,50) and (50,2)
|
| 56 |
+
decoder_dims : list of int
|
| 57 |
+
same as encoder_dims
|
| 58 |
+
"""
|
| 59 |
+
super(VisModel, self).__init__()
|
| 60 |
+
assert len(encoder_dims) > 1
|
| 61 |
+
assert len(decoder_dims) > 1
|
| 62 |
+
self.encoder_dims = encoder_dims
|
| 63 |
+
self.decoder_dims = decoder_dims
|
| 64 |
+
self._init_autoencoder()
|
| 65 |
+
|
| 66 |
+
def _init_autoencoder(self):
|
| 67 |
+
self.encoder = nn.Sequential()
|
| 68 |
+
for i in range(0, len(self.encoder_dims)-2):
|
| 69 |
+
self.encoder.add_module("{}".format(len(self.encoder)), nn.Linear(self.encoder_dims[i], self.encoder_dims[i+1]))
|
| 70 |
+
self.encoder.add_module("{}".format(len(self.encoder)), nn.ReLU(True))
|
| 71 |
+
self.encoder.add_module("{}".format(len(self.encoder)), nn.Linear(self.encoder_dims[-2], self.encoder_dims[-1]))
|
| 72 |
+
|
| 73 |
+
self.decoder = nn.Sequential()
|
| 74 |
+
for i in range(0, len(self.decoder_dims)-2):
|
| 75 |
+
self.decoder.add_module("{}".format(len(self.decoder)), nn.Linear(self.decoder_dims[i], self.decoder_dims[i+1]))
|
| 76 |
+
self.decoder.add_module("{}".format(len(self.decoder)), nn.ReLU(True))
|
| 77 |
+
self.decoder.add_module("{}".format(len(self.decoder)), nn.Linear(self.decoder_dims[-2], self.decoder_dims[-1]))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def forward(self, edge_to, edge_from):
|
| 81 |
+
outputs = dict()
|
| 82 |
+
embedding_to = self.encoder(edge_to)
|
| 83 |
+
embedding_from = self.encoder(edge_from)
|
| 84 |
+
recon_to = self.decoder(embedding_to)
|
| 85 |
+
recon_from = self.decoder(embedding_from)
|
| 86 |
+
|
| 87 |
+
outputs["umap"] = (embedding_to, embedding_from)
|
| 88 |
+
outputs["recon"] = (recon_to, recon_from)
|
| 89 |
+
|
| 90 |
+
return outputs
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
'''
|
| 94 |
+
The visualization model definition class
|
| 95 |
+
'''
|
| 96 |
+
import tensorflow as tf
|
| 97 |
+
from tensorflow import keras
|
| 98 |
+
class tfModel(keras.Model):
|
| 99 |
+
def __init__(self, optimizer, loss, loss_weights, encoder_dims, decoder_dims, batch_size, withoutB=True, attention=True, prev_trainable_variables=None):
|
| 100 |
+
|
| 101 |
+
super(tfModel, self).__init__()
|
| 102 |
+
self._init_autoencoder(encoder_dims, decoder_dims)
|
| 103 |
+
self.optimizer = optimizer # optimizer
|
| 104 |
+
self.withoutB = withoutB
|
| 105 |
+
self.attention = attention
|
| 106 |
+
|
| 107 |
+
self.loss = loss # dict of 3 losses {"total", "umap", "reconstrunction", "regularization"}
|
| 108 |
+
self.loss_weights = loss_weights # weights for each loss (in total 3 losses)
|
| 109 |
+
|
| 110 |
+
self.prev_trainable_variables = prev_trainable_variables # weights for previous iteration
|
| 111 |
+
self.batch_size = batch_size
|
| 112 |
+
|
| 113 |
+
def _init_autoencoder(self, encoder_dims, decoder_dims):
|
| 114 |
+
self.encoder = tf.keras.Sequential([
|
| 115 |
+
tf.keras.layers.InputLayer(input_shape=(encoder_dims[0],)),
|
| 116 |
+
tf.keras.layers.Flatten(),
|
| 117 |
+
])
|
| 118 |
+
for i in range(1, len(encoder_dims)-1, 1):
|
| 119 |
+
self.encoder.add(tf.keras.layers.Dense(units=encoder_dims[i], activation="relu"))
|
| 120 |
+
self.encoder.add(tf.keras.layers.Dense(units=encoder_dims[-1]),)
|
| 121 |
+
|
| 122 |
+
self.decoder = tf.keras.Sequential([
|
| 123 |
+
tf.keras.layers.InputLayer(input_shape=(decoder_dims[0],)),
|
| 124 |
+
])
|
| 125 |
+
for i in range(1, len(decoder_dims)-1, 1):
|
| 126 |
+
self.decoder.add(tf.keras.layers.Dense(units=decoder_dims[i], activation="relu"))
|
| 127 |
+
self.decoder.add(tf.keras.layers.Dense(units=decoder_dims[-1]))
|
| 128 |
+
print(self.encoder.summary())
|
| 129 |
+
print(self.decoder.summary())
|
| 130 |
+
|
| 131 |
+
def train_step(self, x):
|
| 132 |
+
|
| 133 |
+
to_x, from_x, to_alpha, from_alpha, n_rate, weight = x[0]
|
| 134 |
+
to_x = tf.cast(to_x, dtype=tf.float32)
|
| 135 |
+
from_x = tf.cast(from_x, dtype=tf.float32)
|
| 136 |
+
to_alpha = tf.cast(to_alpha, dtype=tf.float32)
|
| 137 |
+
from_alpha = tf.cast(from_alpha, dtype=tf.float32)
|
| 138 |
+
n_rate = tf.cast(n_rate, dtype=tf.float32)
|
| 139 |
+
weight = tf.cast(weight, dtype=tf.float32)
|
| 140 |
+
|
| 141 |
+
# Forward pass
|
| 142 |
+
with tf.GradientTape(persistent=True) as tape:
|
| 143 |
+
|
| 144 |
+
# parametric embedding
|
| 145 |
+
embedding_to = self.encoder(to_x) # embedding for instance 1
|
| 146 |
+
embedding_from = self.encoder(from_x) # embedding for instance 1
|
| 147 |
+
embedding_to_recon = self.decoder(embedding_to) # reconstruct instance 1
|
| 148 |
+
embedding_from_recon = self.decoder(embedding_from) # reconstruct instance 1
|
| 149 |
+
|
| 150 |
+
# concatenate embedding1 and embedding2 to prepare for umap loss
|
| 151 |
+
embedding_to_from = tf.concat((embedding_to, embedding_from, weight),
|
| 152 |
+
axis=1)
|
| 153 |
+
# reconstruction loss
|
| 154 |
+
if self.attention:
|
| 155 |
+
reconstruct_loss = self.loss["reconstruction"](to_x, from_x, embedding_to_recon, embedding_from_recon,to_alpha, from_alpha)
|
| 156 |
+
else:
|
| 157 |
+
self.loss["reconstruction"] = tf.keras.losses.MeanSquaredError()
|
| 158 |
+
reconstruct_loss = self.loss["reconstruction"](y_true=to_x, y_pred=embedding_to_recon)/2 + self.loss["reconstruction"](y_true=from_x, y_pred=embedding_from_recon)/2
|
| 159 |
+
|
| 160 |
+
# umap loss
|
| 161 |
+
umap_loss = self.loss["umap"](None, embed_to_from=embedding_to_from) # w_(t-1), no gradient
|
| 162 |
+
|
| 163 |
+
# compute alpha bar
|
| 164 |
+
alpha_mean = tf.cast(tf.reduce_mean(tf.stop_gradient(n_rate)), dtype=tf.float32)
|
| 165 |
+
# L2 norm of w current - w for last epoch (subject model's epoch)
|
| 166 |
+
# dummy zeros-loss if no previous epoch
|
| 167 |
+
if self.prev_trainable_variables is None:
|
| 168 |
+
prev_trainable_variables = [tf.stop_gradient(x) for x in self.trainable_variables]
|
| 169 |
+
else:
|
| 170 |
+
prev_trainable_variables = self.prev_trainable_variables
|
| 171 |
+
regularization_loss = self.loss["regularization"](w_prev=prev_trainable_variables,w_current=self.trainable_variables, to_alpha=alpha_mean)
|
| 172 |
+
|
| 173 |
+
# aggregate loss, weighted average
|
| 174 |
+
loss = tf.add(tf.add(tf.math.multiply(tf.constant(self.loss_weights["reconstruction"]), reconstruct_loss),
|
| 175 |
+
tf.math.multiply(tf.constant(self.loss_weights["umap"]), umap_loss)),
|
| 176 |
+
tf.math.multiply(tf.constant(self.loss_weights["regularization"]), regularization_loss))
|
| 177 |
+
|
| 178 |
+
# Compute gradients
|
| 179 |
+
trainable_vars = self.trainable_variables
|
| 180 |
+
grads = tape.gradient(loss, trainable_vars)
|
| 181 |
+
|
| 182 |
+
# Update weights
|
| 183 |
+
self.optimizer.apply_gradients(zip(grads, trainable_vars))
|
| 184 |
+
|
| 185 |
+
return {"loss": loss, "umap": umap_loss, "reconstruction": reconstruct_loss,
|
| 186 |
+
"regularization": regularization_loss}
|
| 187 |
+
|
| 188 |
+
|
saved_models/codesearch_simp/singleVis/__init__.py
ADDED
|
File without changes
|
saved_models/codesearch_simp/singleVis/__pycache__/SingleVisualizationModel.cpython-37.pyc
ADDED
|
Binary file (5.91 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/SingleVisualizationModel.cpython-39.pyc
ADDED
|
Binary file (5.93 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (111 Bytes). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (152 Bytes). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/active_sampling.cpython-37.pyc
ADDED
|
Binary file (860 Bytes). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/backend.cpython-37.pyc
ADDED
|
Binary file (5.09 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/backend.cpython-39.pyc
ADDED
|
Binary file (5.12 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/custom_weighted_random_sampler.cpython-37.pyc
ADDED
|
Binary file (2.01 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/custom_weighted_random_sampler.cpython-39.pyc
ADDED
|
Binary file (1.12 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/data.cpython-37.pyc
ADDED
|
Binary file (35.7 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/data.cpython-39.pyc
ADDED
|
Binary file (32.5 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/edge_dataset.cpython-37.pyc
ADDED
|
Binary file (5.22 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/edge_dataset.cpython-39.pyc
ADDED
|
Binary file (5.15 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/intrinsic_dim.cpython-37.pyc
ADDED
|
Binary file (4.42 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/intrinsic_dim.cpython-39.pyc
ADDED
|
Binary file (4.44 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/jj1sk.cpython-37.pyc
ADDED
|
Binary file (16.5 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/jj51sk.cpython-37.pyc
ADDED
|
Binary file (16.5 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/jj551sk.cpython-37.pyc
ADDED
|
Binary file (16.5 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/jjsk.cpython-37.pyc
ADDED
|
Binary file (16.6 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/kcenter_greedy.cpython-37.pyc
ADDED
|
Binary file (5.29 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/kcenter_greedy.cpython-39.pyc
ADDED
|
Binary file (4.9 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/losses.cpython-37.pyc
ADDED
|
Binary file (12.1 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/losses.cpython-39.pyc
ADDED
|
Binary file (12.1 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/projector.cpython-37.pyc
ADDED
|
Binary file (17.6 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/sVis.cpython-37.pyc
ADDED
|
Binary file (16.6 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/s_Vis.cpython-37.pyc
ADDED
|
Binary file (16.6 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/segmenter.cpython-37.pyc
ADDED
|
Binary file (3.82 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/skeVis.cpython-37.pyc
ADDED
|
Binary file (16.4 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/skeleVis.cpython-37.pyc
ADDED
|
Binary file (16.6 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/skele_Vis.cpython-37.pyc
ADDED
|
Binary file (16.4 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/skele_viser.cpython-37.pyc
ADDED
|
Binary file (16.4 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/skeletonVis.cpython-37.pyc
ADDED
|
Binary file (16.3 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/skeletonViser.cpython-37.pyc
ADDED
|
Binary file (16.3 kB). View file
|
|
|
saved_models/codesearch_simp/singleVis/__pycache__/skeletonVisualizer.cpython-37.pyc
ADDED
|
Binary file (16.3 kB). View file
|
|
|