File size: 2,964 Bytes
a554f94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#    This file was created by
#    MATLAB Deep Learning Toolbox Converter for TensorFlow Models.
#    13-Aug-2025 21:15:39

import TextStoryAI.model
import os

def load_model(load_weights=True, debug=False):
    m = model.create_model()
    if load_weights:
        loadWeights(m, debug=debug)
    return m

## Utility functions:

import tensorflow as tf
import h5py

def loadWeights(model, filename=os.path.join(__package__, "weights.h5"), debug=False):
    with h5py.File(filename, 'r') as f:
        # Every layer is an h5 group. Ignore non-groups (such as /0)
        for g in f:
            if isinstance(f[g], h5py.Group):
                group = f[g]
                layerName = group.attrs['Name']
                numVars = int(group.attrs['NumVars'])
                if debug:
                    print("layerName:", layerName)
                    print("    numVars:", numVars)
                # Find the layer index from its namevar
                layerIdx = layerNum(model, layerName)
                layer = model.layers[layerIdx]
                if debug:
                    print("    layerIdx=", layerIdx)
                # Every weight is an h5 dataset in the layer group. Read the weights 
                # into a list in the correct order
                weightList = [0]*numVars
                for d in group:
                    dataset = group[d]
                    varName = dataset.attrs['Name']
                    shp     = intList(dataset.attrs['Shape'])
                    weightNum = int(dataset.attrs['WeightNum'])
                    # Read the weight and put it into the right position in the list
                    if debug:
                        print("    varName:", varName)
                        print("        shp:", shp)
                        print("        weightNum:", weightNum)
                    weightList[weightNum] = tf.constant(dataset[()], shape=shp)
                # Assign the weights into the layer
                for w in range(numVars):
                    if debug:
                        print("Copying variable of shape:")
                        print(weightList[w].shape)
                    layer.variables[w].assign(weightList[w])
                    if debug:
                        print("Assignment successful.")
                        print("Set variable value:")
                        print(layer.variables[w])
                # Finalize layer state
                if hasattr(layer, 'finalize_state'):
                    layer.finalize_state()

def layerNum(model, layerName):
    # Returns the index to the layer
    layers = model.layers
    for i in range(len(layers)):
        if layerName==layers[i].name:
            return i
    print("")
    print("WEIGHT LOADING FAILED. MODEL DOES NOT CONTAIN LAYER WITH NAME: ", layerName)
    print("")
    return -1

def intList(myList): 
    # Converts a list of numbers into a list of ints.
    return list(map(int, myList))