P2SAMAPA commited on
Commit
c91df12
·
unverified ·
1 Parent(s): 24d4524

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +8 -0
models.py CHANGED
@@ -129,3 +129,11 @@ def train_tft(X_train, y_train, X_val, y_val, epochs=200,
129
  callbacks=callbacks, verbose=1, shuffle=True
130
  )
131
  return model, history
 
 
 
 
 
 
 
 
 
129
  callbacks=callbacks, verbose=1, shuffle=True
130
  )
131
  return model, history
132
+
133
+
134
+ def predict_tft(model, X_test):
135
+ """
136
+ Run inference with the TFT model.
137
+ Returns softmax probability array of shape (N, num_classes).
138
+ """
139
+ return model.predict(X_test, verbose=0)