FatimahEmadEldin commited on
Commit
18e1844
·
verified ·
1 Parent(s): 5fd748d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -134
app.py CHANGED
@@ -1,146 +1,192 @@
1
- # app.py
2
  import gradio as gr
3
- import pickle
4
  import numpy as np
5
- import lightkurve as lk
6
- import matplotlib.pyplot as plt
 
7
  import os
8
  import warnings
9
 
10
- # --- CONFIGURATION & GLOBAL SETUP ---
11
- DATA_DIR = "data"
12
- MODEL_PATH = "transit_classifier.pkl"
13
- EXAMPLE_STARS = ["pi Mensae", "WASP-126", "TOI-700", "LHS 3844"]
14
- KNOWN_PERIODS = {
15
- "pi Mensae": 6.27,
16
- "WASP-126": 3.84,
17
- "TOI-700": 16.4, # for planet d
18
- "LHS 3844": 0.46
19
- }
20
-
21
- warnings.filterwarnings("ignore", category=lk.LightkurveWarning)
22
- plt.style.use('seaborn-v0_8-whitegrid')
23
-
24
- # --- LOAD THE PRE-TRAINED ML MODEL ONCE AT STARTUP ---
25
  try:
26
- with open(MODEL_PATH, "rb") as f:
27
- model_data = pickle.load(f)
28
- MODEL = model_data["model"]
29
- WINDOW_SIZE = model_data["window_size"]
30
- print(" ML model loaded successfully from 'transit_classifier.pkl'.")
31
- except FileNotFoundError:
32
- print(f"❌ FATAL ERROR: Model file not found at '{MODEL_PATH}'.")
33
- MODEL = None
34
- WINDOW_SIZE = 64 # A default to prevent crashing
35
-
36
- # --- CORE FUNCTIONS FOR THE GRADIO APP ---
37
-
38
- def load_and_plot_star(target_star):
39
- """
40
- Loads a star's data, cleans it, and generates the initial plot.
41
- """
42
- print(f"Loading and plotting data for '{target_star}'...")
43
- try:
44
- safe_filename = target_star.lower().replace(" ", "_") + ".fits"
45
- file_path = os.path.join(DATA_DIR, safe_filename)
46
- lc_raw = lk.read(file_path)
47
-
48
- flux_values = lc_raw.flux.value
49
- median_flux = np.nanmedian(flux_values)
50
- normalized_flux_values = flux_values / median_flux
51
- lc_normalized = lk.LightCurve(time=lc_raw.time, flux=normalized_flux_values)
52
- lc_clean = lc_normalized.flatten(window_length=401).remove_outliers()
53
-
54
- fig, ax = plt.subplots(figsize=(12, 6))
55
- lc_clean.plot(ax=ax, color='dodgerblue', marker='.', markersize=2, linestyle='none')
56
- ax.set_title(f"Light Curve for {target_star} - Ready for Analysis", fontsize=14)
57
- ax.set_ylabel("Normalized & Flattened Flux")
58
- ax.set_xlabel("Time [BTJD]")
59
- plt.tight_layout()
60
-
61
- # Return the plot, the cleaned light curve object, and a reset to the feedback panel
62
- return fig, lc_clean, "Plot loaded. **Click and drag on the plot** to select a potential transit."
63
- except Exception as e:
64
- print(f"[ERROR] Could not process {target_star}: {e}")
65
- fig, ax = plt.subplots(); ax.text(0.5, 0.5, f"Could not generate plot.\nError: {e}", ha='center')
66
- return fig, None, f"Error loading data for {target_star}."
67
-
68
- def check_user_selection(lc_object, star_name, selection_event: gr.SelectData):
69
- """
70
- Analyzes the region selected by the user on the plot via click-and-drag.
71
- This function is triggered by the 'select' event.
72
- """
73
- if lc_object is None or MODEL is None:
74
- return "Please select a star first. Model or data not loaded."
75
-
76
- # The selection_event contains the x-axis range of the user's drag
77
- start_time, end_time = selection_event.index
78
-
79
- # Find the data points within the user's selection
80
- selection_mask = (lc_object.time.value >= start_time) & (lc_object.time.value <= end_time)
81
- selected_flux = lc_object.flux.value[selection_mask]
82
-
83
- if len(selected_flux) < WINDOW_SIZE:
84
- return "### Selection Too Small\nYour selection is too small for the AI to analyze. Please select a wider region."
85
-
86
- # --- 1. Get the AI's Prediction ---
87
- windows = [selected_flux[i:i+WINDOW_SIZE] for i in range(len(selected_flux) - WINDOW_SIZE)]
88
- ai_predictions = MODEL.predict(windows)
89
- ai_found_transit = np.any(ai_predictions == 1)
90
-
91
- # --- 2. Get the Ground Truth ---
92
- period = KNOWN_PERIODS.get(star_name, None)
93
- if period:
94
- phase = lc_object.fold(period).phase.value[selection_mask]
95
- ground_truth_is_transit = np.any((phase > -0.05) & (phase < 0.05))
96
- else:
97
- ground_truth_is_transit = None
98
-
99
- # --- 3. Generate Feedback ---
100
- feedback_md = "### Analysis Results\n\n"
101
- if ground_truth_is_transit is not None:
102
- feedback_md += "✅ **Correct!** You have successfully identified a known transit region.\n" if ground_truth_is_transit else "❌ **Keep Looking!** This region does not contain a known transit.\n"
103
 
