File size: 5,882 Bytes
d31d611
 
2d3b0ea
6783dbd
 
 
 
2d3b0ea
d31d611
 
 
 
 
 
 
6783dbd
d31d611
6783dbd
 
 
 
 
d31d611
6783dbd
 
 
 
 
 
 
2d3b0ea
d31d611
 
 
 
 
 
 
 
 
 
 
 
 
 
2d3b0ea
d31d611
22bb8ba
 
 
 
 
 
d31d611
 
 
 
 
 
 
 
 
 
 
 
2d3b0ea
 
6783dbd
 
 
 
 
d31d611
6783dbd
d31d611
 
 
 
 
 
 
2d3b0ea
9dd6619
 
 
d96630d
 
9dd6619
d96630d
9dd6619
 
d96630d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d31d611
af46622
9dd6619
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af46622
2d3b0ea
af46622
 
d31d611
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
from huggingface_hub import login
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from peft import PeftModel
import torch
import json

# ── 0) Hugging Face login ────────────────────────────────────────────────────────
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
    raise ValueError("HF_TOKEN environment variable is not set.")
login(token=hf_token)

# ── 1) Load your model ───────────────────────────────────────────────────────────
base_model_id = "meta-llama/Llama-2-7b-chat-hf"
adapter_id    = "mdot77/fingpt-llama2-7b-forecaster-finetuned"

tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_fast=True)
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    device_map="auto",
    load_in_8bit=True,   # save VRAM
)
model = PeftModel.from_pretrained(
    base_model,
    adapter_id,
    device_map="auto",
)
model.eval()

# ── 2) Define your system‐instruction template ───────────────────────────────────
SYSTEM = """You are a portfolio optimization assistant.

For a given stock snapshot, recommend how the allocation should be adjusted.
Your response MUST be valid JSON matching this schema:
{
  "ticker": "<string>",
  "snapshot": "<YYYY-MM-DD>",
  "verdict": "<Increase|Decrease|Hold|Add|Remove>",
  "new_alloc_pct": <number>,
  "reasoning": "<short explanation>"
}

Do not include any extra keys or commentary. At the end, emit only the JSON."""

# ── 3) Inference function ───────────────────────────────────────────────────────
def infer(data_json: str):
    # Hardcoded generation parameters
    max_new_tokens = 256
    temperature = 0.0
    top_p = 1.0
    
    prompt = (
        "[INST] <<SYS>>\n"
        f"{SYSTEM}\n"
        "<</SYS>>\n\n"
        "DATA:\n"
        f"{data_json}\n"
        "[/INST]"
    )
    inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
    gen_cfg = GenerationConfig(
        max_new_tokens=max_new_tokens,
        do_sample=(temperature > 0),
        temperature=temperature,
        top_p=top_p,
        use_cache=True,
    )
    outputs = model.generate(
        inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        generation_config=gen_cfg,
    )
    new_tokens = outputs[0, inputs["input_ids"].shape[-1]:]
    reply = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
    try:
        parsed = json.loads(reply)
        return json.dumps(parsed, indent=2)
    except json.JSONDecodeError:
        return reply

# ── 4) API wrapper function for the main prediction ──────────────────────────────
def predict_api_wrapper(request_data):
    """Wrapper function that handles the API request format"""
    try:
        # Extract the data from the request
        if isinstance(request_data, dict) and "data" in request_data and len(request_data["data"]) > 0:
            data_json = request_data["data"][0]
        elif isinstance(request_data, str):
            data_json = request_data
        else:
            return {"error": "No data provided"}
        
        # Call the inference function
        result = infer(data_json)
        
        # Try to parse the result as JSON if it's not already
        try:
            if isinstance(result, str):
                parsed_result = json.loads(result)
            else:
                parsed_result = result
        except json.JSONDecodeError:
            # If it's not valid JSON, return as text
            parsed_result = {"raw_output": result}
        
        return parsed_result
    except Exception as e:
        return {"error": str(e)}

# ── 5) Gradio interface with API endpoints ───────────────────────────────────────
with gr.Blocks(title="Portfolio-Optimizer Inference") as iface:
    gr.Markdown("# Portfolio-Optimizer Inference")
    gr.Markdown("Paste your snapshot JSON and get back a single-JSON allocation verdict.")
    
    with gr.Tab("Inference"):
        input_text = gr.Textbox(
            label="Snapshot data (JSON)",
            lines=15,
            value='{"ticker": "COIN", "snapshot": "2022-06-18", "previous_allocation_pct": 0.05}'
        )
        output_json = gr.JSON(label="Model output")
        predict_btn = gr.Button("Predict")
        predict_btn.click(fn=infer, inputs=input_text, outputs=output_json)
    
    with gr.Tab("API Testing"):
        gr.Markdown("## API Testing Interface")
        gr.Markdown("Use this to test the API functionality. The main API endpoint is available at `/api/predict/`")
        
        api_input = gr.Textbox(
            label="API Request Data (JSON)",
            lines=10,
            value='{"data": ["{\\"ticker\\": \\"AAPL\\", \\"snapshot\\": \\"2025-01-01\\", \\"previous_allocation_pct\\": 0.05}"]}'
        )
        api_output = gr.JSON(label="API Response")
        api_btn = gr.Button("Test API")
        api_btn.click(fn=predict_api_wrapper, inputs=api_input, outputs=api_output)

# ── 6) Launch for Hugging Face Spaces ───────────────────────────────────────────
if __name__ == "__main__":
    # For Hugging Face Spaces, use the default launch
    # The API endpoints will be available at /api/predict/ automatically
    iface.launch()