File size: 1,773 Bytes
10de70c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
Gradio app for FinRL PPO inference.
Deploy to HuggingFace Spaces (CPU basic, free).
"""

import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download
from stable_baselines3 import PPO

# Load model from HF Hub (change to your username)
REPO_ID = "2045max/finrl-ppo-dow30-quick"

print(f"Loading model from {REPO_ID}...")
model_path = hf_hub_download(repo_id=REPO_ID, filename="agent_ppo.zip")
model = PPO.load(model_path)
print("Model loaded!")


def predict(state_json: str):
    """Predict action from a 301-dim state vector (JSON list)."""
    try:
        import json

        state = np.array(json.loads(state_json), dtype=np.float32)
        if state.shape[0] != 301:
            return f"Error: state must be 301-dim, got {state.shape[0]}"

        action, _ = model.predict(state, deterministic=True)

        # action: 30-dim, scale to readable
        result = {
            "action": action.tolist(),
            "interpretation": [
                f"Stock {i}: {'BUY' if a > 0.1 else 'SELL' if a < -0.1 else 'HOLD'} ({a:.2f})"
                for i, a in enumerate(action)
            ],
        }
        return json.dumps(result, indent=2)
    except Exception as e:
        return f"Error: {e}"


# Demo state (random for testing)
demo_state = "[1000000.0" + ", 100.0" * 30 + ", 0.0" * 30 + ", 0.5" * 240 + "]"

iface = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(
        label="State (301-dim JSON list)",
        value=demo_state,
        lines=5,
    ),
    outputs=gr.Textbox(label="Action (30-dim)", lines=15),
    title="FinRL PPO Agent (Quick Demo)",
    description="⚠️ Toy model trained on only 2000 steps. Educational use only.",
    api_name="predict",  # ← exposes /api/predict endpoint
)

iface.launch()