COCODEDE04 commited on
Commit
6be85f7
·
verified ·
1 Parent(s): f92c118

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -35
app.py CHANGED
@@ -1,6 +1,9 @@
1
  import os, json, io, traceback
2
  from typing import Any, Dict, List, Optional
3
 
 
 
 
4
  import numpy as np
5
  import tensorflow as tf
6
  from fastapi import FastAPI, Request
@@ -546,9 +549,8 @@ async def predict(req: Request):
546
  content={"error": str(e), "trace": traceback.format_exc()},
547
  )
548
 
549
-
550
- # ============================================================
551
- # CORAL ORDINAL HELPERS (copied from your training script)
552
  # ============================================================
553
 
554
  def to_cumulative_targets_tf(y_true_int, K_):
@@ -558,6 +560,7 @@ def to_cumulative_targets_tf(y_true_int, K_):
558
  T = tf.cast(tf.greater_equal(y[:, None], thresholds[None, :]), tf.float32)
559
  return T
560
 
 
561
  def coral_loss_tf(y_true, logits):
562
  y_true = tf.reshape(y_true, [-1])
563
  y_true = tf.cast(y_true, tf.int32)
@@ -565,21 +568,32 @@ def coral_loss_tf(y_true, logits):
565
  bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=T, logits=logits)
566
  return tf.reduce_mean(tf.reduce_sum(bce, axis=1))
567
 
568
- def coral_probs_from_logits(logits):
569
- sig = tf.math.sigmoid(logits)
 
 
 
570
  left = tf.concat([tf.ones_like(sig[:, :1]), sig], axis=1)
571
  right = tf.concat([sig, tf.zeros_like(sig[:, :1])], axis=1)
572
  return tf.clip_by_value(left - right, 1e-12, 1.0)
573
 
 
 
 
 
 
 
 
574
  @tf.function
575
  def ordinal_accuracy_metric(y_true, y_pred_logits):
576
  y_true = tf.reshape(y_true, [-1])
577
  y_true = tf.cast(y_true, tf.int32)
578
- probs = coral_probs_from_logits(y_pred_logits)
579
  y_pred = tf.argmax(probs, axis=1, output_type=tf.int32)
580
  return tf.reduce_mean(tf.cast(tf.equal(y_true, y_pred), tf.float32))
581
 
582
- # ============================================================
 
583
  # RECREATE MODEL FROM BEST HYPERPARAMETERS
584
  # ============================================================
585
 
@@ -608,7 +622,6 @@ def build_model_from_hparams(hp: dict):
608
  if drop > 0:
609
  x = tf.keras.layers.Dropout(drop)(x)
610
 
611
- # CORAL output
612
  outputs = tf.keras.layers.Dense(len(CLASSES) - 1, activation=None)(x)
613
 
614
  model = tf.keras.Model(inputs, outputs)
@@ -620,12 +633,12 @@ def build_model_from_hparams(hp: dict):
620
  return model
621
 
622
 
623
- # ============================================================
624
- # RETRAINING LOGIC
625
  # ============================================================
626
 
627
- FINGERPRINT_CSV = "fingerprints_db.csv" # <-- choose file name
628
- BEST_HP_JSON = "best_params_and_metrics.json"
629
 
630
 
631
  def load_best_hparams():
@@ -637,11 +650,9 @@ def load_best_hparams():
637
  def load_fingerprint_dataset():
638
  df = pd.read_csv(FINGERPRINT_CSV)
639
 
640
- # Must include: company, date, rating, and 21 features
641
- y = df["rating"].map({c:i for i,c in enumerate(CLASSES)}).astype("int32").to_numpy()
642
  X_raw = df[FEATURES].to_numpy().astype("float32")
643
 
644
- # Fit imputer + scaler from full dataset
645
  imp = SimpleImputer(strategy="median")
646
  sc = StandardScaler()
647
 
@@ -658,40 +669,57 @@ def retrain_model():
658
  model_new = build_model_from_hparams(hp)
659
 
660
  es = tf.keras.callbacks.EarlyStopping(
661
- monitor="loss", patience=15, restore_best_weights=True
 
 
 
662
  )
663
 
