astirn commited on
Commit
e78fda7
·
1 Parent(s): af865b9
Files changed (1) hide show
  1. tiger.py +23 -11
tiger.py CHANGED
@@ -99,17 +99,29 @@ def process_data(transcript_seq: str):
99
  return target_seq, guide_seq, model_inputs
100
 
101
 
102
- def prediction_transform(predictions: np.array, cutoff_path: str = 'cutoff.npy'):
103
- """
104
- :param predictions: in [0,1] where 0 represents most active guides
105
- :param cutoff_path: full path to cutoff.npy (a float in [0,1] above which guides are inactive)
106
- :return: predictions in [0,1] where 1 represents most active guides
107
- """
108
- cutoff = np.load(cutoff_path)
109
- predictions[predictions > cutoff] = cutoff + (predictions[predictions > cutoff] - cutoff) * 0.01
110
- predictions = predictions.max() - predictions
111
- predictions = predictions / predictions.max()
112
- return predictions
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
 
115
  def predict_on_target(transcript_seq: str, model: tf.keras.Model):
 
99
  return target_seq, guide_seq, model_inputs
100
 
101
 
102
+ def prediction_transform(predictions: np.array, **params):
103
+
104
+ # regime indices
105
+ active_saturation = predictions < params['a']
106
+ linear_regime = (params['a'] <= predictions) & (predictions <= params['c'])
107
+ inactive_saturation = params['c'] < predictions
108
+
109
+ # linear regime
110
+ slope = (params['d'] - params['b']) / (params['c'] - params['a'])
111
+ intercept = -params['a'] * slope + params['b']
112
+ predictions[linear_regime] = slope * predictions[linear_regime] + intercept
113
+
114
+ # active saturation regime
115
+ alpha = slope / params['b']
116
+ beta = alpha * params['a'] - np.log(params['b'])
117
+ predictions[active_saturation] = np.exp(alpha * predictions[active_saturation] - beta)
118
+
119
+ # inactive saturation regime
120
+ alpha = slope / (1 - params['d'])
121
+ beta = -alpha * params['c'] - np.log(1 - params['d'])
122
+ predictions[inactive_saturation] = 1 - np.exp(-alpha * predictions[inactive_saturation] - beta)
123
+
124
+ return 1 - predictions
125
 
126
 
127
  def predict_on_target(transcript_seq: str, model: tf.keras.Model):