| | import pickle |
| | from pathlib import Path |
| | from face_feature.core import pathex |
| | import numpy as np |
| |
|
| | from face_feature.core.leras import nn |
| |
|
| | tf = nn.tf |
| |
|
| | class Saveable(): |
| | def __init__(self, name=None): |
| | self.name = name |
| |
|
| | |
| | def get_weights(self): |
| | |
| | return [] |
| |
|
| | |
| | def get_weights_np(self): |
| | weights = self.get_weights() |
| | if len(weights) == 0: |
| | return [] |
| | return nn.tf_sess.run (weights) |
| |
|
| | def set_weights(self, new_weights): |
| | weights = self.get_weights() |
| | if len(weights) != len(new_weights): |
| | raise ValueError ('len of lists mismatch') |
| |
|
| | tuples = [] |
| | for w, new_w in zip(weights, new_weights): |
| |
|
| | if len(w.shape) != new_w.shape: |
| | new_w = new_w.reshape(w.shape) |
| |
|
| | tuples.append ( (w, new_w) ) |
| |
|
| | nn.batch_set_value (tuples) |
| |
|
| | def save_weights(self, filename, force_dtype=None): |
| | d = {} |
| | weights = self.get_weights() |
| |
|
| | if self.name is None: |
| | raise Exception("name must be defined.") |
| |
|
| | name = self.name |
| | for w, w_val in zip(weights, nn.tf_sess.run (weights)): |
| | w_name_split = w.name.split('/', 1) |
| | if name != w_name_split[0]: |
| | raise Exception("weight first name != Saveable.name") |
| |
|
| | if force_dtype is not None: |
| | w_val = w_val.astype(force_dtype) |
| |
|
| | d[ w_name_split[1] ] = w_val |
| |
|
| | d_dumped = pickle.dumps (d, 4) |
| | pathex.write_bytes_safe ( Path(filename), d_dumped ) |
| |
|
| | def load_weights(self, filename): |
| | """ |
| | returns True if file exists |
| | """ |
| | filepath = Path(filename) |
| | if filepath.exists(): |
| | result = True |
| | d_dumped = filepath.read_bytes() |
| | d = pickle.loads(d_dumped) |
| | else: |
| | return False |
| |
|
| | weights = self.get_weights() |
| |
|
| | if self.name is None: |
| | raise Exception("name must be defined.") |
| |
|
| | try: |
| | tuples = [] |
| | for w in weights: |
| | w_name_split = w.name.split('/') |
| | if self.name != w_name_split[0]: |
| | raise Exception("weight first name != Saveable.name") |
| |
|
| | sub_w_name = "/".join(w_name_split[1:]) |
| |
|
| | w_val = d.get(sub_w_name, None) |
| |
|
| | if w_val is None: |
| | |
| | tuples.append ( (w, w.initializer) ) |
| | else: |
| | w_val = np.reshape( w_val, w.shape.as_list() ) |
| | tuples.append ( (w, w_val) ) |
| |
|
| | nn.batch_set_value(tuples) |
| | except: |
| | return False |
| | |
| | return True |
| |
|
| | def init_weights(self): |
| | nn.init_weights(self.get_weights()) |
| | |
| | nn.Saveable = Saveable |
| |
|