664
  model_new.fit(
665
  X, y,
666
  epochs=150,
667
  batch_size=128,
 
668
  verbose=1,
669
- callbacks=[es]
670
  )
671
 
672
- # Update globals used by /predict
673
- global model, IMPUTER, SCALER
674
  model = model_new
675
- IMPUTER = imp
676
- SCALER = sc
 
 
 
 
 
 
 
 
 
 
 
677
 
678
  return True
679
 
680
 
681
-
682
- # ============================================================
683
  # API ENDPOINT: APPEND + RETRAIN
684
  # ============================================================
685
 
686
  @app.post("/append_and_retrain")
687
  def append_and_retrain(payload: dict):
688
  """
689
- payload:
690
  {
691
  "company": "...",
692
  "date": "2025-Q1",
693
- "rating": "Mid-Top",
694
- "features": { autosuf_oper: ..., improductiva: ..., ... }
 
 
 
 
695
  }
696
  """
697
 
@@ -700,14 +728,20 @@ def append_and_retrain(payload: dict):
700
  rating = payload.get("rating")
701
  feats = payload.get("features", {})
702
 
703
- if not company or not rating or len(feats) != len(FEATURES):
704
- return {"ok": False, "error": "Invalid payload"}
705
 
706
- # Append to CSV
707
- df_new = pd.DataFrame([{**{"company": company,
708
- "date": date,
709
- "rating": rating},
710
- **feats}])
 
 
 
 
 
 
711
 
712
  if os.path.exists(FINGERPRINT_CSV):
713
  df = pd.read_csv(FINGERPRINT_CSV)
@@ -717,7 +751,7 @@ def append_and_retrain(payload: dict):
717
 
718
  df.to_csv(FINGERPRINT_CSV, index=False)
719
 
720
- # Retrain
721
  retrain_model()
722
 
723
- return {"ok": True, "message": "Fingerprint added and model retrained"}
 
1
  import os, json, io, traceback
2
  from typing import Any, Dict, List, Optional
3
 
4
+ import pandas as pd
5
+ from sklearn.impute import SimpleImputer
6
+ from sklearn.preprocessing import StandardScaler
7
  import numpy as np
8
  import tensorflow as tf
9
  from fastapi import FastAPI, Request
 
549
  content={"error": str(e), "trace": traceback.format_exc()},
550
  )
551
 
552
+ # ============================================================
553
+ # CORAL ORDINAL HELPERS (from training script)
 
554
  # ============================================================
555
 
556
  def to_cumulative_targets_tf(y_true_int, K_):
 
560
  T = tf.cast(tf.greater_equal(y[:, None], thresholds[None, :]), tf.float32)
561
  return T
562
 
563
+
564
  def coral_loss_tf(y_true, logits):
565
  y_true = tf.reshape(y_true, [-1])
566
  y_true = tf.cast(y_true, tf.int32)
 
568
  bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=T, logits=logits)
569
  return tf.reduce_mean(tf.reduce_sum(bce, axis=1))
570
 
571
+
572
+ # ---------- TF helper & numpy wrapper (unified version) ----------
573
+ def _coral_probs_from_logits_tf(logits_tf: tf.Tensor) -> tf.Tensor:
574
+ """Pure TF CORAL probability transform."""
575
+ sig = tf.math.sigmoid(logits_tf)
576
  left = tf.concat([tf.ones_like(sig[:, :1]), sig], axis=1)
577
  right = tf.concat([sig, tf.zeros_like(sig[:, :1])], axis=1)
578
  return tf.clip_by_value(left - right, 1e-12, 1.0)
579
 
580
+
581
+ def coral_probs_from_logits(logits_np: np.ndarray) -> np.ndarray:
582
+ """Numpy wrapper used by decode_logits + SHAP."""
583
+ logits_tf = tf.convert_to_tensor(logits_np, dtype=tf.float32)
584
+ return _coral_probs_from_logits_tf(logits_tf).numpy()
585
+
586
+
587
  @tf.function
588
  def ordinal_accuracy_metric(y_true, y_pred_logits):
589
  y_true = tf.reshape(y_true, [-1])
590
  y_true = tf.cast(y_true, tf.int32)
