TheAlienSeb commited on
Commit
76b187f
·
verified ·
1 Parent(s): 2c3d3e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py CHANGED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import gradio as gr
4
+ import keras
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ REPO_ID = "DrMarcus24/stock-predictor-dev"
8
+
9
+ # Load architecture from config.json then load weights (.h5)
10
+ cfg_path = hf_hub_download(REPO_ID, "config.json") # downloads & caches; returns local path
11
+ with open(cfg_path, "r") as f:
12
+ cfg = json.load(f)
13
+
14
+ # Keras 3: rebuild from dict, then load weights
15
+ model = keras.saving.deserialize_keras_object(cfg) # alias: keras.utils.deserialize_keras_object
16
+ w_path = hf_hub_download(REPO_ID, "model.weights.h5")
17
+ model.load_weights(w_path)
18
+
19
+ def predict_one(csv_line: str):
20
+ # Expect 9 comma-separated features
21
+ try:
22
+ feats = [float(v) for v in csv_line.strip().split(",")]
23
+ except Exception:
24
+ raise gr.Error("Parse error. Provide 9 comma-separated numbers.")
25
+ if len(feats) != 9:
26
+ raise gr.Error(f"Expected 9 features, got {len(feats)}.")
27
+ x = np.array(feats, dtype="float32").reshape(1, 1, 9) # (batch, 1, 9)
28
+ y = model.predict(x, verbose=0)[0].tolist() # 7 outputs
29
+ return {"predictions": y}
30
+
31
+ demo = gr.Interface(
32
+ fn=predict_one,
33
+ inputs=gr.Textbox(placeholder="f1,f2,f3,f4,f5,f6,f7,f8,f9"),
34
+ outputs="json",
35
+ title="Stock Predictor (Keras)",
36
+ description="Enter 9 features; returns 7 predictions."
37
+ )
38
+
39
+ if __name__ == "__main__":
40
+ demo.launch()
41
+