sachin1801 commited on
Commit
b801ecb
·
1 Parent(s): 08cdda3

removed test script not needed anymore for testing pre-trained model

Browse files
Files changed (1) hide show
  1. test_model.py +0 -57
test_model.py DELETED
@@ -1,57 +0,0 @@
1
- """Simple script to test the pre-trained splicing model.
2
-
3
- This script uses the approach from the original notebooks:
4
- - figures/generate_csv_for_supplementary.ipynb
5
- - 2022_03_11_figures/position_specific_activations.ipynb
6
-
7
- Requires: Python 3.10 + TensorFlow 2.10 (see README for setup)
8
- """
9
-
10
- import sys
11
-
12
- # Add figures directory to path so we can import quad_model
13
- sys.path.insert(0, 'figures')
14
-
15
- # Import from quad_model - this auto-registers all custom layers
16
- # via @tf.keras.utils.register_keras_serializable() decorators
17
- from quad_model import *
18
- from tensorflow.keras.models import load_model
19
- from joblib import load as jload
20
- import numpy as np
21
-
22
- print("Loading model...")
23
- model = load_model('output/custom_adjacency_regularizer_20210731_124_step3.h5')
24
- print("Model loaded successfully!")
25
-
26
- print("\nLoading test data...")
27
- xTe = jload('data/xTe_ES7_HeLa_ABC.pkl.gz')
28
- yTe = jload('data/yTe_ES7_HeLa_ABC.pkl.gz')
29
-
30
- num_samples = len(xTe[0]) if isinstance(xTe, list) else len(xTe)
31
- print(f"Number of test samples: {num_samples}")
32
-
33
- print("\nRunning predictions...")
34
- predictions = model.predict(xTe, verbose=0)
35
-
36
- print(f"\nResults:")
37
- print(f"Predictions shape: {predictions.shape}")
38
- print(f"\nFirst 10 predictions vs actual PSI values:")
39
- print("-" * 50)
40
- print(f"{'Predicted PSI':<15} {'Actual PSI':<15} {'Diff':<10}")
41
- print("-" * 50)
42
- for i in range(min(10, len(predictions))):
43
- pred = predictions[i, 0]
44
- actual = yTe[i]
45
- diff = pred - actual
46
- print(f"{pred:<15.4f} {actual:<15.4f} {diff:<10.4f}")
47
-
48
- # Calculate overall metrics
49
- from sklearn.metrics import mean_squared_error, r2_score
50
- mse = mean_squared_error(yTe, predictions)
51
- r2 = r2_score(yTe, predictions)
52
- correlation = np.corrcoef(yTe.flatten(), predictions.flatten())[0, 1]
53
-
54
- print(f"\nOverall Metrics:")
55
- print(f" MSE: {mse:.6f}")
56
- print(f" R2 Score: {r2:.4f}")
57
- print(f" Correlation: {correlation:.4f}")