""" GEX (Gamma Exposure) Calculator Module ====================================== Computes gamma exposure profiles from options chain data. Supports both raw CBOE format and normalized DataFrame inputs. Functions: - compute_gex: Basic GEX from calls/puts DataFrames - compute_gex_plus: GEX+ with LOB disagreement correction - compute_gex_by_strike: Per-strike GEX breakdown - compute_net_gex: Net gamma exposure (calls - puts) - find_zero_gamma: Find spot level where GEX+ crosses zero - compute_crash_profile: GEX+ vs spot move percentage - compute_gamma_flip: Identify gamma flip level """ import pandas as pd import numpy as np from typing import Optional, Tuple, List, Dict, Any def _safe_float(v, default=0.0): """Safely convert to float, returning default on failure.""" try: if v is None: return default f = float(v) return f if np.isfinite(f) else default except (TypeError, ValueError): return default def _safe_int(v, default=0): """Safely convert to int, returning default on failure.""" try: if v is None: return default return int(float(v)) except (TypeError, ValueError): return default def compute_gex( calls_df: Optional[pd.DataFrame], puts_df: Optional[pd.DataFrame], spot: float ) -> Tuple[Optional[pd.DataFrame], float]: """ Compute GEX from separate calls and puts DataFrames. Parameters ---------- calls_df : DataFrame with columns: strike, gamma, openInterest puts_df : DataFrame with columns: strike, gamma, openInterest spot : Current spot price Returns ------- (gex_by_strike_df, total_gex) - DataFrame with columns: strike, gex_call, gex_put, gex_net - total_gex: signed total in dollars """ if calls_df is None or puts_df is None or spot is None or spot <= 0: return None, 0.0 spot = float(spot) rows: Dict[float, Dict[str, float]] = {} # Process calls (positive gamma contribution) for _, r in calls_df.iterrows(): gamma = _safe_float(r.get('gamma')) oi = _safe_int(r.get('openInterest')) strike = _safe_float(r.get('strike')) if gamma > 0 and oi > 0: if strike not in rows: rows[strike] = {'gex_call': 0.0, 'gex_put': 0.0} rows[strike]['gex_call'] += gamma * oi * spot * spot / 100.0 # Process puts (negative gamma contribution) for _, r in puts_df.iterrows(): gamma = _safe_float(r.get('gamma')) oi = _safe_int(r.get('openInterest')) strike = _safe_float(r.get('strike')) if gamma > 0 and oi > 0: if strike not in rows: rows[strike] = {'gex_call': 0.0, 'gex_put': 0.0} rows[strike]['gex_put'] -= gamma * oi * spot * spot / 100.0 if not rows: return None, 0.0 data = [] for strike, vals in rows.items(): data.append({ 'strike': strike, 'gex_call': vals['gex_call'], 'gex_put': vals['gex_put'], 'gex_net': vals['gex_call'] + vals['gex_put'], }) df = pd.DataFrame(data).sort_values('strike').reset_index(drop=True) total = df['gex_net'].sum() return df, float(total) def compute_gex_from_records(records: List[Dict[str, Any]], spot: float) -> float: """ Compute total net GEX from a list of strike records (CBOE format). Each record should have keys: strike, oi_call, oi_put, gamma_call, gamma_put """ if not records or spot is None or spot <= 0: return 0.0 spot = float(spot) total = 0.0 for s in records: gamma_c = _safe_float(s.get('gamma_call', s.get('gamma_c'))) gamma_p = _safe_float(s.get('gamma_put', s.get('gamma_p'))) oi_c = _safe_int(s.get('oi_call', s.get('open_interest_call'))) oi_p = _safe_int(s.get('oi_put', s.get('open_interest_put'))) total += gamma_c * oi_c * spot * spot / 100.0 total -= gamma_p * oi_p * spot * spot / 100.0 return total def compute_gex_plus(records: List[Dict[str, Any]], spot: float) -> float: """ GEX+ with LOB (Limit Order Book) disagreement correction. Adjusts GEX by detecting IV disagreement between puts and calls, which can indicate informed flow distorting the options chain. """ if not records or spot is None or spot <= 0: return 0.0 spot = float(spot) total = 0.0 disagreements = 0 for s in records: oi_c = _safe_int(s.get('oi_call')) oi_p = _safe_int(s.get('oi_put')) gamma_c = _safe_float(s.get('gamma_call')) gamma_p = _safe_float(s.get('gamma_put')) iv_c = _safe_float(s.get('iv_call', 0.001)) iv_p = _safe_float(s.get('iv_put', 0.001)) net_oi = oi_c - oi_p gamma = (gamma_c + gamma_p) / 2.0 g = net_oi * gamma * spot * spot / 100.0 total += g # Detect LOB disagreement (put IV significantly > call IV) if iv_p > iv_c * 1.02 and iv_c > 0: disagreements += 1 rate = disagreements / len(records) if records else 0.0 # Correction factor: reduce GEX proportionally to disagreement rate correction = max(0.0, 1.0 - 2.0 * rate) return total * correction def compute_gex_by_strike(records: List[Dict[str, Any]], spot: float) -> pd.DataFrame: """ Compute GEX breakdown by strike from records. Returns DataFrame with columns: strike, gex_call, gex_put, gex_net, oi_call, oi_put, iv_call, iv_put """ if not records: return pd.DataFrame(columns=['strike', 'gex_call', 'gex_put', 'gex_net', 'oi_call', 'oi_put', 'iv_call', 'iv_put']) spot = float(spot) if spot else 0.0 data = [] for s in records: strike = _safe_float(s.get('strike')) oi_c = _safe_int(s.get('oi_call')) oi_p = _safe_int(s.get('oi_put')) gamma_c = _safe_float(s.get('gamma_call')) gamma_p = _safe_float(s.get('gamma_put')) iv_c = _safe_float(s.get('iv_call')) * 100 # Convert to percent iv_p = _safe_float(s.get('iv_put')) * 100 if spot > 0: gex_c = gamma_c * oi_c * spot * spot / 100.0 gex_p = -gamma_p * oi_p * spot * spot / 100.0 else: gex_c = 0.0 gex_p = 0.0 data.append({ 'strike': strike, 'gex_call': gex_c, 'gex_put': gex_p, 'gex_net': gex_c + gex_p, 'oi_call': oi_c, 'oi_put': oi_p, 'iv_call': iv_c, 'iv_put': iv_p, }) return pd.DataFrame(data).sort_values('strike').reset_index(drop=True) def compute_net_gex(records: List[Dict[str, Any]], spot: float) -> Dict[str, float]: """ Compute comprehensive net GEX metrics. Returns dict with: - net_gex: Total net gamma exposure ($) - call_gex: Total call gamma exposure ($) - put_gex: Total put gamma exposure ($) - max_positive_strike: Strike with highest positive GEX - max_negative_strike: Strike with highest negative GEX (most short gamma) - gamma_wall_strike: Strike with highest absolute GEX (gamma wall) """ if not records or spot is None or spot <= 0: return { 'net_gex': 0.0, 'call_gex': 0.0, 'put_gex': 0.0, 'max_positive_strike': 0.0, 'max_negative_strike': 0.0, 'gamma_wall_strike': 0.0, } spot = float(spot) total_call = 0.0 total_put = 0.0 max_pos_gex = 0.0 max_pos_strike = 0.0 max_neg_gex = 0.0 max_neg_strike = 0.0 max_abs_gex = 0.0 gamma_wall = 0.0 for s in records: strike = _safe_float(s.get('strike')) oi_c = _safe_int(s.get('oi_call')) oi_p = _safe_int(s.get('oi_put')) gamma_c = _safe_float(s.get('gamma_call')) gamma_p = _safe_float(s.get('gamma_put')) gex_c = gamma_c * oi_c * spot * spot / 100.0 gex_p = -gamma_p * oi_p * spot * spot / 100.0 net = gex_c + gex_p total_call += gex_c total_put += gex_p if net > max_pos_gex: max_pos_gex = net max_pos_strike = strike if net < max_neg_gex: max_neg_gex = net max_neg_strike = strike if abs(net) > max_abs_gex: max_abs_gex = abs(net) gamma_wall = strike return { 'net_gex': total_call + total_put, 'call_gex': total_call, 'put_gex': total_put, 'max_positive_strike': max_pos_strike, 'max_negative_strike': max_neg_strike, 'gamma_wall_strike': gamma_wall, } def find_zero_gamma( records: List[Dict[str, Any]], spot: float, max_range_pct: float = 25.0, step: float = 0.1 ) -> float: """ Find the spot level where GEX+ crosses zero. Scales gamma by spot/test_spot to approximate how exposure changes with spot moves (sticky-strike assumption). Returns the spot price at which GEX+ flips sign. """ if not records or spot is None or spot <= 0: return spot gex_at_spot = compute_gex_plus(records, spot) if gex_at_spot == 0: return spot sign_at = np.sign(gex_at_spot) # Search upward for pct in np.arange(0.0, max_range_pct + step, step): test = spot * (1 + pct / 100.0) scaled = _scale_records(records, spot, test) g = compute_gex_plus(scaled, test) if np.sign(g) != sign_at: return test # Search downward for pct in np.arange(0.0, -max_range_pct - step, -step): test = spot * (1 + pct / 100.0) scaled = _scale_records(records, spot, test) g = compute_gex_plus(scaled, test) if np.sign(g) != sign_at: return test return spot * 1.05 # Fallback: 5% above spot def compute_crash_profile( records: List[Dict[str, Any]], spot: float, range_pct: Tuple[float, float] = (-15.0, 10.0), step: float = 0.25 ) -> List[Dict[str, Any]]: """ Compute GEX+ across a range of spot moves. Returns list of dicts with keys: spot_pct, spx, gex_plus """ if not records or spot is None or spot <= 0: return [] profile = [] gex_at = compute_gex_plus(records, spot) for pct in np.arange(range_pct[0], range_pct[1] + step, step): test = spot * (1 + pct / 100.0) scaled = _scale_records(records, spot, test) gex = compute_gex_plus(scaled, test) profile.append({ 'spot_pct': round(float(pct), 2), 'spx': round(test, 1), 'gex_plus': gex, }) return profile def compute_gamma_flip(records: List[Dict[str, Any]], spot: float) -> Dict[str, float]: """ Find gamma flip levels (both above and below current spot). Returns dict with: - up_flip: Spot level above current where GEX flips - down_flip: Spot level below current where GEX flips - current_gex: Current GEX value """ if not records or spot is None: return {'up_flip': 0.0, 'down_flip': 0.0, 'current_gex': 0.0} current_gex = compute_gex_plus(records, spot) up_flip = spot down_flip = spot # Find upward flip for pct in np.arange(0.1, 25.0, 0.1): test = spot * (1 + pct / 100.0) scaled = _scale_records(records, spot, test) g = compute_gex_plus(scaled, test) if np.sign(g) != np.sign(current_gex) and current_gex != 0: up_flip = test break # Find downward flip for pct in np.arange(-0.1, -25.0, -0.1): test = spot * (1 + pct / 100.0) scaled = _scale_records(records, spot, test) g = compute_gex_plus(scaled, test) if np.sign(g) != np.sign(current_gex) and current_gex != 0: down_flip = test break return { 'up_flip': up_flip, 'down_flip': down_flip, 'current_gex': current_gex, } def _scale_records( records: List[Dict[str, Any]], old_spot: float, new_spot: float ) -> List[Dict[str, Any]]: """ Scale gamma values in records for a different spot level. Gamma scales inversely with spot (sticky-strike). """ if old_spot <= 0: return records factor = old_spot / new_spot return [ { **s, 'gamma_call': _safe_float(s.get('gamma_call')) * factor, 'gamma_put': _safe_float(s.get('gamma_put')) * factor, } for s in records ] def compute_vanna_exposure(records: List[Dict[str, Any]], spot: float) -> float: """ Compute VEX (Vanna Exposure) from records. Vanna = d(Delta)/d(IV) = d(Vega)/d(Spot) Uses oi_call/oi_put, vanna_call/vanna_put if available, otherwise estimates from gamma. """ if not records or spot is None or spot <= 0: return 0.0 spot = float(spot) total = 0.0 for s in records: # Use explicit vanna if available, else estimate vanna_c = _safe_float(s.get('vanna_call')) vanna_p = _safe_float(s.get('vanna_put')) if vanna_c == 0 and vanna_p == 0: # Estimate vanna from gamma: vanna ≈ -gamma * d1 / (spot * vol) gamma_c = _safe_float(s.get('gamma_call')) gamma_p = _safe_float(s.get('gamma_put')) iv_c = _safe_float(s.get('iv_call', 0.20)) iv_p = _safe_float(s.get('iv_put', 0.20)) strike = _safe_float(s.get('strike')) if iv_c > 0 and strike > 0: d1_c = (np.log(spot / strike) + 0.5 * iv_c**2) / iv_c vanna_c = -gamma_c * d1_c / iv_c if iv_p > 0 and strike > 0: d1_p = (np.log(spot / strike) + 0.5 * iv_p**2) / iv_p vanna_p = -gamma_p * d1_p / iv_p oi_c = _safe_int(s.get('oi_call')) oi_p = _safe_int(s.get('oi_put')) total += vanna_c * oi_c * spot + vanna_p * oi_p * spot return total def compute_charm_exposure(records: List[Dict[str, Any]], spot: float) -> float: """ Compute Charm (Delta decay) exposure. Charm = -d(Delta)/d(Time) = d(Delta)/d(T) Important for hedging around expiration. """ if not records or spot is None or spot <= 0: return 0.0 spot = float(spot) total = 0.0 for s in records: iv_c = _safe_float(s.get('iv_call', 0.20)) iv_p = _safe_float(s.get('iv_put', 0.20)) gamma_c = _safe_float(s.get('gamma_call')) gamma_p = _safe_float(s.get('gamma_put')) strike = _safe_float(s.get('strike')) # Assume 30 DTE if not specified dte = _safe_float(s.get('dte', 30)) T = dte / 365.0 if strike > 0 and iv_c > 0 and T > 0: d1 = (np.log(spot / strike) + (0.5 * iv_c**2) * T) / (iv_c * np.sqrt(T)) d2 = d1 - iv_c * np.sqrt(T) # Charm ≈ -gamma * (d2 / (2*T) - r/(IV*sqrt(T))) # (simplified, r≈0) charm_c = -gamma_c * d2 / (2 * T) if T > 0 else 0 else: charm_c = 0 if strike > 0 and iv_p > 0 and T > 0: d1 = (np.log(spot / strike) + (0.5 * iv_p**2) * T) / (iv_p * np.sqrt(T)) d2 = d1 - iv_p * np.sqrt(T) charm_p = -gamma_p * d2 / (2 * T) if T > 0 else 0 else: charm_p = 0 oi_c = _safe_int(s.get('oi_call')) oi_p = _safe_int(s.get('oi_put')) total += charm_c * oi_c * spot + charm_p * oi_p * spot return total def gex_summary(records: List[Dict[str, Any]], spot: float) -> Dict[str, Any]: """ Compute a comprehensive GEX summary. Returns dict with all key metrics. """ net = compute_net_gex(records, spot) flip = compute_gamma_flip(records, spot) gex_plus = compute_gex_plus(records, spot) vex = compute_vanna_exposure(records, spot) return { 'spot': spot, 'gex_plus': gex_plus, 'net_gex': net['net_gex'], 'call_gex': net['call_gex'], 'put_gex': net['put_gex'], 'gamma_wall': net['gamma_wall_strike'], 'max_positive_strike': net['max_positive_strike'], 'max_negative_strike': net['max_negative_strike'], 'zero_gamma_up': flip['up_flip'], 'zero_gamma_down': flip['down_flip'], 'vex': vex, 'vgr': abs(vex) / abs(gex_plus) if abs(gex_plus) > 1e-6 else 0.0, 'regime': 'Long Gamma' if gex_plus > 0 else 'Short Gamma', } if __name__ == '__main__': # Smoke test with synthetic data rng = np.random.default_rng(42) test_records = [] spot = 6632.0 for k in np.arange(5800, 7400, 25): atm_dist = (k - spot) / spot iv_c = max(0.05, 0.18 + abs(atm_dist) * 0.3 - atm_dist * 0.05) iv_p = max(0.05, iv_c + 0.01 + max(0.0, -atm_dist * 0.08)) gamma = float(0.003 * np.exp(-50 * atm_dist**2)) test_records.append({ 'strike': float(k), 'oi_call': int(abs(rng.normal(8000, 3000))), 'oi_put': int(abs(rng.normal(10000, 4000))), 'iv_c': round(iv_c, 4), 'iv_p': round(iv_p, 4), 'gamma_c': round(gamma, 6), 'gamma_p': round(gamma, 6), }) summary = gex_summary(test_records, spot) print("GEX Summary:") for k, v in summary.items(): if isinstance(v, float): print(f" {k}: {v:,.4f}") else: print(f" {k}: {v}")