104
- feedback_md += "🤖 **The AI agrees with you** and also predicts a transit here.\n" if ai_found_transit else "🤖 **The AI disagrees** and does not predict a transit here.\n"
105
- return feedback_md
106
-
107
- # --- BUILD THE GRADIO INTERFACE using the corrected event listener syntax ---
108
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
109
- gr.Markdown("# TESS Transit Hunter 🔭\nSelect a star, then **click and drag** on the plot's x-axis to analyze for exoplanet transits.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- # Hidden component to store the lightkurve object between interactions
112
- lc_state = gr.State(None)
113
-
114
- with gr.Row():
115
- # Left column for controls and feedback
116
- with gr.Column(scale=1):
117
- star_selector = gr.Dropdown(
118
- choices=EXAMPLE_STARS,
119
- value="pi Mensae",
120
- label="1. Select a Target Star"
121
- )
122
- feedback_panel = gr.Markdown("### 2. Feedback Panel\n\nWelcome! Select a star to begin. The plot will appear on the right.")
123
 
124
- # Right column for the interactive plot
125
- with gr.Column(scale=3):
126
- lc_plot = gr.Plot(label="TESS Light Curve")
127
-
128
- # --- Define the interactive behaviors ---
129
- # When the dropdown value changes, run load_and_plot_star.
130
- # It updates the plot, the hidden state, and the feedback panel.
131
- star_selector.change(
132
- fn=load_and_plot_star,
133
- inputs=star_selector,
134
- outputs=[lc_plot, lc_state, feedback_panel]
135
- )
 
 
 
 
136
 
137
- # THIS IS THE KEY FIX:
138
- # When the user makes a selection on the plot, trigger the check_user_selection function.
139
- lc_plot.select(
140
- fn=check_user_selection,
141
- inputs=[lc_state, star_selector],
142
- outputs=feedback_panel
143
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  if __name__ == "__main__":
146
- demo.launch()
 
 
1
  import gradio as gr
2
+ import pandas as pd
3
  import numpy as np
4
+ import plotly.express as px
5
+ import plotly.graph_objects as go
6
+ import joblib
7
  import os
8
  import warnings
9
 
10
+ warnings.filterwarnings('ignore')
11
+
12
+ # --- 1. ROBUST FILE LOADING ---
 
 
 
 
 
 
 
 
 
 
 
 
13
  try:
14
+ def find_file(filename, search_paths=['./', './data/']):
15
+ for path in search_paths:
16
+ filepath = os.path.join(path, filename)
17
+ if os.path.exists(filepath):
18
+ print(f"Found '{filename}' at: {filepath}")
19
+ return filepath
20
+ return None
21
+
22
+ scaler_path = find_file('scaler.joblib')
23
+ kmeans_path = find_file('kmeans_model.joblib')
24
+ forecasting_path = find_file('forecasting_models.joblib')
25
+ data_path = find_file('consolidated_farm_data.csv')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ if not all([scaler_path, kmeans_path, forecasting_path, data_path]):
28
+ raise FileNotFoundError("Could not find all required model (.joblib) and data (.csv) files.")
29
+
30
+ scaler = joblib.load(scaler_path)
31
+ kmeans_model = joblib.load(kmeans_path)
32
+ forecasting_models = joblib.load(forecasting_path)
33
+ df_historical = pd.read_csv(data_path)
34
+ df_historical['timestamp'] = pd.to_datetime(df_historical['timestamp'])
35
+
36
+ ALL_FARMS = sorted(df_historical['farm_name'].unique())
37
+ FARM_COORDINATES = {
38
+ 'alia': [24.434117, 39.624376], 'Abdula altazi': [24.499210, 39.661664],
39
+ 'albadr': [24.499454, 39.666633], 'alhabibah': [24.499002, 39.667079],
40
+ 'alia almadinah': [24.450111, 39.627500], 'almarbad': [24.442014, 39.628323],
41
+ 'alosba': [24.431591, 39.605149], 'abuonoq': [24.494620, 39.623123],
42
+ 'wahaa nakeel': [24.442692, 39.623028], 'wahaa 2': [24.442388, 39.621116]
43
+ }
44
+ farm_coords_df = pd.DataFrame.from_dict(FARM_COORDINATES, orient='index', columns=['lat', 'lon']).reset_index().rename(columns={'index':'farm_name'})
45
+
46
+ except FileNotFoundError as e:
47
+ raise FileNotFoundError(f"CRITICAL ERROR: {e}")
48
+
49
+
50
+ # --- 2. DEFINE CORE FUNCTIONS ---
51
+ def get_performance_report():
52
+ kpi_df = df_historical.groupby('farm_name').agg(
53
+ mean_ndvi=('NDVI', 'mean'), mean_evi=('EVI', 'mean'), std_ndvi=('NDVI', 'std')
54
+ ).reset_index().dropna()
55
+ features = kpi_df[['mean_ndvi', 'mean_evi', 'std_ndvi']]
56
+ scaled_features = scaler.transform(features)
57
+ kpi_df['cluster'] = kmeans_model.predict(scaled_features)
58
+ cluster_centers = pd.DataFrame(scaler.inverse_transform(kmeans_model.cluster_centers_), columns=['mean_ndvi', 'mean_evi', 'std_ndvi'])
59
+ sorted_clusters = cluster_centers.sort_values(by='mean_ndvi', ascending=False).index
60
+ tier_map = {sorted_clusters[0]: 'Tier 1 (High)', sorted_clusters[1]: 'Tier 2 (Medium)', sorted_clusters[2]: 'Tier 3 (Low)'}
61
+ kpi_df['Performance Tier'] = kpi_df['cluster'].map(tier_map)
62
+ return kpi_df[['farm_name', 'Performance Tier', 'mean_ndvi', 'mean_evi']].sort_values('Performance Tier')
63
+
64
+ def detect_and_classify_anomalies(farm_name):
65
+ farm_data = df_historical[df_historical['farm_name'] == farm_name].set_index('timestamp').sort_index()
66
+ df_resampled = farm_data[['NDVI', 'NDWI', 'SAR_VV']].resample('W').mean().interpolate(method='linear')
67
+ df_change = df_resampled.diff().dropna()
68
+ rolling_std = df_change.rolling(window=12, min_periods=4).std()
69
+ thresholds = {'NDVI': rolling_std['NDVI'] * 1.5, 'NDWI': rolling_std['NDWI'] * 1.5, 'SAR_VV': rolling_std['SAR_VV'] * 1.5}
70
+ anomalies_found = []
71
+ for date, row in df_change.iterrows():
72
+ ndvi_change, ndwi_change, sar_vv_change = row['NDVI'], row['NDWI'], row['SAR_VV']
73
+ ndvi_thresh, ndwi_thresh, sar_thresh = thresholds['NDVI'].get(date, 0.07), thresholds['NDWI'].get(date, 0.07), thresholds['SAR_VV'].get(date, 1.0)
74
+ classification = "Normal"
75
+ if ndvi_change < -ndvi_thresh and sar_vv_change < -sar_thresh:
76
+ classification = 'Harvest Event'
77
+ elif ndvi_change < -ndvi_thresh and ndwi_change < -ndwi_thresh:
78
+ classification = 'Potential Drought Stress'
79
+ elif ndvi_change < -ndvi_thresh:
80
+ classification = 'General Stress Event'
81
+ if classification != "Normal":
82
+ anomalies_found.append({'Date': date, 'Classification': classification, 'NDVI Change': f"{ndvi_change:.3f}"})
83
+
84
+ fig = go.Figure()
85
+ fig.add_trace(go.Scatter(x=farm_data.index, y=farm_data['NDVI'], mode='lines', name='NDVI', line=dict(color='green')))
86
+ colors = {'Harvest Event': 'red', 'Potential Drought Stress': 'orange', 'General Stress Event': 'purple'}
87
 
88
+ # FINAL FIX: Manually add shapes and annotations instead of using fig.add_vline()
89
+ for anomaly in anomalies_found:
90
+ anomaly_date = anomaly['Date'].to_pydatetime()
91
+ line_color = colors.get(anomaly['Classification'])
 
 
 
 
 
 
 
 
92
 
93
+ # Add the vertical line shape
94
+ fig.add_shape(
95
+ type='line',
96
+ x0=anomaly_date, y0=0, x1=anomaly_date, y1=1,
97
+ yref='paper', # This makes the line span the full height of the plot
98
+ line=dict(color=line_color, width=2, dash='dash')
99
+ )
100
+
101
+ # Add the annotation text
102
+ fig.add_annotation(
103
+ x=anomaly_date, y=1.0, yref='paper', # Position text at the top
104
+ text=anomaly['Classification'],
105
+ showarrow=False,
106
+ yshift=10, # Shift text slightly above the top line
107
+ font=dict(color=line_color)
108
+ )
109
 
110
+ fig.update_layout(title=f'NDVI Timeline & Detected Anomalies for {farm_name}', xaxis_title='Date', yaxis_title='NDVI')
111
+
112
+ display_anomalies = [{'Date': a['Date'].strftime('%Y-%m-%d'), 'Classification': a['Classification'], 'NDVI Change': a['NDVI Change']} for a in anomalies_found]
113
+ return pd.DataFrame(display_anomalies), fig
114
+
115
+ def run_forecast(farm_name):
116
+ model = forecasting_models.get(farm_name)
117
+ last_date = df_historical['timestamp'].max()
118
+ future_dates = pd.to_datetime(pd.date_range(start=last_date, periods=12, freq='W'))
119
+ future_df = pd.DataFrame(index=future_dates)
120
+ future_df['day_of_year'] = future_df.index.dayofyear
121
+ farm_data = df_historical[df_historical['farm_name'] == farm_name]
122
+ future_df['EVI'] = farm_data['EVI'].iloc[-1]
123
+ future_df['NDWI'] = farm_data['NDWI'].iloc[-1]
124
+ predictions = model.predict(future_df[['day_of_year', 'EVI', 'NDWI']])
125
+
126
+ fig = go.Figure()
127
+ fig.add_trace(go.Scatter(x=farm_data['timestamp'], y=farm_data['NDVI'], mode='lines', name='Historical NDVI'))
128
+ fig.add_trace(go.Scatter(x=future_dates, y=predictions, mode='lines', name='Forecasted NDVI', line=dict(color='red', dash='dash')))
129
+ fig.update_layout(title=f'3-Month NDVI Forecast for {farm_name}')
130
+ return fig, pd.DataFrame({'Forecast Date': future_dates.strftime('%Y-%m-%d'), 'Predicted NDVI': np.round(predictions, 3)})
131
+
132
+ def plot_tier_distribution(report_df):
133
+ tier_counts = report_df['Performance Tier'].value_counts().reset_index()
134
+ tier_counts.columns = ['Performance Tier', 'Count']
135
+ fig = px.bar(tier_counts, x='Performance Tier', y='Count', title='Farm Distribution by Performance Tier',
136
+ color='Performance Tier', text_auto=True,
137
+ color_discrete_map={'Tier 1 (High)': 'green', 'Tier 2 (Medium)': 'orange', 'Tier 3 (Low)': 'red'})
138
+ fig.update_layout(showlegend=False)
139
+ return fig
140
+
141
+ # --- 3. BUILD GRADIO INTERFACE ---
142
+ df_performance_report = get_performance_report()
143
+
144
+ with gr.Blocks(theme=gr.themes.Soft(), title="Palm Farm Intelligence") as demo:
145
+ gr.Markdown("# Palm Farm Intelligence Platform")
146
+
147
+ with gr.Tabs():
148
+ with gr.TabItem("Performance Overview"):
149
+ with gr.Row():
150
+ with gr.Column(scale=1):
151
+ gr.Markdown("### All Farms Performance Tiers")
152
+ gr.DataFrame(df_performance_report)
153
+ gr.Markdown("### Tier Distribution")
154
+ tier_plot = gr.Plot()
155
+ with gr.Column(scale=2):
156
+ gr.Markdown("### Farm Locations")
157
+ map_plot = gr.Plot()
158
+
159
+ with gr.TabItem(" Anomaly Detection"):
160
+ gr.Markdown("### Intelligent Anomaly Detection")
161
+ anomaly_farm_selector = gr.Dropdown(ALL_FARMS, label="Select a Farm", value=ALL_FARMS[0])
162
+ with gr.Row():
163
+ anomaly_table = gr.DataFrame(headers=["Date", "Classification", "NDVI Change"])
164
+ anomaly_plot = gr.Plot()
165
+
166
+ with gr.TabItem(" NDVI Forecasting"):
167
+ gr.Markdown("### 3-Month Vegetation Health Forecast")
168
+ forecast_farm_selector = gr.Dropdown(ALL_FARMS, label="Select Farm to Forecast", value=ALL_FARMS[0])
169
+ forecast_plot = gr.Plot()
170
+ forecast_data = gr.DataFrame()
171
+
172
+ def update_anomaly_view(farm_name):
173
+ return detect_and_classify_anomalies(farm_name)
174
+ anomaly_farm_selector.change(fn=update_anomaly_view, inputs=anomaly_farm_selector, outputs=[anomaly_table, anomaly_plot])
175
+
176
+ def update_forecast_view(farm_name):
177
+ return run_forecast(farm_name)
178
+ forecast_farm_selector.change(fn=update_forecast_view, inputs=forecast_farm_selector, outputs=[forecast_plot, forecast_data])
179
+
180
+ def initial_load():
181
+ fig_map = px.scatter_mapbox(farm_coords_df, lat="lat", lon="lon", hover_name="farm_name",
182
+ color_discrete_sequence=["green"], zoom=8, height=500)
183
+ fig_map.update_layout(mapbox_style="open-street-map", margin={"r":0,"t":0,"l":0,"b":0})
184
+ fig_tier = plot_tier_distribution(df_performance_report)
185
+ an_table, an_plot = detect_and_classify_anomalies(ALL_FARMS[0])
186
+ fc_plot, fc_data = run_forecast(ALL_FARMS[0])
187
+ return fig_map, fig_tier, an_table, an_plot, fc_plot, fc_data
188
+
189
+ demo.load(fn=initial_load, outputs=[map_plot, tier_plot, anomaly_table, anomaly_plot, forecast_plot, forecast_data])
190
 
191
  if __name__ == "__main__":
192
+ demo.launch(debug=True)