streamlit_gmt_demo / src /streamlit_app.py
cwadayi's picture
Update src/streamlit_app.py
2fb0a80 verified
import streamlit as st
import pygmt
import pandas as pd
import io
import os
import contextlib
import traceback
import datetime
import re
from huggingface_hub import InferenceClient
# --- Page Configuration ---
st.set_page_config(layout="wide", page_title="Interactive PyGMT Map")
# --- Title ---
st.title("🗺️ Interactive Earthquake Map Editor")
# --- Default Code & Session State ---
default_code = """
fig = pygmt.Figure()
fig.coast(
region=[119, 124, 21, 26],
projection="M15c",
frame="afg",
shorelines="0.5p,black",
land="lightgray",
water="skyblue"
)
# Only try to plot if there is data after filtering
if not earthquake_df.empty:
pygmt.makecpt(cmap="viridis", series=[earthquake_df['depth'].min(), earthquake_df['depth'].max()], reverse=True)
fig.plot(
x=earthquake_df["lon"],
y=earthquake_df["lat"],
style="c",
size=earthquake_df["ML"] / 25,
fill=earthquake_df["depth"],
cmap=True,
pen="0.5p,black",
transparency=20
)
fig.colorbar(frame='af+l"Depth (km)"')
"""
if "code_input" not in st.session_state:
st.session_state.code_input = default_code
def reset_code():
st.session_state.code_input = default_code
# --- Data Loading Function ---
@st.cache_data
def load_data():
"""Loads and cleans data from the local GDMS_earthquake_catalog.csv file."""
try:
df = pd.read_csv("GDMS_earthquake_catalog.csv")
df['date'] = pd.to_datetime(df['date'], errors='coerce')
eq_cols_to_numeric = ["lon", "lat", "depth", "ML"]
for col in eq_cols_to_numeric:
df[col] = pd.to_numeric(df[col], errors='coerce')
df.dropna(subset=['date'] + eq_cols_to_numeric, inplace=True)
return df
except FileNotFoundError:
st.error("Error: `GDMS_earthquake_catalog.csv` not found. Please make sure the file is in your project directory.")
return None
except Exception as e:
st.error(f"Error reading or processing file: {e}")
return None
# --- Main App Logic ---
original_df = load_data()
if original_df is not None:
# --- Sidebar Filters ---
st.sidebar.header("Filter Earthquakes")
min_date = original_df['date'].min().date()
max_date = original_df['date'].max().date()
date_range = st.sidebar.date_input(
"Date range", (min_date, max_date), min_value=min_date, max_value=max_date
)
min_mag = float(original_df['ML'].min())
max_mag = float(original_df['ML'].max())
mag_range = st.sidebar.slider(
"Magnitude (ML)", min_value=min_mag, max_value=max_mag, value=(min_mag, max_mag)
)
min_depth = float(original_df['depth'].min())
max_depth = float(original_df['depth'].max())
depth_range = st.sidebar.slider(
"Depth (km)", min_value=min_depth, max_value=max_depth, value=(min_depth, max_depth)
)
# Apply filters
start_date, end_date = (datetime.date.min, datetime.date.max)
if len(date_range) == 2:
start_date, end_date = date_range
filtered_df = original_df[
(original_df['date'].dt.date >= start_date) &
(original_df['date'].dt.date <= end_date) &
(original_df['ML'] >= mag_range[0]) &
(original_df['ML'] <= mag_range[1]) &
(original_df['depth'] >= depth_range[0]) &
(original_df['depth'] <= depth_range[1])
]
# --- Display Filtered Data Table ---
with st.expander(f"Show Filtered Data ({len(filtered_df)} of {len(original_df)} earthquakes)"):
st.dataframe(filtered_df)
# --- Learning Tips Section ---
with st.expander("💡 Click here for PyGMT Learning Tips & Tricks"):
st.subheader("📚 Use the Official Documentation First")
st.markdown("- **[Gallery](https://www.pygmt.org/latest/gallery/index.html)**: Find a map you like and adapt the code.")
# --- ADDED: AI Code Generation Section ---
with st.expander("🤖 Generate Code with AI"):
st.info("This feature uses an LLM to generate code. You must have an `HF_TOKEN` secret set in your Space settings.")
hf_token = os.environ.get('HF_TOKEN')
goal_input = st.text_input("Describe a map feature to add (e.g., 'add a map scale' or 'draw country borders for Taiwan')", key="goal")
if st.button("Generate Code Snippet"):
if not goal_input:
st.warning("Please describe your goal.")
#elif 'HF_TOKEN' not in st.secrets:
#st.error("Hugging Face token not found. Please add it to your Space secrets as `HF_TOKEN`.")
else:
with st.spinner("Generating code with AI..."):
try:
client = InferenceClient(token=hf_token)
prompt = (
f"Based on the pygmt Figure object 'fig', provide ONLY the Python code segment to {goal_input}. "
"Do not add explanation or markdown formatting. The code should start with 'fig.'."
)
completion = client.chat.completions.create(
model="Qwen/Qwen3-8B",
messages=[{"role": "user", "content": prompt}],
)
generated_code = completion.choices[0].message.content
# Clean the generated code by removing markdown fences
cleaned_code = re.sub(r'```python\n|```', '', generated_code).strip()
# Append the new code to the existing code in the editor
st.session_state.code_input += f"\n\n# Code generated by AI for: {goal_input}\n{cleaned_code}"
st.success("Code added to the editor below!")
except Exception as e:
st.error(f"Could not generate code: {e}")
# --- END of AI Section ---
# --- Live Code Editor ---
st.subheader("PyGMT Code Editor")
st.text_area(label="PyGMT Code", key="code_input", height=300)
# --- Action Buttons ---
col1, col2, _ = st.columns([2, 1, 3])
run_button = col1.button("Generate Map", type="primary")
col2.button("Reset Code", on_click=reset_code)
if run_button:
temp_file_path = "temp_map.png"
local_vars = {"earthquake_df": filtered_df}
stdout_capture = io.StringIO()
st.markdown("---")
st.subheader("Execution Results")
try:
try:
with contextlib.redirect_stdout(stdout_capture):
exec(st.session_state.code_input, globals(), local_vars)
fig = local_vars.get("fig")
if isinstance(fig, pygmt.Figure):
fig.savefig(temp_file_path, dpi=150)
st.image(temp_file_path)
else:
st.success("Code executed, but no 'fig' object was found to display.")
except Exception:
error_trace = traceback.format_exc()
st.error(error_trace)
finally:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)