maxkru92's picture
Upload gex.py with huggingface_hub
9914e9b verified
Raw
History Blame Contribute Delete
17.8 kB
"""
GEX (Gamma Exposure) Calculator Module
======================================
Computes gamma exposure profiles from options chain data.
Supports both raw CBOE format and normalized DataFrame inputs.
Functions:
- compute_gex: Basic GEX from calls/puts DataFrames
- compute_gex_plus: GEX+ with LOB disagreement correction
- compute_gex_by_strike: Per-strike GEX breakdown
- compute_net_gex: Net gamma exposure (calls - puts)
- find_zero_gamma: Find spot level where GEX+ crosses zero
- compute_crash_profile: GEX+ vs spot move percentage
- compute_gamma_flip: Identify gamma flip level
"""
import pandas as pd
import numpy as np
from typing import Optional, Tuple, List, Dict, Any
def _safe_float(v, default=0.0):
"""Safely convert to float, returning default on failure."""
try:
if v is None:
return default
f = float(v)
return f if np.isfinite(f) else default
except (TypeError, ValueError):
return default
def _safe_int(v, default=0):
"""Safely convert to int, returning default on failure."""
try:
if v is None:
return default
return int(float(v))
except (TypeError, ValueError):
return default
def compute_gex(
calls_df: Optional[pd.DataFrame],
puts_df: Optional[pd.DataFrame],
spot: float
) -> Tuple[Optional[pd.DataFrame], float]:
"""
Compute GEX from separate calls and puts DataFrames.
Parameters
----------
calls_df : DataFrame with columns: strike, gamma, openInterest
puts_df : DataFrame with columns: strike, gamma, openInterest
spot : Current spot price
Returns
-------
(gex_by_strike_df, total_gex)
- DataFrame with columns: strike, gex_call, gex_put, gex_net
- total_gex: signed total in dollars
"""
if calls_df is None or puts_df is None or spot is None or spot <= 0:
return None, 0.0
spot = float(spot)
rows: Dict[float, Dict[str, float]] = {}
# Process calls (positive gamma contribution)
for _, r in calls_df.iterrows():
gamma = _safe_float(r.get('gamma'))
oi = _safe_int(r.get('openInterest'))
strike = _safe_float(r.get('strike'))
if gamma > 0 and oi > 0:
if strike not in rows:
rows[strike] = {'gex_call': 0.0, 'gex_put': 0.0}
rows[strike]['gex_call'] += gamma * oi * spot * spot / 100.0
# Process puts (negative gamma contribution)
for _, r in puts_df.iterrows():
gamma = _safe_float(r.get('gamma'))
oi = _safe_int(r.get('openInterest'))
strike = _safe_float(r.get('strike'))
if gamma > 0 and oi > 0:
if strike not in rows:
rows[strike] = {'gex_call': 0.0, 'gex_put': 0.0}
rows[strike]['gex_put'] -= gamma * oi * spot * spot / 100.0
if not rows:
return None, 0.0
data = []
for strike, vals in rows.items():
data.append({
'strike': strike,
'gex_call': vals['gex_call'],
'gex_put': vals['gex_put'],
'gex_net': vals['gex_call'] + vals['gex_put'],
})
df = pd.DataFrame(data).sort_values('strike').reset_index(drop=True)
total = df['gex_net'].sum()
return df, float(total)
def compute_gex_from_records(records: List[Dict[str, Any]], spot: float) -> float:
"""
Compute total net GEX from a list of strike records (CBOE format).
Each record should have keys:
strike, oi_call, oi_put, gamma_call, gamma_put
"""
if not records or spot is None or spot <= 0:
return 0.0
spot = float(spot)
total = 0.0
for s in records:
gamma_c = _safe_float(s.get('gamma_call', s.get('gamma_c')))
gamma_p = _safe_float(s.get('gamma_put', s.get('gamma_p')))
oi_c = _safe_int(s.get('oi_call', s.get('open_interest_call')))
oi_p = _safe_int(s.get('oi_put', s.get('open_interest_put')))
total += gamma_c * oi_c * spot * spot / 100.0
total -= gamma_p * oi_p * spot * spot / 100.0
return total
def compute_gex_plus(records: List[Dict[str, Any]], spot: float) -> float:
"""
GEX+ with LOB (Limit Order Book) disagreement correction.
Adjusts GEX by detecting IV disagreement between puts and calls,
which can indicate informed flow distorting the options chain.
"""
if not records or spot is None or spot <= 0:
return 0.0
spot = float(spot)
total = 0.0
disagreements = 0
for s in records:
oi_c = _safe_int(s.get('oi_call'))
oi_p = _safe_int(s.get('oi_put'))
gamma_c = _safe_float(s.get('gamma_call'))
gamma_p = _safe_float(s.get('gamma_put'))
iv_c = _safe_float(s.get('iv_call', 0.001))
iv_p = _safe_float(s.get('iv_put', 0.001))
net_oi = oi_c - oi_p
gamma = (gamma_c + gamma_p) / 2.0
g = net_oi * gamma * spot * spot / 100.0
total += g
# Detect LOB disagreement (put IV significantly > call IV)
if iv_p > iv_c * 1.02 and iv_c > 0:
disagreements += 1
rate = disagreements / len(records) if records else 0.0
# Correction factor: reduce GEX proportionally to disagreement rate
correction = max(0.0, 1.0 - 2.0 * rate)
return total * correction
def compute_gex_by_strike(records: List[Dict[str, Any]], spot: float) -> pd.DataFrame:
"""
Compute GEX breakdown by strike from records.
Returns DataFrame with columns:
strike, gex_call, gex_put, gex_net, oi_call, oi_put, iv_call, iv_put
"""
if not records:
return pd.DataFrame(columns=['strike', 'gex_call', 'gex_put', 'gex_net',
'oi_call', 'oi_put', 'iv_call', 'iv_put'])
spot = float(spot) if spot else 0.0
data = []
for s in records:
strike = _safe_float(s.get('strike'))
oi_c = _safe_int(s.get('oi_call'))
oi_p = _safe_int(s.get('oi_put'))
gamma_c = _safe_float(s.get('gamma_call'))
gamma_p = _safe_float(s.get('gamma_put'))
iv_c = _safe_float(s.get('iv_call')) * 100 # Convert to percent
iv_p = _safe_float(s.get('iv_put')) * 100
if spot > 0:
gex_c = gamma_c * oi_c * spot * spot / 100.0
gex_p = -gamma_p * oi_p * spot * spot / 100.0
else:
gex_c = 0.0
gex_p = 0.0
data.append({
'strike': strike,
'gex_call': gex_c,
'gex_put': gex_p,
'gex_net': gex_c + gex_p,
'oi_call': oi_c,
'oi_put': oi_p,
'iv_call': iv_c,
'iv_put': iv_p,
})
return pd.DataFrame(data).sort_values('strike').reset_index(drop=True)
def compute_net_gex(records: List[Dict[str, Any]], spot: float) -> Dict[str, float]:
"""
Compute comprehensive net GEX metrics.
Returns dict with:
- net_gex: Total net gamma exposure ($)
- call_gex: Total call gamma exposure ($)
- put_gex: Total put gamma exposure ($)
- max_positive_strike: Strike with highest positive GEX
- max_negative_strike: Strike with highest negative GEX (most short gamma)
- gamma_wall_strike: Strike with highest absolute GEX (gamma wall)
"""
if not records or spot is None or spot <= 0:
return {
'net_gex': 0.0, 'call_gex': 0.0, 'put_gex': 0.0,
'max_positive_strike': 0.0, 'max_negative_strike': 0.0,
'gamma_wall_strike': 0.0,
}
spot = float(spot)
total_call = 0.0
total_put = 0.0
max_pos_gex = 0.0
max_pos_strike = 0.0
max_neg_gex = 0.0
max_neg_strike = 0.0
max_abs_gex = 0.0
gamma_wall = 0.0
for s in records:
strike = _safe_float(s.get('strike'))
oi_c = _safe_int(s.get('oi_call'))
oi_p = _safe_int(s.get('oi_put'))
gamma_c = _safe_float(s.get('gamma_call'))
gamma_p = _safe_float(s.get('gamma_put'))
gex_c = gamma_c * oi_c * spot * spot / 100.0
gex_p = -gamma_p * oi_p * spot * spot / 100.0
net = gex_c + gex_p
total_call += gex_c
total_put += gex_p
if net > max_pos_gex:
max_pos_gex = net
max_pos_strike = strike
if net < max_neg_gex:
max_neg_gex = net
max_neg_strike = strike
if abs(net) > max_abs_gex:
max_abs_gex = abs(net)
gamma_wall = strike
return {
'net_gex': total_call + total_put,
'call_gex': total_call,
'put_gex': total_put,
'max_positive_strike': max_pos_strike,
'max_negative_strike': max_neg_strike,
'gamma_wall_strike': gamma_wall,
}
def find_zero_gamma(
records: List[Dict[str, Any]],
spot: float,
max_range_pct: float = 25.0,
step: float = 0.1
) -> float:
"""
Find the spot level where GEX+ crosses zero.
Scales gamma by spot/test_spot to approximate how exposure
changes with spot moves (sticky-strike assumption).
Returns the spot price at which GEX+ flips sign.
"""
if not records or spot is None or spot <= 0:
return spot
gex_at_spot = compute_gex_plus(records, spot)
if gex_at_spot == 0:
return spot
sign_at = np.sign(gex_at_spot)
# Search upward
for pct in np.arange(0.0, max_range_pct + step, step):
test = spot * (1 + pct / 100.0)
scaled = _scale_records(records, spot, test)
g = compute_gex_plus(scaled, test)
if np.sign(g) != sign_at:
return test
# Search downward
for pct in np.arange(0.0, -max_range_pct - step, -step):
test = spot * (1 + pct / 100.0)
scaled = _scale_records(records, spot, test)
g = compute_gex_plus(scaled, test)
if np.sign(g) != sign_at:
return test
return spot * 1.05 # Fallback: 5% above spot
def compute_crash_profile(
records: List[Dict[str, Any]],
spot: float,
range_pct: Tuple[float, float] = (-15.0, 10.0),
step: float = 0.25
) -> List[Dict[str, Any]]:
"""
Compute GEX+ across a range of spot moves.
Returns list of dicts with keys: spot_pct, spx, gex_plus
"""
if not records or spot is None or spot <= 0:
return []
profile = []
gex_at = compute_gex_plus(records, spot)
for pct in np.arange(range_pct[0], range_pct[1] + step, step):
test = spot * (1 + pct / 100.0)
scaled = _scale_records(records, spot, test)
gex = compute_gex_plus(scaled, test)
profile.append({
'spot_pct': round(float(pct), 2),
'spx': round(test, 1),
'gex_plus': gex,
})
return profile
def compute_gamma_flip(records: List[Dict[str, Any]], spot: float) -> Dict[str, float]:
"""
Find gamma flip levels (both above and below current spot).
Returns dict with:
- up_flip: Spot level above current where GEX flips
- down_flip: Spot level below current where GEX flips
- current_gex: Current GEX value
"""
if not records or spot is None:
return {'up_flip': 0.0, 'down_flip': 0.0, 'current_gex': 0.0}
current_gex = compute_gex_plus(records, spot)
up_flip = spot
down_flip = spot
# Find upward flip
for pct in np.arange(0.1, 25.0, 0.1):
test = spot * (1 + pct / 100.0)
scaled = _scale_records(records, spot, test)
g = compute_gex_plus(scaled, test)
if np.sign(g) != np.sign(current_gex) and current_gex != 0:
up_flip = test
break
# Find downward flip
for pct in np.arange(-0.1, -25.0, -0.1):
test = spot * (1 + pct / 100.0)
scaled = _scale_records(records, spot, test)
g = compute_gex_plus(scaled, test)
if np.sign(g) != np.sign(current_gex) and current_gex != 0:
down_flip = test
break
return {
'up_flip': up_flip,
'down_flip': down_flip,
'current_gex': current_gex,
}
def _scale_records(
records: List[Dict[str, Any]],
old_spot: float,
new_spot: float
) -> List[Dict[str, Any]]:
"""
Scale gamma values in records for a different spot level.
Gamma scales inversely with spot (sticky-strike).
"""
if old_spot <= 0:
return records
factor = old_spot / new_spot
return [
{
**s,
'gamma_call': _safe_float(s.get('gamma_call')) * factor,
'gamma_put': _safe_float(s.get('gamma_put')) * factor,
}
for s in records
]
def compute_vanna_exposure(records: List[Dict[str, Any]], spot: float) -> float:
"""
Compute VEX (Vanna Exposure) from records.
Vanna = d(Delta)/d(IV) = d(Vega)/d(Spot)
Uses oi_call/oi_put, vanna_call/vanna_put if available,
otherwise estimates from gamma.
"""
if not records or spot is None or spot <= 0:
return 0.0
spot = float(spot)
total = 0.0
for s in records:
# Use explicit vanna if available, else estimate
vanna_c = _safe_float(s.get('vanna_call'))
vanna_p = _safe_float(s.get('vanna_put'))
if vanna_c == 0 and vanna_p == 0:
# Estimate vanna from gamma: vanna ≈ -gamma * d1 / (spot * vol)
gamma_c = _safe_float(s.get('gamma_call'))
gamma_p = _safe_float(s.get('gamma_put'))
iv_c = _safe_float(s.get('iv_call', 0.20))
iv_p = _safe_float(s.get('iv_put', 0.20))
strike = _safe_float(s.get('strike'))
if iv_c > 0 and strike > 0:
d1_c = (np.log(spot / strike) + 0.5 * iv_c**2) / iv_c
vanna_c = -gamma_c * d1_c / iv_c
if iv_p > 0 and strike > 0:
d1_p = (np.log(spot / strike) + 0.5 * iv_p**2) / iv_p
vanna_p = -gamma_p * d1_p / iv_p
oi_c = _safe_int(s.get('oi_call'))
oi_p = _safe_int(s.get('oi_put'))
total += vanna_c * oi_c * spot + vanna_p * oi_p * spot
return total
def compute_charm_exposure(records: List[Dict[str, Any]], spot: float) -> float:
"""
Compute Charm (Delta decay) exposure.
Charm = -d(Delta)/d(Time) = d(Delta)/d(T)
Important for hedging around expiration.
"""
if not records or spot is None or spot <= 0:
return 0.0
spot = float(spot)
total = 0.0
for s in records:
iv_c = _safe_float(s.get('iv_call', 0.20))
iv_p = _safe_float(s.get('iv_put', 0.20))
gamma_c = _safe_float(s.get('gamma_call'))
gamma_p = _safe_float(s.get('gamma_put'))
strike = _safe_float(s.get('strike'))
# Assume 30 DTE if not specified
dte = _safe_float(s.get('dte', 30))
T = dte / 365.0
if strike > 0 and iv_c > 0 and T > 0:
d1 = (np.log(spot / strike) + (0.5 * iv_c**2) * T) / (iv_c * np.sqrt(T))
d2 = d1 - iv_c * np.sqrt(T)
# Charm ≈ -gamma * (d2 / (2*T) - r/(IV*sqrt(T)))
# (simplified, r≈0)
charm_c = -gamma_c * d2 / (2 * T) if T > 0 else 0
else:
charm_c = 0
if strike > 0 and iv_p > 0 and T > 0:
d1 = (np.log(spot / strike) + (0.5 * iv_p**2) * T) / (iv_p * np.sqrt(T))
d2 = d1 - iv_p * np.sqrt(T)
charm_p = -gamma_p * d2 / (2 * T) if T > 0 else 0
else:
charm_p = 0
oi_c = _safe_int(s.get('oi_call'))
oi_p = _safe_int(s.get('oi_put'))
total += charm_c * oi_c * spot + charm_p * oi_p * spot
return total
def gex_summary(records: List[Dict[str, Any]], spot: float) -> Dict[str, Any]:
"""
Compute a comprehensive GEX summary.
Returns dict with all key metrics.
"""
net = compute_net_gex(records, spot)
flip = compute_gamma_flip(records, spot)
gex_plus = compute_gex_plus(records, spot)
vex = compute_vanna_exposure(records, spot)
return {
'spot': spot,
'gex_plus': gex_plus,
'net_gex': net['net_gex'],
'call_gex': net['call_gex'],
'put_gex': net['put_gex'],
'gamma_wall': net['gamma_wall_strike'],
'max_positive_strike': net['max_positive_strike'],
'max_negative_strike': net['max_negative_strike'],
'zero_gamma_up': flip['up_flip'],
'zero_gamma_down': flip['down_flip'],
'vex': vex,
'vgr': abs(vex) / abs(gex_plus) if abs(gex_plus) > 1e-6 else 0.0,
'regime': 'Long Gamma' if gex_plus > 0 else 'Short Gamma',
}
if __name__ == '__main__':
# Smoke test with synthetic data
rng = np.random.default_rng(42)
test_records = []
spot = 6632.0
for k in np.arange(5800, 7400, 25):
atm_dist = (k - spot) / spot
iv_c = max(0.05, 0.18 + abs(atm_dist) * 0.3 - atm_dist * 0.05)
iv_p = max(0.05, iv_c + 0.01 + max(0.0, -atm_dist * 0.08))
gamma = float(0.003 * np.exp(-50 * atm_dist**2))
test_records.append({
'strike': float(k),
'oi_call': int(abs(rng.normal(8000, 3000))),
'oi_put': int(abs(rng.normal(10000, 4000))),
'iv_c': round(iv_c, 4),
'iv_p': round(iv_p, 4),
'gamma_c': round(gamma, 6),
'gamma_p': round(gamma, 6),
})
summary = gex_summary(test_records, spot)
print("GEX Summary:")
for k, v in summary.items():
if isinstance(v, float):
print(f" {k}: {v:,.4f}")
else:
print(f" {k}: {v}")