COCODEDE04 commited on
Commit
f92c118
·
verified ·
1 Parent(s): 1122e44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -1
app.py CHANGED
@@ -544,4 +544,180 @@ async def predict(req: Request):
544
  return JSONResponse(
545
  status_code=500,
546
  content={"error": str(e), "trace": traceback.format_exc()},
547
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
  return JSONResponse(
545
  status_code=500,
546
  content={"error": str(e), "trace": traceback.format_exc()},
547
+ )
548
+
549
+
550
+ # ============================================================
551
+ # CORAL ORDINAL HELPERS (copied from your training script)
552
+ # ============================================================
553
+
554
+ def to_cumulative_targets_tf(y_true_int, K_):
555
+ y = tf.reshape(y_true_int, [-1])
556
+ y = tf.cast(y, tf.int32)
557
+ thresholds = tf.range(1, K_, dtype=tf.int32)
558
+ T = tf.cast(tf.greater_equal(y[:, None], thresholds[None, :]), tf.float32)
559
+ return T
560
+
561
+ def coral_loss_tf(y_true, logits):
562
+ y_true = tf.reshape(y_true, [-1])
563
+ y_true = tf.cast(y_true, tf.int32)
564
+ T = to_cumulative_targets_tf(y_true, len(CLASSES))
565
+ bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=T, logits=logits)
566
+ return tf.reduce_mean(tf.reduce_sum(bce, axis=1))
567
+
568
+ def coral_probs_from_logits(logits):
569
+ sig = tf.math.sigmoid(logits)
570
+ left = tf.concat([tf.ones_like(sig[:, :1]), sig], axis=1)
571
+ right = tf.concat([sig, tf.zeros_like(sig[:, :1])], axis=1)
572
+ return tf.clip_by_value(left - right, 1e-12, 1.0)
573
+
574
+ @tf.function
575
+ def ordinal_accuracy_metric(y_true, y_pred_logits):
576
+ y_true = tf.reshape(y_true, [-1])
577
+ y_true = tf.cast(y_true, tf.int32)
578
+ probs = coral_probs_from_logits(y_pred_logits)
579
+ y_pred = tf.argmax(probs, axis=1, output_type=tf.int32)
580
+ return tf.reduce_mean(tf.cast(tf.equal(y_true, y_pred), tf.float32))
581
+
582
+ # ============================================================
583
+ # RECREATE MODEL FROM BEST HYPERPARAMETERS
584
+ # ============================================================
585
+
586
+ def build_model_from_hparams(hp: dict):
587
+ inputs = tf.keras.Input(shape=(len(FEATURES),))
588
+ x = inputs
589
+
590
+ n_hidden = hp["n_hidden"]
591
+ use_bn = hp["batchnorm"]
592
+ act = hp["activation"]
593
+ l2_reg = hp["l2"]
594
+
595
+ for i in range(1, n_hidden + 1):
596
+ units = hp[f"units_{i}"]
597
+ drop = hp[f"dropout_{i}"]
598
+
599
+ x = tf.keras.layers.Dense(
600
+ units,
601
+ activation=act,
602
+ kernel_regularizer=tf.keras.regularizers.l2(l2_reg)
603
+ )(x)
604
+
605
+ if use_bn:
606
+ x = tf.keras.layers.BatchNormalization()(x)
607
+
608
+ if drop > 0:
609
+ x = tf.keras.layers.Dropout(drop)(x)
610
+
611
+ # CORAL output
612
+ outputs = tf.keras.layers.Dense(len(CLASSES) - 1, activation=None)(x)
613
+
614
+ model = tf.keras.Model(inputs, outputs)
615
+ model.compile(
616
+ optimizer=tf.keras.optimizers.Adam(learning_rate=hp["lr"]),
617
+ loss=coral_loss_tf,
618
+ metrics=[ordinal_accuracy_metric],
619
+ )
620
+ return model
621
+
622
+
623
+ # ============================================================
624
+ # RETRAINING LOGIC
625
+ # ============================================================
626
+
627
+ FINGERPRINT_CSV = "fingerprints_db.csv" # <-- choose file name
628
+ BEST_HP_JSON = "best_params_and_metrics.json"
629
+
630
+
631
+ def load_best_hparams():
632
+ with open(BEST_HP_JSON, "r") as f:
633
+ js = json.load(f)
634
+ return js["best_hyperparams"]
635
+
636
+
637
+ def load_fingerprint_dataset():
638
+ df = pd.read_csv(FINGERPRINT_CSV)
639
+
640
+ # Must include: company, date, rating, and 21 features
641
+ y = df["rating"].map({c:i for i,c in enumerate(CLASSES)}).astype("int32").to_numpy()
642
+ X_raw = df[FEATURES].to_numpy().astype("float32")
643
+
644
+ # Fit imputer + scaler from full dataset
645
+ imp = SimpleImputer(strategy="median")
646
+ sc = StandardScaler()
647
+
648
+ X_imp = imp.fit_transform(X_raw)
649
+ X_sc = sc.fit_transform(X_imp)
650
+
651
+ return X_sc, y, imp, sc
652
+
653
+
654
+ def retrain_model():
655
+ hp = load_best_hparams()
656
+ X, y, imp, sc = load_fingerprint_dataset()
657
+
658
+ model_new = build_model_from_hparams(hp)
659
+
660
+ es = tf.keras.callbacks.EarlyStopping(
661
+ monitor="loss", patience=15, restore_best_weights=True
662
+ )
663
+
664
+ model_new.fit(
665
+ X, y,
666
+ epochs=150,
667
+ batch_size=128,
668
+ verbose=1,
669
+ callbacks=[es]
670
+ )
671
+
672
+ # Update globals used by /predict
673
+ global model, IMPUTER, SCALER
674
+ model = model_new
675
+ IMPUTER = imp
676
+ SCALER = sc
677
+
678
+ return True
679
+
680
+
681
+
682
+ # ============================================================
683
+ # API ENDPOINT: APPEND + RETRAIN
684
+ # ============================================================
685
+
686
+ @app.post("/append_and_retrain")
687
+ def append_and_retrain(payload: dict):
688
+ """
689
+ payload:
690
+ {
691
+ "company": "...",
692
+ "date": "2025-Q1",
693
+ "rating": "Mid-Top",
694
+ "features": { autosuf_oper: ..., improductiva: ..., ... }
695
+ }
696
+ """
697
+
698
+ company = payload.get("company")
699
+ date = payload.get("date")
700
+ rating = payload.get("rating")
701
+ feats = payload.get("features", {})
702
+
703
+ if not company or not rating or len(feats) != len(FEATURES):
704
+ return {"ok": False, "error": "Invalid payload"}
705
+
706
+ # Append to CSV
707
+ df_new = pd.DataFrame([{**{"company": company,
708
+ "date": date,
709
+ "rating": rating},
710
+ **feats}])
711
+
712
+ if os.path.exists(FINGERPRINT_CSV):
713
+ df = pd.read_csv(FINGERPRINT_CSV)
714
+ df = pd.concat([df, df_new], ignore_index=True)
715
+ else:
716
+ df = df_new
717
+
718
+ df.to_csv(FINGERPRINT_CSV, index=False)
719
+
720
+ # Retrain
721
+ retrain_model()
722
+
723
+ return {"ok": True, "message": "Fingerprint added and model retrained"}