Spaces:
Running
Running
P2SAMAPA commited on
Update models.py
Browse files
models.py
CHANGED
|
@@ -129,3 +129,11 @@ def train_tft(X_train, y_train, X_val, y_val, epochs=200,
|
|
| 129 |
callbacks=callbacks, verbose=1, shuffle=True
|
| 130 |
)
|
| 131 |
return model, history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
callbacks=callbacks, verbose=1, shuffle=True
|
| 130 |
)
|
| 131 |
return model, history
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def predict_tft(model, X_test):
|
| 135 |
+
"""
|
| 136 |
+
Run inference with the TFT model.
|
| 137 |
+
Returns softmax probability array of shape (N, num_classes).
|
| 138 |
+
"""
|
| 139 |
+
return model.predict(X_test, verbose=0)
|