| | import os |
| | from huggingface_hub import login |
| | import gradio as gr |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig |
| | from peft import PeftModel |
| | import torch |
| | import json |
| |
|
| | |
| | hf_token = os.environ.get("HF_TOKEN") |
| | if not hf_token: |
| | raise ValueError("HF_TOKEN environment variable is not set.") |
| | login(token=hf_token) |
| |
|
| | |
| | base_model_id = "meta-llama/Llama-2-7b-chat-hf" |
| | adapter_id = "mdot77/fingpt-llama2-7b-forecaster-finetuned" |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_fast=True) |
| | base_model = AutoModelForCausalLM.from_pretrained( |
| | base_model_id, |
| | device_map="auto", |
| | load_in_8bit=True, |
| | ) |
| | model = PeftModel.from_pretrained( |
| | base_model, |
| | adapter_id, |
| | device_map="auto", |
| | ) |
| | model.eval() |
| |
|
| | |
| | SYSTEM = """You are a portfolio optimization assistant. |
| | |
| | For a given stock snapshot, recommend how the allocation should be adjusted. |
| | Your response MUST be valid JSON matching this schema: |
| | { |
| | "ticker": "<string>", |
| | "snapshot": "<YYYY-MM-DD>", |
| | "verdict": "<Increase|Decrease|Hold|Add|Remove>", |
| | "new_alloc_pct": <number>, |
| | "reasoning": "<short explanation>" |
| | } |
| | |
| | Do not include any extra keys or commentary. At the end, emit only the JSON.""" |
| |
|
| | |
| | def infer(data_json: str): |
| | |
| | max_new_tokens = 256 |
| | temperature = 0.0 |
| | top_p = 1.0 |
| | |
| | prompt = ( |
| | "[INST] <<SYS>>\n" |
| | f"{SYSTEM}\n" |
| | "<</SYS>>\n\n" |
| | "DATA:\n" |
| | f"{data_json}\n" |
| | "[/INST]" |
| | ) |
| | inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device) |
| | gen_cfg = GenerationConfig( |
| | max_new_tokens=max_new_tokens, |
| | do_sample=(temperature > 0), |
| | temperature=temperature, |
| | top_p=top_p, |
| | use_cache=True, |
| | ) |
| | outputs = model.generate( |
| | inputs["input_ids"], |
| | attention_mask=inputs["attention_mask"], |
| | generation_config=gen_cfg, |
| | ) |
| | new_tokens = outputs[0, inputs["input_ids"].shape[-1]:] |
| | reply = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() |
| | try: |
| | parsed = json.loads(reply) |
| | return json.dumps(parsed, indent=2) |
| | except json.JSONDecodeError: |
| | return reply |
| |
|
| | |
| | def predict_api_wrapper(request_data): |
| | """Wrapper function that handles the API request format""" |
| | try: |
| | |
| | if isinstance(request_data, dict) and "data" in request_data and len(request_data["data"]) > 0: |
| | data_json = request_data["data"][0] |
| | elif isinstance(request_data, str): |
| | data_json = request_data |
| | else: |
| | return {"error": "No data provided"} |
| | |
| | |
| | result = infer(data_json) |
| | |
| | |
| | try: |
| | if isinstance(result, str): |
| | parsed_result = json.loads(result) |
| | else: |
| | parsed_result = result |
| | except json.JSONDecodeError: |
| | |
| | parsed_result = {"raw_output": result} |
| | |
| | return parsed_result |
| | except Exception as e: |
| | return {"error": str(e)} |
| |
|
| | |
| | with gr.Blocks(title="Portfolio-Optimizer Inference") as iface: |
| | gr.Markdown("# Portfolio-Optimizer Inference") |
| | gr.Markdown("Paste your snapshot JSON and get back a single-JSON allocation verdict.") |
| | |
| | with gr.Tab("Inference"): |
| | input_text = gr.Textbox( |
| | label="Snapshot data (JSON)", |
| | lines=15, |
| | value='{"ticker": "COIN", "snapshot": "2022-06-18", "previous_allocation_pct": 0.05}' |
| | ) |
| | output_json = gr.JSON(label="Model output") |
| | predict_btn = gr.Button("Predict") |
| | predict_btn.click(fn=infer, inputs=input_text, outputs=output_json) |
| | |
| | with gr.Tab("API Testing"): |
| | gr.Markdown("## API Testing Interface") |
| | gr.Markdown("Use this to test the API functionality. The main API endpoint is available at `/api/predict/`") |
| | |
| | api_input = gr.Textbox( |
| | label="API Request Data (JSON)", |
| | lines=10, |
| | value='{"data": ["{\\"ticker\\": \\"AAPL\\", \\"snapshot\\": \\"2025-01-01\\", \\"previous_allocation_pct\\": 0.05}"]}' |
| | ) |
| | api_output = gr.JSON(label="API Response") |
| | api_btn = gr.Button("Test API") |
| | api_btn.click(fn=predict_api_wrapper, inputs=api_input, outputs=api_output) |
| |
|
| | |
| | if __name__ == "__main__": |
| | |
| | |
| | iface.launch() |
| |
|