cryptoindex / index_interface_hf.py
igriv's picture
Fix Gradio app launch for HF Spaces - create interface at module level
ff560c0
import gradio as gr
import plotly.express as px
from cryptoindex import *
import pandas as pd
from updater import *
from time import sleep
from functools import partial
import argparse
import os
from datetime import datetime
import hashlib
import json
import uuid
import sqlite3
from contextlib import contextmanager
# SQLite database path
DB_PATH = "cryptoindex.db"
@contextmanager
def get_db():
"""Context manager for database connections."""
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
try:
yield conn
finally:
conn.close()
def initialize_database():
"""Initialize SQLite database with required tables."""
with get_db() as conn:
cursor = conn.cursor()
# Create index cache table
cursor.execute("""
CREATE TABLE IF NOT EXISTS index_cache (
id INTEGER PRIMARY KEY AUTOINCREMENT,
cache_key TEXT UNIQUE NOT NULL,
start_date TEXT NOT NULL,
end_date TEXT NOT NULL,
locale TEXT NOT NULL,
market_type TEXT NOT NULL,
index_data TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Create user weights table
cursor.execute("""
CREATE TABLE IF NOT EXISTS user_weights (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
locale TEXT NOT NULL,
market_type TEXT NOT NULL,
weights_data TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Create indexes
cursor.execute("CREATE INDEX IF NOT EXISTS idx_cache_key ON index_cache(cache_key)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_session ON user_weights(session_id)")
conn.commit()
print("SQLite database initialized successfully!")
def get_cache_key(start_date: str, end_date: str, locale: str, market_type: str) -> str:
"""Generate a unique cache key for the given parameters."""
key_string = f"{start_date}_{end_date}_{locale}_{market_type}"
return hashlib.md5(key_string.encode()).hexdigest()
def get_or_create_session(request: gr.Request) -> str:
"""Get or create a session ID from the Gradio request."""
if hasattr(request, 'session_hash'):
return request.session_hash
return str(uuid.uuid4())
def fetch_from_cache(cache_key: str):
"""Fetch index data from SQLite cache."""
try:
with get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT index_data FROM index_cache WHERE cache_key = ?",
(cache_key,)
)
row = cursor.fetchone()
if row:
v_data = pd.read_json(row['index_data'])
v_data.index = pd.to_datetime(v_data.index)
print(f"Cache hit for key: {cache_key}")
return v_data
else:
print(f"Cache miss for key: {cache_key}")
except Exception as e:
print(f"Cache fetch error: {e}")
return None
def save_to_cache(cache_key: str, v_data: pd.DataFrame, start_date: str, end_date: str,
locale: str, market_type: str):
"""Save index data to SQLite cache."""
try:
with get_db() as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT OR REPLACE INTO index_cache
(cache_key, start_date, end_date, locale, market_type, index_data)
VALUES (?, ?, ?, ?, ?, ?)
""", (cache_key, start_date, end_date, locale, market_type, v_data.to_json()))
conn.commit()
print(f"Saved to cache: {cache_key}")
except Exception as e:
print(f"Cache save error: {e}")
def get_user_weights(session_id: str, locale: str, market_type: str):
"""Get user-specific weights from SQLite."""
try:
with get_db() as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT weights_data FROM user_weights
WHERE session_id = ? AND locale = ? AND market_type = ?
ORDER BY created_at DESC LIMIT 1
""", (session_id, locale, market_type))
row = cursor.fetchone()
if row:
return pd.read_json(row['weights_data'])
except Exception as e:
print(f"Error fetching user weights: {e}")
return None
def save_user_weights(session_id: str, weights_df: pd.DataFrame, locale: str, market_type: str):
"""Save user-specific weights to SQLite."""
try:
with get_db() as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO user_weights
(session_id, locale, market_type, weights_data)
VALUES (?, ?, ?, ?)
""", (session_id, locale, market_type, weights_df.to_json()))
conn.commit()
print(f"Saved weights for session: {session_id}")
except Exception as e:
print(f"Error saving user weights: {e}")
def plot_index_prices(start_date, end_date, request: gr.Request, **kwargs):
"""Plot historical index prices with caching."""
session_id = get_or_create_session(request)
locale = kwargs.get('locale', 'global')
market_type = kwargs.get('market_type', 'crypto')
cache_key = get_cache_key(start_date, end_date, locale, market_type)
v = fetch_from_cache(cache_key)
if v is None:
cryptodf = fetch_crypto_data(start_date=start_date, end_date=end_date, **kwargs)
v, _ = get_crypto_index(cryptodf, func=np.sqrt)
save_to_cache(cache_key, v, start_date, end_date, locale, market_type)
_, _, _, output = do_sharpe(v.close)
fig = px.line(v, x=v.index, y='close', title='Index Prices')
fig.update_xaxes(rangeslider_visible=True)
return fig, output
def realtime_update_weighted_prices(request: gr.Request, locale='global', market_type='crypto'):
"""Update real-time prices with user-specific weights."""
session_id = get_or_create_session(request)
if should_update_weights():
weights_df = update_weights(locale=locale, market_type=market_type)
save_user_weights(session_id, weights_df, locale, market_type)
last_day = get_user_weights(session_id, locale, market_type)
if last_day is None:
weights_df = update_weights(locale=locale, market_type=market_type)
save_user_weights(session_id, weights_df, locale, market_type)
last_day = weights_df
prices = update_day(last_day)
_, _, _, output = do_sharpe(prices, days=False)
fig = px.line(prices, x=prices.index, y=prices.values, title='Index Today')
return fig, output
def make_graph(choice, start_date=None, end_date=None, request: gr.Request = None, **kwargs):
"""Create graph based on user choice."""
if choice == "Historical":
fig, stats = plot_index_prices(start_date, end_date, request, **kwargs)
else:
fig, stats = realtime_update_weighted_prices(request, **kwargs)
return gr.Plot(fig), gr.Markdown(stats)
def cleanup_old_data():
"""Clean up old cache entries and weights."""
try:
with get_db() as conn:
cursor = conn.cursor()
# Delete cache entries older than 30 days
cursor.execute("""
DELETE FROM index_cache
WHERE created_at < datetime('now', '-30 days')
""")
# Delete user weights older than 7 days
cursor.execute("""
DELETE FROM user_weights
WHERE created_at < datetime('now', '-7 days')
""")
conn.commit()
print("Cleaned up old data")
except Exception as e:
print(f"Cleanup error: {e}")
def create_interface(locale='global', market_type='crypto'):
"""Create and return the Gradio interface."""
initialize_database()
cleanup_old_data()
with gr.Blocks() as iface:
gr.Markdown("# Crypto Index Tracker")
gr.Markdown("Each user session has isolated data and computations are cached locally.")
startdatebox = gr.Textbox(label="Start Date", placeholder="YYYY-MM-DD")
enddatebox = gr.Textbox(label="End Date", placeholder="YYYY-MM-DD")
radio = gr.Radio(choices=["Historical", "Real-time"], label="Graph Type", value="Historical")
update_button = gr.Button("Update Graph")
theplot = gr.Plot()
thestats = gr.Markdown()
make_graph_partial = partial(make_graph, locale=locale, market_type=market_type)
radio.change(
fn=make_graph_partial,
inputs=[radio, startdatebox, enddatebox],
outputs=[theplot, thestats]
)
update_button.click(
fn=make_graph_partial,
inputs=[radio, startdatebox, enddatebox],
outputs=[theplot, thestats]
)
return iface
# Create the interface at module level for HF Spaces
iface = create_interface()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--locale", default='global', help="the locale")
parser.add_argument("--market_type", default='crypto', help="the market type")
parser.add_argument("--share", action="store_true", help="share the interface")
args = parser.parse_args()
# Create interface with args
iface = create_interface(locale=args.locale, market_type=args.market_type)
# Detect if running on Hugging Face Spaces
if os.getenv("SPACE_ID"):
iface.launch()
else:
iface.launch(server_port=7860, server_name="0.0.0.0", share=args.share)