| | |
| | |
| | |
| |
|
| | import TextStoryAI.model |
| | import os |
| |
|
| | def load_model(load_weights=True, debug=False): |
| | m = model.create_model() |
| | if load_weights: |
| | loadWeights(m, debug=debug) |
| | return m |
| |
|
| | |
| |
|
| | import tensorflow as tf |
| | import h5py |
| |
|
| | def loadWeights(model, filename=os.path.join(__package__, "weights.h5"), debug=False): |
| | with h5py.File(filename, 'r') as f: |
| | |
| | for g in f: |
| | if isinstance(f[g], h5py.Group): |
| | group = f[g] |
| | layerName = group.attrs['Name'] |
| | numVars = int(group.attrs['NumVars']) |
| | if debug: |
| | print("layerName:", layerName) |
| | print(" numVars:", numVars) |
| | |
| | layerIdx = layerNum(model, layerName) |
| | layer = model.layers[layerIdx] |
| | if debug: |
| | print(" layerIdx=", layerIdx) |
| | |
| | |
| | weightList = [0]*numVars |
| | for d in group: |
| | dataset = group[d] |
| | varName = dataset.attrs['Name'] |
| | shp = intList(dataset.attrs['Shape']) |
| | weightNum = int(dataset.attrs['WeightNum']) |
| | |
| | if debug: |
| | print(" varName:", varName) |
| | print(" shp:", shp) |
| | print(" weightNum:", weightNum) |
| | weightList[weightNum] = tf.constant(dataset[()], shape=shp) |
| | |
| | for w in range(numVars): |
| | if debug: |
| | print("Copying variable of shape:") |
| | print(weightList[w].shape) |
| | layer.variables[w].assign(weightList[w]) |
| | if debug: |
| | print("Assignment successful.") |
| | print("Set variable value:") |
| | print(layer.variables[w]) |
| | |
| | if hasattr(layer, 'finalize_state'): |
| | layer.finalize_state() |
| |
|
| | def layerNum(model, layerName): |
| | |
| | layers = model.layers |
| | for i in range(len(layers)): |
| | if layerName==layers[i].name: |
| | return i |
| | print("") |
| | print("WEIGHT LOADING FAILED. MODEL DOES NOT CONTAIN LAYER WITH NAME: ", layerName) |
| | print("") |
| | return -1 |
| |
|
| | def intList(myList): |
| | |
| | return list(map(int, myList)) |
| |
|
| |
|