591
+ probs = _coral_probs_from_logits_tf(y_pred_logits)
592
  y_pred = tf.argmax(probs, axis=1, output_type=tf.int32)
593
  return tf.reduce_mean(tf.cast(tf.equal(y_true, y_pred), tf.float32))
594
 
595
+
596
+ # ============================================================
597
  # RECREATE MODEL FROM BEST HYPERPARAMETERS
598
  # ============================================================
599
 
 
622
  if drop > 0:
623
  x = tf.keras.layers.Dropout(drop)(x)
624
 
 
625
  outputs = tf.keras.layers.Dense(len(CLASSES) - 1, activation=None)(x)
626
 
627
  model = tf.keras.Model(inputs, outputs)
 
633
  return model
634
 
635
 
636
+ # ============================================================
637
+ # RETRAINING LOGIC + DATASET MGMT
638
  # ============================================================
639
 
640
+ FINGERPRINT_CSV = "fingerprints_db.csv"
641
+ BEST_HP_JSON = "best_params_and_metrics.json"
642
 
643
 
644
  def load_best_hparams():
 
650
  def load_fingerprint_dataset():
651
  df = pd.read_csv(FINGERPRINT_CSV)
652
 
653
+ y = df["rating"].map({c: i for i, c in enumerate(CLASSES)}).astype("int32").to_numpy()
 
654
  X_raw = df[FEATURES].to_numpy().astype("float32")
655
 
 
656
  imp = SimpleImputer(strategy="median")
657
  sc = StandardScaler()
658
 
 
669
  model_new = build_model_from_hparams(hp)
670
 
671
  es = tf.keras.callbacks.EarlyStopping(
672
+ monitor="loss",
673
+ patience=15,
674
+ restore_best_weights=True,
675
+ verbose=1
676
  )
677
 
678
  model_new.fit(
679
  X, y,
680
  epochs=150,
681
  batch_size=128,
682
+ callbacks=[es],
683
  verbose=1,
 
684
  )
685
 
686
+ # Update global model + preprocessors
687
+ global model, imputer, scaler
688
  model = model_new
689
+ imputer = imp
690
+ scaler = sc
691
+
692
+ # Rebuild SHAP explainer to match new model
693
+ global EXPLAINER
694
+ if SHAP_AVAILABLE:
695
+ try:
696
+ BACKGROUND_Z = np.zeros((50, len(FEATURES)), dtype=np.float32)
697
+ EXPLAINER = shap.KernelExplainer(model_proba_from_z, BACKGROUND_Z)
698
+ print("SHAP explainer rebuilt after retrain.")
699
+ except Exception as e:
700
+ EXPLAINER = None
701
+ print("⚠️ Failed to rebuild SHAP explainer:", repr(e))
702
 
703
  return True
704
 
705
 
706
+ # ============================================================
 
707
  # API ENDPOINT: APPEND + RETRAIN
708
  # ============================================================
709
 
710
  @app.post("/append_and_retrain")
711
  def append_and_retrain(payload: dict):
712
  """
713
+ payload format:
714
  {
715
  "company": "...",
716
  "date": "2025-Q1",
717
+ "rating": "Mid",
718
+ "features": {
719
+ "autosuf_oper": ...,
720
+ "improductiva": ...,
721
+ ...
722
+ }
723
  }
724
  """
725
 
 
728
  rating = payload.get("rating")
729
  feats = payload.get("features", {})
730
 
731
+ if not company or not date or not rating:
732
+ return {"ok": False, "error": "Missing company/date/rating"}
733
 
734
+ if set(feats.keys()) != set(FEATURES):
735
+ return {"ok": False, "error": "Features missing or incorrect"}
736
+
737
+ # Append row
738
+ new_row = {
739
+ "company": company,
740
+ "date": date,
741
+ "rating": rating,
742
+ **feats
743
+ }
744
+ df_new = pd.DataFrame([new_row])
745
 
746
  if os.path.exists(FINGERPRINT_CSV):
747
  df = pd.read_csv(FINGERPRINT_CSV)
 
751
 
752
  df.to_csv(FINGERPRINT_CSV, index=False)
753
 
754
+ # Retrain model
755
  retrain_model()
756
 
757
+ return {"ok": True, "message": "Fingerprint appended + model retrained"}