Update fine-tune notebook securities to real StockEx symbols
Browse filesReplace hardcoded OPAP/MYTIL/ADMIE/etc. with actual StockEx securities
(ALPHA, PEIR, EXAE, QUEST, NBG, EUROB, AEG, INTKA, AAAK, ATTIK) and
current market prices for Colab training dataset generation.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
notebooks/ch_trader_finetune.ipynb
CHANGED
|
@@ -96,176 +96,7 @@
|
|
| 96 |
"id": "dataset-gen",
|
| 97 |
"metadata": {},
|
| 98 |
"outputs": [],
|
| 99 |
-
"source": [
|
| 100 |
-
"# Securities traded on StockEx\n",
|
| 101 |
-
"SECURITIES = [\n",
|
| 102 |
-
" {\"symbol\": \"ALPHA\", \"base\": 6.00},\n",
|
| 103 |
-
" {\"symbol\": \"PEIR\", \"base\": 8.20},\n",
|
| 104 |
-
" {\"symbol\": \"EXAE\", \"base\": 6.90},\n",
|
| 105 |
-
" {\"symbol\": \"OPAP\", \"base\": 14.50},\n",
|
| 106 |
-
" {\"symbol\": \"MYTIL\", \"base\": 9.80},\n",
|
| 107 |
-
" {\"symbol\": \"ADMIE\", \"base\": 2.45},\n",
|
| 108 |
-
" {\"symbol\": \"ELPE\", \"base\": 7.60},\n",
|
| 109 |
-
" {\"symbol\": \"MOTOR\", \"base\": 22.30},\n",
|
| 110 |
-
" {\"symbol\": \"OTE\", \"base\": 15.10},\n",
|
| 111 |
-
" {\"symbol\": \"TPEIR\", \"base\": 1.75},\n",
|
| 112 |
-
"]\n",
|
| 113 |
-
"\n",
|
| 114 |
-
"STARTING_CAPITAL = 100_000.0\n",
|
| 115 |
-
"DAILY_OBLIGATION = 10\n",
|
| 116 |
-
"\n",
|
| 117 |
-
"\n",
|
| 118 |
-
"def gen_bbo(base_price: float) -> dict:\n",
|
| 119 |
-
" \"\"\"Generate a realistic bid/ask spread around a base price.\"\"\"\n",
|
| 120 |
-
" drift = random.uniform(-0.05, 0.05)\n",
|
| 121 |
-
" mid = round(base_price * (1 + drift), 2)\n",
|
| 122 |
-
" spread = round(random.choice([0.05, 0.10, 0.15]), 2)\n",
|
| 123 |
-
" best_bid = round(mid - spread / 2, 2)\n",
|
| 124 |
-
" best_ask = round(mid + spread / 2, 2)\n",
|
| 125 |
-
" return {\"best_bid\": best_bid, \"best_ask\": best_ask, \"mid\": mid}\n",
|
| 126 |
-
"\n",
|
| 127 |
-
"\n",
|
| 128 |
-
"def gen_holdings(bbos: dict) -> list:\n",
|
| 129 |
-
" \"\"\"Randomly generate some holdings for a member.\"\"\"\n",
|
| 130 |
-
" holdings = []\n",
|
| 131 |
-
" n = random.randint(0, 4) # 0–4 positions\n",
|
| 132 |
-
" for sym in random.sample(list(bbos.keys()), min(n, len(bbos))):\n",
|
| 133 |
-
" qty = random.randint(50, 500)\n",
|
| 134 |
-
" mid = bbos[sym][\"mid\"]\n",
|
| 135 |
-
" avg_cost = round(mid * random.uniform(0.92, 1.08), 2)\n",
|
| 136 |
-
" holdings.append({\"symbol\": sym, \"quantity\": qty, \"avg_cost\": avg_cost})\n",
|
| 137 |
-
" return holdings\n",
|
| 138 |
-
"\n",
|
| 139 |
-
"\n",
|
| 140 |
-
"def build_prompt(member_id: str, capital: float, holdings: list,\n",
|
| 141 |
-
" obligation_remaining: int, bbos: dict) -> str:\n",
|
| 142 |
-
" market_lines = [\n",
|
| 143 |
-
" f\" {sym}: Bid {bbo['best_bid']:.2f} / Ask {bbo['best_ask']:.2f}\"\n",
|
| 144 |
-
" for sym, bbo in sorted(bbos.items())\n",
|
| 145 |
-
" ]\n",
|
| 146 |
-
" holding_lines = (\n",
|
| 147 |
-
" [f\" {h['symbol']}: {h['quantity']} shares @ avg cost {h['avg_cost']:.2f}\"\n",
|
| 148 |
-
" for h in holdings]\n",
|
| 149 |
-
" if holdings else [\" None\"]\n",
|
| 150 |
-
" )\n",
|
| 151 |
-
" return (\n",
|
| 152 |
-
" f\"You are simulating clearing house member {member_id} making ONE trading decision.\\n\\n\"\n",
|
| 153 |
-
" f\"Member state:\\n\"\n",
|
| 154 |
-
" f\" Available capital: EUR {capital:,.2f}\\n\"\n",
|
| 155 |
-
" f\" Securities obligation remaining today: {obligation_remaining} more to trade\\n\"\n",
|
| 156 |
-
" f\" Current holdings:\\n\" + \"\\n\".join(holding_lines) + \"\\n\\n\"\n",
|
| 157 |
-
" f\"Current market (Bid/Ask):\\n\" + \"\\n\".join(market_lines) + \"\\n\\n\"\n",
|
| 158 |
-
" f\"Rules:\\n\"\n",
|
| 159 |
-
" f\"- Do not spend more than your available capital\\n\"\n",
|
| 160 |
-
" f\"- Do not sell more shares than you hold\\n\"\n",
|
| 161 |
-
" f\"- If you have no holdings, you must BUY\\n\"\n",
|
| 162 |
-
" f\"- Choose a realistic price close to the BBO mid-price\\n\"\n",
|
| 163 |
-
" f\"- Quantity should be between 10 and 200\\n\\n\"\n",
|
| 164 |
-
" f\"Respond ONLY with valid JSON, no other text:\\n\"\n",
|
| 165 |
-
" f'Example: {{\"symbol\": \"ALPHA\", \"side\": \"BUY\", \"quantity\": 50, \"price\": 5.95}}'\n",
|
| 166 |
-
" )\n",
|
| 167 |
-
"\n",
|
| 168 |
-
"\n",
|
| 169 |
-
"def gen_decision(capital: float, holdings: list, bbos: dict) -> dict:\n",
|
| 170 |
-
" \"\"\"Generate a rule-valid trading decision for the given state.\"\"\"\n",
|
| 171 |
-
" has_holdings = len(holdings) > 0\n",
|
| 172 |
-
"\n",
|
| 173 |
-
" # Decide side: BUY if no holdings or randomly; SELL if heavy positions\n",
|
| 174 |
-
" holdings_value = sum(\n",
|
| 175 |
-
" h[\"quantity\"] * bbos.get(h[\"symbol\"], {}).get(\"mid\", h[\"avg_cost\"])\n",
|
| 176 |
-
" for h in holdings\n",
|
| 177 |
-
" )\n",
|
| 178 |
-
" net_worth = capital + holdings_value\n",
|
| 179 |
-
" holdings_ratio = holdings_value / net_worth if net_worth > 0 else 0\n",
|
| 180 |
-
"\n",
|
| 181 |
-
" if not has_holdings:\n",
|
| 182 |
-
" side = \"BUY\"\n",
|
| 183 |
-
" elif holdings_ratio > 0.6:\n",
|
| 184 |
-
" side = random.choices([\"SELL\", \"BUY\"], weights=[0.7, 0.3])[0]\n",
|
| 185 |
-
" else:\n",
|
| 186 |
-
" side = random.choices([\"BUY\", \"SELL\"], weights=[0.55, 0.45])[0]\n",
|
| 187 |
-
"\n",
|
| 188 |
-
" if side == \"BUY\":\n",
|
| 189 |
-
" # Pick a random affordable symbol\n",
|
| 190 |
-
" affordable = [\n",
|
| 191 |
-
" sym for sym, bbo in bbos.items()\n",
|
| 192 |
-
" if 10 * bbo[\"best_ask\"] <= capital\n",
|
| 193 |
-
" ]\n",
|
| 194 |
-
" if not affordable:\n",
|
| 195 |
-
" # Fall back to cheapest\n",
|
| 196 |
-
" sym = min(bbos, key=lambda s: bbos[s][\"best_ask\"])\n",
|
| 197 |
-
" else:\n",
|
| 198 |
-
" # Weight toward securities we already hold (adding to position)\n",
|
| 199 |
-
" held_syms = [h[\"symbol\"] for h in holdings]\n",
|
| 200 |
-
" weights = [3 if s in held_syms else 1 for s in affordable]\n",
|
| 201 |
-
" sym = random.choices(affordable, weights=weights)[0]\n",
|
| 202 |
-
" ask = bbos[sym][\"best_ask\"]\n",
|
| 203 |
-
" max_qty = min(200, int(capital / ask))\n",
|
| 204 |
-
" qty = random.randint(10, max(10, max_qty))\n",
|
| 205 |
-
" price = round(bbos[sym][\"mid\"] + random.uniform(-0.05, 0.05), 2)\n",
|
| 206 |
-
" price = max(bbos[sym][\"best_bid\"], min(price, ask))\n",
|
| 207 |
-
" return {\"symbol\": sym, \"side\": \"BUY\", \"quantity\": qty, \"price\": round(price, 2)}\n",
|
| 208 |
-
" else:\n",
|
| 209 |
-
" # Sell from existing holdings\n",
|
| 210 |
-
" h = random.choice(holdings)\n",
|
| 211 |
-
" sym = h[\"symbol\"]\n",
|
| 212 |
-
" bbo = bbos[sym]\n",
|
| 213 |
-
" qty = random.randint(10, min(200, h[\"quantity\"]))\n",
|
| 214 |
-
" price = round(bbo[\"mid\"] + random.uniform(-0.05, 0.05), 2)\n",
|
| 215 |
-
" price = max(bbo[\"best_bid\"] - 0.05, min(price, bbo[\"best_ask\"]))\n",
|
| 216 |
-
" return {\"symbol\": sym, \"side\": \"SELL\", \"quantity\": qty, \"price\": round(price, 2)}\n",
|
| 217 |
-
"\n",
|
| 218 |
-
"\n",
|
| 219 |
-
"def generate_dataset(n: int) -> list:\n",
|
| 220 |
-
" examples = []\n",
|
| 221 |
-
" member_ids = [f\"USR{i:02d}\" for i in range(1, 11)]\n",
|
| 222 |
-
"\n",
|
| 223 |
-
" scenarios = [\n",
|
| 224 |
-
" # (capital_range, obligation_range, description)\n",
|
| 225 |
-
" ((80_000, 100_000), (5, 10), \"fresh_member\"), # new, must trade a lot\n",
|
| 226 |
-
" ((50_000, 80_000), (0, 5), \"active_member\"), # mid-session, nearly done\n",
|
| 227 |
-
" ((20_000, 50_000), (0, 2), \"low_capital\"), # low cash, mostly holdings\n",
|
| 228 |
-
" ((5_000, 20_000), (0, 10), \"very_low_capital\"), # near margin, careful\n",
|
| 229 |
-
" ((90_000, 100_000), (10, 10),\"start_of_day\"), # just started\n",
|
| 230 |
-
" ]\n",
|
| 231 |
-
"\n",
|
| 232 |
-
" for _ in range(n):\n",
|
| 233 |
-
" cap_range, obl_range, _ = random.choice(scenarios)\n",
|
| 234 |
-
" capital = round(random.uniform(*cap_range), 2)\n",
|
| 235 |
-
" obligation = random.randint(*obl_range)\n",
|
| 236 |
-
" member_id = random.choice(member_ids)\n",
|
| 237 |
-
"\n",
|
| 238 |
-
" # Generate market state\n",
|
| 239 |
-
" bbos = {s[\"symbol\"]: gen_bbo(s[\"base\"]) for s in SECURITIES}\n",
|
| 240 |
-
"\n",
|
| 241 |
-
" # Generate holdings consistent with remaining capital\n",
|
| 242 |
-
" holdings = gen_holdings(bbos)\n",
|
| 243 |
-
"\n",
|
| 244 |
-
" # Ensure capital consistency: if holdings are expensive, reduce capital\n",
|
| 245 |
-
" holdings_cost = sum(h[\"quantity\"] * h[\"avg_cost\"] for h in holdings)\n",
|
| 246 |
-
" if holdings_cost > STARTING_CAPITAL - capital:\n",
|
| 247 |
-
" # Scale down holdings to fit\n",
|
| 248 |
-
" scale = (STARTING_CAPITAL - capital) / max(holdings_cost, 1)\n",
|
| 249 |
-
" for h in holdings:\n",
|
| 250 |
-
" h[\"quantity\"] = max(10, int(h[\"quantity\"] * scale))\n",
|
| 251 |
-
"\n",
|
| 252 |
-
" prompt = build_prompt(member_id, capital, holdings, obligation, bbos)\n",
|
| 253 |
-
" decision = gen_decision(capital, holdings, bbos)\n",
|
| 254 |
-
"\n",
|
| 255 |
-
" examples.append({\n",
|
| 256 |
-
" \"prompt\": prompt,\n",
|
| 257 |
-
" \"completion\": json.dumps(decision),\n",
|
| 258 |
-
" })\n",
|
| 259 |
-
"\n",
|
| 260 |
-
" return examples\n",
|
| 261 |
-
"\n",
|
| 262 |
-
"\n",
|
| 263 |
-
"print(f\"Generating {DATASET_SIZE} training examples...\")\n",
|
| 264 |
-
"raw_data = generate_dataset(DATASET_SIZE)\n",
|
| 265 |
-
"print(f\"Done. Example:\")\n",
|
| 266 |
-
"print(\"PROMPT:\\n\", raw_data[0][\"prompt\"])\n",
|
| 267 |
-
"print(\"\\nCOMPLETION:\", raw_data[0][\"completion\"])"
|
| 268 |
-
]
|
| 269 |
},
|
| 270 |
{
|
| 271 |
"cell_type": "code",
|
|
@@ -478,83 +309,7 @@
|
|
| 478 |
"id": "inference-test",
|
| 479 |
"metadata": {},
|
| 480 |
"outputs": [],
|
| 481 |
-
"source": [
|
| 482 |
-
"import re\n",
|
| 483 |
-
"from transformers import pipeline\n",
|
| 484 |
-
"\n",
|
| 485 |
-
"pipe = pipeline(\n",
|
| 486 |
-
" \"text-generation\",\n",
|
| 487 |
-
" model=merged_model,\n",
|
| 488 |
-
" tokenizer=tokenizer,\n",
|
| 489 |
-
" device_map=\"auto\",\n",
|
| 490 |
-
")\n",
|
| 491 |
-
"\n",
|
| 492 |
-
"# Test scenarios\n",
|
| 493 |
-
"test_cases = [\n",
|
| 494 |
-
" {\n",
|
| 495 |
-
" \"desc\": \"New member, no holdings, must trade\",\n",
|
| 496 |
-
" \"capital\": 100_000.0,\n",
|
| 497 |
-
" \"holdings\": [],\n",
|
| 498 |
-
" \"obligation\": 10,\n",
|
| 499 |
-
" },\n",
|
| 500 |
-
" {\n",
|
| 501 |
-
" \"desc\": \"Experienced member with holdings, low obligation\",\n",
|
| 502 |
-
" \"capital\": 65_000.0,\n",
|
| 503 |
-
" \"holdings\": [\n",
|
| 504 |
-
" {\"symbol\": \"ALPHA\", \"quantity\": 300, \"avg_cost\": 5.90},\n",
|
| 505 |
-
" {\"symbol\": \"OPAP\", \"quantity\": 150, \"avg_cost\": 14.20},\n",
|
| 506 |
-
" ],\n",
|
| 507 |
-
" \"obligation\": 2,\n",
|
| 508 |
-
" },\n",
|
| 509 |
-
" {\n",
|
| 510 |
-
" \"desc\": \"Low capital, large holdings\",\n",
|
| 511 |
-
" \"capital\": 8_000.0,\n",
|
| 512 |
-
" \"holdings\": [\n",
|
| 513 |
-
" {\"symbol\": \"PEIR\", \"quantity\": 500, \"avg_cost\": 8.10},\n",
|
| 514 |
-
" {\"symbol\": \"MYTIL\", \"quantity\": 200, \"avg_cost\": 9.50},\n",
|
| 515 |
-
" ],\n",
|
| 516 |
-
" \"obligation\": 5,\n",
|
| 517 |
-
" },\n",
|
| 518 |
-
"]\n",
|
| 519 |
-
"\n",
|
| 520 |
-
"test_bbos = {s[\"symbol\"]: gen_bbo(s[\"base\"]) for s in SECURITIES}\n",
|
| 521 |
-
"\n",
|
| 522 |
-
"print(\"=\" * 70)\n",
|
| 523 |
-
"for tc in test_cases:\n",
|
| 524 |
-
" print(f\"\\nSCENARIO: {tc['desc']}\")\n",
|
| 525 |
-
" prompt = build_prompt(\n",
|
| 526 |
-
" \"USR01\", tc[\"capital\"], tc[\"holdings\"], tc[\"obligation\"], test_bbos\n",
|
| 527 |
-
" )\n",
|
| 528 |
-
" messages = [\n",
|
| 529 |
-
" {\"role\": \"system\",\"content\": SYSTEM_PROMPT},\n",
|
| 530 |
-
" {\"role\": \"user\", \"content\": prompt},\n",
|
| 531 |
-
" ]\n",
|
| 532 |
-
" output = pipe(\n",
|
| 533 |
-
" messages,\n",
|
| 534 |
-
" max_new_tokens=60,\n",
|
| 535 |
-
" temperature=0.3,\n",
|
| 536 |
-
" do_sample=True,\n",
|
| 537 |
-
" pad_token_id=tokenizer.eos_token_id,\n",
|
| 538 |
-
" )\n",
|
| 539 |
-
" response = output[0][\"generated_text\"][-1][\"content\"].strip()\n",
|
| 540 |
-
" print(f\"RESPONSE: {response}\")\n",
|
| 541 |
-
"\n",
|
| 542 |
-
" # Validate JSON\n",
|
| 543 |
-
" try:\n",
|
| 544 |
-
" m = re.search(r\"\\{[^}]+\\}\", response)\n",
|
| 545 |
-
" if m:\n",
|
| 546 |
-
" d = json.loads(m.group())\n",
|
| 547 |
-
" assert d[\"side\"] in (\"BUY\", \"SELL\")\n",
|
| 548 |
-
" assert d[\"symbol\"] in [s[\"symbol\"] for s in SECURITIES]\n",
|
| 549 |
-
" assert d[\"quantity\"] > 0\n",
|
| 550 |
-
" assert d[\"price\"] > 0\n",
|
| 551 |
-
" print(f\"✓ Valid JSON: {d}\")\n",
|
| 552 |
-
" else:\n",
|
| 553 |
-
" print(\"✗ No JSON found in response\")\n",
|
| 554 |
-
" except Exception as e:\n",
|
| 555 |
-
" print(f\"✗ Invalid: {e}\")\n",
|
| 556 |
-
" print(\"-\" * 70)"
|
| 557 |
-
]
|
| 558 |
},
|
| 559 |
{
|
| 560 |
"cell_type": "markdown",
|
|
|
|
| 96 |
"id": "dataset-gen",
|
| 97 |
"metadata": {},
|
| 98 |
"outputs": [],
|
| 99 |
+
"source": "# Securities from shared_data/securities.txt (symbol, start_price, current_price)\nSECURITIES = [\n {\"symbol\": \"ALPHA\", \"base\": 5.65},\n {\"symbol\": \"PEIR\", \"base\": 8.35},\n {\"symbol\": \"EXAE\", \"base\": 6.90},\n {\"symbol\": \"QUEST\", \"base\": 13.35},\n {\"symbol\": \"NBG\", \"base\": 8.00},\n {\"symbol\": \"EUROB\", \"base\": 3.45},\n {\"symbol\": \"AEG\", \"base\": 4.75},\n {\"symbol\": \"INTKA\", \"base\": 7.35},\n {\"symbol\": \"AAAK\", \"base\": 2.75},\n {\"symbol\": \"ATTIK\", \"base\": 4.90},\n]\n\nSTARTING_CAPITAL = 100_000.0\nDAILY_OBLIGATION = 10\n\n\ndef gen_bbo(base_price: float) -> dict:\n \"\"\"Generate a realistic bid/ask spread around a base price.\"\"\"\n drift = random.uniform(-0.05, 0.05)\n mid = round(base_price * (1 + drift), 2)\n spread = round(random.choice([0.05, 0.10, 0.15]), 2)\n best_bid = round(mid - spread / 2, 2)\n best_ask = round(mid + spread / 2, 2)\n return {\"best_bid\": best_bid, \"best_ask\": best_ask, \"mid\": mid}\n\n\ndef gen_holdings(bbos: dict) -> list:\n \"\"\"Randomly generate some holdings for a member.\"\"\"\n holdings = []\n n = random.randint(0, 4)\n for sym in random.sample(list(bbos.keys()), min(n, len(bbos))):\n qty = random.randint(50, 500)\n mid = bbos[sym][\"mid\"]\n avg_cost = round(mid * random.uniform(0.92, 1.08), 2)\n holdings.append({\"symbol\": sym, \"quantity\": qty, \"avg_cost\": avg_cost})\n return holdings\n\n\ndef build_prompt(member_id: str, capital: float, holdings: list,\n obligation_remaining: int, bbos: dict) -> str:\n market_lines = [\n f\" {sym}: Bid {bbo['best_bid']:.2f} / Ask {bbo['best_ask']:.2f}\"\n for sym, bbo in sorted(bbos.items())\n ]\n holding_lines = (\n [f\" {h['symbol']}: {h['quantity']} shares @ avg cost {h['avg_cost']:.2f}\"\n for h in holdings]\n if holdings else [\" None\"]\n )\n return (\n f\"You are simulating clearing house member {member_id} making ONE trading decision.\\n\\n\"\n f\"Member state:\\n\"\n f\" Available capital: EUR {capital:,.2f}\\n\"\n f\" Securities obligation remaining today: {obligation_remaining} more to trade\\n\"\n f\" Current holdings:\\n\" + \"\\n\".join(holding_lines) + \"\\n\\n\"\n f\"Current market (Bid/Ask):\\n\" + \"\\n\".join(market_lines) + \"\\n\\n\"\n f\"Rules:\\n\"\n f\"- Do not spend more than your available capital\\n\"\n f\"- Do not sell more shares than you hold\\n\"\n f\"- If you have no holdings, you must BUY\\n\"\n f\"- Choose a realistic price close to the BBO mid-price\\n\"\n f\"- Quantity should be between 10 and 200\\n\\n\"\n f\"Respond ONLY with valid JSON, no other text:\\n\"\n f'Example: {{\"symbol\": \"ALPHA\", \"side\": \"BUY\", \"quantity\": 50, \"price\": 5.65}}'\n )\n\n\ndef gen_decision(capital: float, holdings: list, bbos: dict) -> dict:\n \"\"\"Generate a rule-valid trading decision for the given state.\"\"\"\n holdings_value = sum(\n h[\"quantity\"] * bbos.get(h[\"symbol\"], {}).get(\"mid\", h[\"avg_cost\"])\n for h in holdings\n )\n net_worth = capital + holdings_value\n holdings_ratio = holdings_value / net_worth if net_worth > 0 else 0\n\n if not holdings:\n side = \"BUY\"\n elif holdings_ratio > 0.6:\n side = random.choices([\"SELL\", \"BUY\"], weights=[0.7, 0.3])[0]\n else:\n side = random.choices([\"BUY\", \"SELL\"], weights=[0.55, 0.45])[0]\n\n if side == \"BUY\":\n affordable = [sym for sym, bbo in bbos.items() if 10 * bbo[\"best_ask\"] <= capital]\n if not affordable:\n sym = min(bbos, key=lambda s: bbos[s][\"best_ask\"])\n else:\n held_syms = [h[\"symbol\"] for h in holdings]\n weights = [3 if s in held_syms else 1 for s in affordable]\n sym = random.choices(affordable, weights=weights)[0]\n ask = bbos[sym][\"best_ask\"]\n max_qty = min(200, int(capital / ask))\n qty = random.randint(10, max(10, max_qty))\n price = round(bbos[sym][\"mid\"] + random.uniform(-0.05, 0.05), 2)\n price = max(bbos[sym][\"best_bid\"], min(price, ask))\n return {\"symbol\": sym, \"side\": \"BUY\", \"quantity\": qty, \"price\": round(price, 2)}\n else:\n h = random.choice(holdings)\n sym = h[\"symbol\"]\n bbo = bbos[sym]\n qty = random.randint(10, min(200, h[\"quantity\"]))\n price = round(bbo[\"mid\"] + random.uniform(-0.05, 0.05), 2)\n price = max(bbo[\"best_bid\"] - 0.05, min(price, bbo[\"best_ask\"]))\n return {\"symbol\": sym, \"side\": \"SELL\", \"quantity\": qty, \"price\": round(price, 2)}\n\n\ndef generate_dataset(n: int) -> list:\n examples = []\n member_ids = [f\"USR{i:02d}\" for i in range(1, 11)]\n scenarios = [\n ((80_000, 100_000), (5, 10), \"fresh_member\"),\n ((50_000, 80_000), (0, 5), \"active_member\"),\n ((20_000, 50_000), (0, 2), \"low_capital\"),\n ((5_000, 20_000), (0, 10), \"very_low_capital\"),\n ((90_000, 100_000), (10, 10),\"start_of_day\"),\n ]\n for _ in range(n):\n cap_range, obl_range, _ = random.choice(scenarios)\n capital = round(random.uniform(*cap_range), 2)\n obligation = random.randint(*obl_range)\n member_id = random.choice(member_ids)\n bbos = {s[\"symbol\"]: gen_bbo(s[\"base\"]) for s in SECURITIES}\n holdings = gen_holdings(bbos)\n holdings_cost = sum(h[\"quantity\"] * h[\"avg_cost\"] for h in holdings)\n if holdings_cost > STARTING_CAPITAL - capital:\n scale = (STARTING_CAPITAL - capital) / max(holdings_cost, 1)\n for h in holdings:\n h[\"quantity\"] = max(10, int(h[\"quantity\"] * scale))\n prompt = build_prompt(member_id, capital, holdings, obligation, bbos)\n decision = gen_decision(capital, holdings, bbos)\n examples.append({\"prompt\": prompt, \"completion\": json.dumps(decision)})\n return examples\n\n\nprint(f\"Generating {DATASET_SIZE} training examples...\")\nraw_data = generate_dataset(DATASET_SIZE)\nprint(f\"Done. Example:\")\nprint(\"PROMPT:\\n\", raw_data[0][\"prompt\"])\nprint(\"\\nCOMPLETION:\", raw_data[0][\"completion\"])"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
},
|
| 101 |
{
|
| 102 |
"cell_type": "code",
|
|
|
|
| 309 |
"id": "inference-test",
|
| 310 |
"metadata": {},
|
| 311 |
"outputs": [],
|
| 312 |
+
"source": "import re\nfrom transformers import pipeline\n\npipe = pipeline(\n \"text-generation\",\n model=merged_model,\n tokenizer=tokenizer,\n device_map=\"auto\",\n)\n\ntest_cases = [\n {\n \"desc\": \"New member, no holdings, must trade\",\n \"capital\": 100_000.0,\n \"holdings\": [],\n \"obligation\": 10,\n },\n {\n \"desc\": \"Experienced member with holdings, low obligation\",\n \"capital\": 65_000.0,\n \"holdings\": [\n {\"symbol\": \"ALPHA\", \"quantity\": 300, \"avg_cost\": 5.60},\n {\"symbol\": \"QUEST\", \"quantity\": 150, \"avg_cost\": 13.20},\n ],\n \"obligation\": 2,\n },\n {\n \"desc\": \"Low capital, large holdings\",\n \"capital\": 8_000.0,\n \"holdings\": [\n {\"symbol\": \"PEIR\", \"quantity\": 500, \"avg_cost\": 8.30},\n {\"symbol\": \"NBG\", \"quantity\": 200, \"avg_cost\": 7.95},\n ],\n \"obligation\": 5,\n },\n]\n\ntest_bbos = {s[\"symbol\"]: gen_bbo(s[\"base\"]) for s in SECURITIES}\n\nprint(\"=\" * 70)\nfor tc in test_cases:\n print(f\"\\nSCENARIO: {tc['desc']}\")\n prompt = build_prompt(\"USR01\", tc[\"capital\"], tc[\"holdings\"], tc[\"obligation\"], test_bbos)\n messages = [\n {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n {\"role\": \"user\", \"content\": prompt},\n ]\n output = pipe(\n messages,\n max_new_tokens=60,\n temperature=0.3,\n do_sample=True,\n pad_token_id=tokenizer.eos_token_id,\n )\n response = output[0][\"generated_text\"][-1][\"content\"].strip()\n print(f\"RESPONSE: {response}\")\n try:\n m = re.search(r\"\\{[^}]+\\}\", response)\n if m:\n d = json.loads(m.group())\n assert d[\"side\"] in (\"BUY\", \"SELL\")\n assert d[\"symbol\"] in [s[\"symbol\"] for s in SECURITIES]\n assert d[\"quantity\"] > 0\n assert d[\"price\"] > 0\n print(f\"✓ Valid JSON: {d}\")\n else:\n print(\"✗ No JSON found in response\")\n except Exception as e:\n print(f\"✗ Invalid: {e}\")\n print(\"-\" * 70)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
},
|
| 314 |
{
|
| 315 |
"cell_type": "markdown",
|