Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| import numpy as np | |
| import torch | |
| from typing import Dict, List, Tuple | |
| import re | |
| from typing import Callable, Union, Dict | |
| class TimeSeriesEditor: | |
| def __init__(self, seq_length: int, feature_dim: int, trainer): | |
| # Existing initialization | |
| self.seq_length = seq_length | |
| self.feature_dim = feature_dim | |
| self.trainer = trainer | |
| self.coef = None | |
| self.stepsize = None | |
| self.sampling_steps = None | |
| self.feature_names = ["revenue", "download", "daily active user"]# * 20 | |
| # self.feature_names = [f"Feature {i}" for i in range(self.feature_dim)] | |
| # Store the latest model output | |
| self.latest_sample = None | |
| self.latest_observed_points = None | |
| self.latest_observed_mask = None | |
| self.latest_gradient_control_signal = None | |
| self.latest_model_control_signal = None | |
| # self.latest_metrics | |
| # Define scales for each feature | |
| self.feature_scales = { | |
| 0: 1000000, # Revenue: $1M per 0.1 | |
| 1: 100000, # Download: 100K downloads per 0.1 | |
| 2: 10000 # AU: 10K active users per 0.1 | |
| } | |
| self.feature_units = { | |
| 0: "$", # Revenue | |
| 1: "downloads", # Download | |
| 2: "users" # AU | |
| } | |
| self.show_normalized = True | |
| # Add frequency band multipliers | |
| self.freq_bands = np.ones(5) # 5 frequency bands, initially all set to 1.0 | |
| self.function_parser = FunctionParser() | |
| self.trending_controls = [ | |
| # (200, 250, 0, self.function_parser.string_to_function("sin(2*pi*x)"), 0.05) | |
| # 200,250,0,sin(2*pi*x),0.05 | |
| ] | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def format_value(self, value: float, feature_idx: int) -> str: | |
| """Format value with appropriate units and notation""" | |
| if self.show_normalized: | |
| return f"{value:.4f}" | |
| else: | |
| if feature_idx == 0: # Revenue | |
| return f"{self.feature_units[feature_idx]}{value:,.2f}" | |
| else: # Downloads and AU | |
| return f"{value:,.0f} {self.feature_units[feature_idx]}" | |
| def create_plot(self, sample: np.ndarray, observed_points: torch.Tensor, | |
| observed_mask: torch.Tensor, | |
| gradient_control_signal: Dict, metrics: Dict) -> List[go.Figure]: | |
| figures = [] | |
| # Get weights from model_control_signal (will be all 1s if not provided) | |
| weights = observed_mask | |
| for feat_idx in range(self.feature_dim): | |
| fig = go.Figure() | |
| # Scale values if needed | |
| scale_factor = self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1 | |
| # Plot predicted line | |
| predicted_values = sample[:, feat_idx] * scale_factor | |
| fig.add_trace(go.Scatter( | |
| x=np.arange(self.seq_length), | |
| y=predicted_values, | |
| mode='lines', | |
| name='Predicted', | |
| line=dict(color='green', width=2), | |
| showlegend=True | |
| )) | |
| # Calculate and plot confidence bands based on weights | |
| # Lower weights = larger uncertainty bands | |
| mask = observed_points[:, feat_idx] > 0 | |
| ox = np.arange(0, self.seq_length)[mask] | |
| oy = observed_points[mask, feat_idx].numpy() * scale_factor | |
| weights_masked = 1 - weights[mask, feat_idx].numpy() | |
| # Calculate error bars - inverse relationship with weight | |
| # Weight of 1.0 gives minimal uncertainty (0.02) | |
| # Weight of 0.1 gives larger uncertainty (0.2) | |
| # error_y = 0.02 / weights_masked | |
| error_y = weights_masked / 5 | |
| # Plot observed points with error bars - changed symbol to 'cross' | |
| fig.add_trace(go.Scatter( | |
| x=ox, | |
| y=oy, | |
| mode='markers', | |
| name='Observed', | |
| marker=dict( | |
| # special red | |
| color='rgba(255, 0, 0, 0.5)', | |
| # size=10, | |
| symbol='x', # Changed from 'circle' to 'x' for cross symbol | |
| ), | |
| error_y=dict( | |
| type='data', | |
| array=error_y * scale_factor, | |
| visible=True, | |
| thickness=0.5, | |
| width=2, | |
| color='blue' | |
| ), | |
| showlegend=True | |
| )) | |
| # Add shaded confidence bands around the predicted line | |
| # This shows the general uncertainty in the prediction | |
| uncertainty = 0.05 # Base uncertainty level | |
| upper_bound = predicted_values + uncertainty * scale_factor | |
| lower_bound = predicted_values - uncertainty * scale_factor | |
| fig.add_trace(go.Scatter( | |
| x=np.concatenate([np.arange(self.seq_length), np.arange(self.seq_length)[::-1]]), | |
| y=np.concatenate([upper_bound, lower_bound[::-1]]), | |
| # fill='toself', | |
| # fillcolor='rgba(0,100,0,0.1)', | |
| line=dict(color='rgba(255,255,255,0)'), | |
| name='Prediction Interval', | |
| showlegend=True | |
| )) | |
| # Add vertical lines for peak points | |
| if gradient_control_signal.get("peak_points"): | |
| for peak_point in gradient_control_signal["peak_points"]: | |
| fig.add_vline(x=peak_point, line_dash="dash", line_color="red") | |
| # Add metrics annotations | |
| total_value = np.sum(sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1) | |
| annotations = [dict( | |
| x=0.02, | |
| y=1.1, | |
| xref="paper", | |
| yref="paper", | |
| text=f"Total {self.feature_names[feat_idx]}: {self.format_value(total_value, feat_idx)}", | |
| showarrow=False | |
| )] | |
| # Update y-axis title based on feature and scaling | |
| if self.show_normalized: | |
| y_title = f'{self.feature_names[feat_idx]} (Normalized)' | |
| else: | |
| unit = self.feature_units[feat_idx] | |
| y_title = f'{self.feature_names[feat_idx]} ({unit})' | |
| # Create a more informative legend for uncertainty | |
| legend_text = ( | |
| "Prediction with Confidence Bands<br>" | |
| "• Blue points: Observed values with uncertainty<br>" | |
| "• Green line: Predicted values<br>" | |
| # "• Shaded area: Prediction uncertainty<br>" | |
| "• Error bars: Observation uncertainty (larger = lower weight)" | |
| ) | |
| fig.update_layout( | |
| title=dict( | |
| text=f'Feature: {self.feature_names[feat_idx]}', | |
| x=0.5, | |
| y=0.95 | |
| ), | |
| xaxis_title='Time', | |
| yaxis_title=y_title, | |
| height=400, | |
| showlegend=True, | |
| dragmode='select', | |
| annotations=[ | |
| *annotations, | |
| # dict( | |
| # x=1.15, | |
| # y=0.5, | |
| # xref="paper", | |
| # yref="paper", | |
| # text=legend_text, | |
| # showarrow=False, | |
| # align="left", | |
| # bordercolor="black", | |
| # borderwidth=1, | |
| # borderpad=4, | |
| # bgcolor="white", | |
| # ) | |
| ], | |
| margin=dict(r=200) # Add right margin for legend | |
| ) | |
| figures.append(fig) | |
| return figures | |
| def update_scaling(self, | |
| revenue_scale: float, | |
| download_scale: float, | |
| au_scale: float, | |
| show_normalized: bool) -> Tuple[List[go.Figure], Dict]: | |
| """Update the scaling parameters and redraw plots""" | |
| if self.latest_sample is None: | |
| return [], {} | |
| # Update scales | |
| self.feature_scales = { | |
| 0: revenue_scale, | |
| 1: download_scale, | |
| 2: au_scale | |
| } | |
| self.show_normalized = show_normalized | |
| # Calculate metrics | |
| metrics = { | |
| 'show_normalized': self.show_normalized | |
| } | |
| for feat_idx in range(self.feature_dim): | |
| total = np.sum(self.latest_sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1) | |
| metrics[f'total_{self.feature_names[feat_idx]}'] = self.format_value(total, feat_idx) | |
| # Update plots | |
| figures = self.create_plot( | |
| self.latest_sample, | |
| self.latest_observed_points, | |
| self.latest_observed_mask, | |
| self.latest_gradient_control_signal, | |
| metrics | |
| ) | |
| return figures, metrics | |
| def parse_data_points(self, df) -> Dict: | |
| """Parse data points from DataFrame with columns: time,feature,value""" | |
| data_dict = {} | |
| if df is None or df.empty: | |
| return data_dict | |
| for _, row in df.iterrows(): | |
| # Skip if any required value is NaN | |
| if pd.isna(row['time']) or pd.isna(row['feature']) or pd.isna(row['value']): | |
| continue | |
| try: | |
| time_idx = int(row['time']) | |
| feature_idx = int(row['feature']) | |
| value = float(row['value']) | |
| if time_idx not in data_dict: | |
| data_dict[time_idx] = {} | |
| data_dict[time_idx][feature_idx] = (value, 1.0) | |
| except (ValueError, TypeError): | |
| continue | |
| return data_dict | |
| def parse_point_groups(self, df) -> Dict: | |
| """Parse point groups from DataFrame with columns: start,end,interval,feature,value,weight""" | |
| data_dict = {} | |
| if df is None or df.empty: | |
| return data_dict | |
| for _, row in df.iterrows(): | |
| # Skip if any required value is NaN | |
| if pd.isna(row['start']) or pd.isna(row['end']) or pd.isna(row['interval']) or \ | |
| pd.isna(row['feature']) or pd.isna(row['value']): | |
| continue | |
| try: | |
| start = int(row['start']) | |
| end = int(row['end']) | |
| interval = int(row['interval']) | |
| feature = int(row['feature']) | |
| value = float(row['value']) | |
| weight = float(row.get('weight', 1.0)) if not pd.isna(row.get('weight')) else 1.0 | |
| for t in range(start, end + 1, interval): | |
| if 0 <= t < self.seq_length: | |
| if t not in data_dict: | |
| data_dict[t] = {} | |
| data_dict[t][feature] = (value, weight) | |
| except (ValueError, TypeError): | |
| continue | |
| return data_dict | |
| def to_tensor(self, observed_points_dict, seq_length, feature_dim): | |
| observed_points = torch.zeros((seq_length, feature_dim)) | |
| observed_weights = torch.zeros((seq_length, feature_dim)) | |
| for seq, feature_dict in observed_points_dict.items(): | |
| for feature, (value, weight) in feature_dict.items(): | |
| observed_points[seq, feature] = value | |
| observed_weights[seq, feature] = weight | |
| return observed_points, observed_weights | |
| def apply_direct_edits(self, sample: np.ndarray, edit_params: Dict) -> np.ndarray: | |
| """Apply direct edits to the sample array""" | |
| edited_sample = sample.copy() | |
| if edit_params.get("enable_direct_area"): | |
| areas = self.parse_area_selections(edit_params["direct_areas"]) | |
| for area in areas: | |
| start, end, feat_idx, target = area | |
| edited_sample[start:end, feat_idx] += target | |
| edited_sample = np.clip(edited_sample, 0, 1) | |
| return edited_sample | |
| def parse_area_selections(self, area_text: str) -> List[Tuple]: | |
| """Parse area selection text into (start, end, feature, target) tuples""" | |
| areas = [] | |
| if not area_text.strip(): | |
| return areas | |
| area_text = area_text.replace('\n', ';') | |
| for line in area_text.strip().split(';'): | |
| if not line.strip(): | |
| continue | |
| try: | |
| start, end, feat, target = map(float, line.strip().split(',')) | |
| areas.append((int(start), int(end), int(feat), target)) | |
| except (ValueError, IndexError): | |
| continue | |
| return areas | |
| def apply_trending_mask(self, points: torch.Tensor, mask: torch.Tensor, consider_last_generated=False) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Apply trending functions as soft constraints through masks""" | |
| if not self.trending_controls or self.latest_sample is None: | |
| return points, mask | |
| for start, end, feat_idx, func, confidence in self.trending_controls: | |
| if start < 0 or end > self.seq_length or start >= end: | |
| continue | |
| # Generate x values normalized between 0 and 1 for the segment | |
| x = np.linspace(0, 1, end - start) | |
| try: | |
| # Calculate the function values | |
| y = func(x) | |
| # Scale the function output to 0-1 range | |
| y = (y - np.min(y)) / (np.max(y) - np.min(y)) | |
| # points[start:end, feat_idx] = torch.tensor(y, dtype=points.dtype) | |
| # mask[start:end, feat_idx] = max(mask[start:end, feat_idx], min(1.0, confidence * abs( | |
| # self.latest_sample[start:end, feat_idx] - y | |
| # ))) # Use lower weight for trending constraints | |
| except Exception as e: | |
| print(f"Error applying function: {e}") | |
| continue | |
| # Apply the trend as soft constraints | |
| mask_zero = (mask[start:end, feat_idx] == 0) | |
| points[start:end, feat_idx][mask_zero] = torch.tensor(y, dtype=points.dtype)[mask_zero] | |
| mask[start:end, feat_idx][mask_zero] = torch.tensor(confidence * np.ones_like(y), dtype=mask.dtype)[mask_zero] | |
| # mask[start:end, feat_idx][mask_zero] = torch.tensor((confidence * np.abs(self.latest_sample[start:end, feat_idx] - y)), dtype=mask.dtype)[mask_zero] | |
| mask = mask.clamp(0, 1) | |
| return points, mask | |
| def update_model(self, | |
| figures: List[go.Figure], | |
| data_points: str, | |
| point_groups: str, | |
| enable_area_control: bool, | |
| area_selections: str, | |
| enable_auc: bool, | |
| auc_value: float, | |
| enable_peaks: bool, | |
| peak_points: str, | |
| peak_alpha: float, | |
| auc_weight: float, | |
| peak_weight: float, | |
| enable_trending: bool = True, | |
| enable_trending_with_diff: bool = False, | |
| trending_params: str = "" | |
| ) -> Tuple[List[go.Figure], str, str, Dict]: | |
| # Parse both point groups and individual data points | |
| individual_points_dict = self.parse_data_points(data_points) | |
| group_points_dict = self.parse_point_groups(point_groups) | |
| # Merge dictionaries, giving preference to individual points | |
| combined_points_dict = group_points_dict.copy() | |
| for t, feat_dict in individual_points_dict.items(): | |
| if t not in combined_points_dict: | |
| combined_points_dict[t] = {} | |
| for f, v in feat_dict.items(): | |
| combined_points_dict[t][f] = v | |
| # Convert to tensor | |
| observed_points, observed_weights = self.to_tensor( | |
| combined_points_dict, | |
| self.seq_length, | |
| self.feature_dim | |
| ) | |
| observed_mask = observed_weights | |
| # Parse peak points | |
| peak_points_list = [] | |
| if enable_peaks and peak_points: | |
| try: | |
| peak_points_list = [int(x.strip()) for x in peak_points.split(',') if x.strip()] | |
| except ValueError: | |
| peak_points_list = [] | |
| # Apply trending control if enabled | |
| if enable_trending and trending_params: | |
| self.parse_trending_parameters(trending_params) | |
| observed_points, observed_mask = self.apply_trending_mask(observed_points, observed_mask, consider_last_generated=enable_trending_with_diff) | |
| # Build gradient control signal | |
| # IMPORTANT | |
| gradient_control_signal = {} | |
| if enable_auc: | |
| gradient_control_signal["auc"] = auc_value | |
| gradient_control_signal["auc_weight"] = auc_weight | |
| if enable_peaks: | |
| gradient_control_signal.update({ | |
| "peak_points": peak_points_list, | |
| "peak_alpha": peak_alpha, | |
| "peak_weight": peak_weight | |
| }) | |
| # Build model control signal | |
| model_control_signal = {} | |
| # if enable_area_control and area_selections: | |
| # areas = self.parse_area_selections(area_selections) | |
| # if areas: | |
| # model_control_signal["selected_areas"] = areas | |
| # Run prediction | |
| with torch.no_grad(): | |
| # to cuda | |
| observed_points = observed_points.to(self.device) | |
| observed_mask = observed_mask.to(self.device) | |
| sample = self.trainer.predict_weighted_points( | |
| observed_points, # (seq_length, feature_dim) | |
| observed_mask, # (seq_length, feature_dim) | |
| self.coef, # fixed | |
| self.stepsize, # fixed | |
| self.sampling_steps, # fixed | |
| # model_control_signal=model_control_signal, | |
| gradient_control_signal=gradient_control_signal | |
| ) | |
| observed_points = observed_points.cpu() | |
| observed_mask = observed_mask.cpu() | |
| # Store latest results | |
| self.latest_sample = sample | |
| self.latest_observed_points = observed_points | |
| self.latest_observed_mask = observed_mask | |
| self.latest_gradient_control_signal = gradient_control_signal | |
| self.latest_model_control_signal = model_control_signal | |
| # Calculate metrics | |
| metrics = { | |
| 'show_normalized': self.show_normalized | |
| } | |
| for feat_idx in range(self.feature_dim): | |
| total = np.sum(sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1) | |
| metrics[f'total_{self.feature_names[feat_idx]}'] = self.format_value(total, feat_idx) | |
| # Update plots | |
| figures = self.create_plot(sample, observed_points, observed_mask, gradient_control_signal, metrics) | |
| return figures, data_points, point_groups, metrics | |
| def update_additional_edit( | |
| self, | |
| enable_direct_area: bool, | |
| direct_areas: str): | |
| # Apply direct edits if enabled | |
| if enable_direct_area: | |
| sample = self.apply_direct_edits(self.latest_sample, { | |
| "enable_direct_area": enable_direct_area, | |
| "direct_areas": direct_areas | |
| }) | |
| else: | |
| sample = self.latest_sample | |
| # Calculate metrics | |
| metrics = { | |
| 'show_normalized': self.show_normalized | |
| } | |
| for feat_idx in range(self.feature_dim): | |
| total = np.sum(sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1) | |
| metrics[f'total_{self.feature_names[feat_idx]}'] = self.format_value(total, feat_idx) | |
| # Update plots | |
| figures = self.create_plot( | |
| sample, | |
| self.latest_observed_points, | |
| self.latest_observed_mask, | |
| self.latest_gradient_control_signal, | |
| metrics | |
| ) | |
| return figures, metrics | |
| def apply_frequency_filter(self, signal: np.ndarray) -> np.ndarray: | |
| """Apply FFT-based frequency filtering using the current band multipliers""" | |
| # Get FFT of the signal | |
| fft = np.fft.fft(signal) | |
| freqs = np.fft.fftfreq(len(signal)) | |
| # Split frequencies into 5 bands | |
| # Exclude DC component (0 frequency) from bands | |
| pos_freqs = freqs[1:len(freqs)//2] | |
| freq_ranges = np.array_split(pos_freqs, 5) | |
| # Apply band multipliers | |
| filtered_fft = fft.copy() | |
| # Handle DC component separately (lowest frequency) | |
| filtered_fft[0] *= self.freq_bands[4] # Apply very low freq multiplier to DC | |
| # Apply multipliers to each frequency band | |
| for i, freq_range in enumerate(freq_ranges): | |
| # Get indices for this frequency band | |
| band_mask = np.logical_and( | |
| freqs >= freq_range[0], | |
| freqs <= freq_range[-1] | |
| ) | |
| # Apply multiplier to positive and negative frequencies symmetrically | |
| filtered_fft[band_mask] *= self.freq_bands[4-i] | |
| filtered_fft[np.flip(band_mask)] *= self.freq_bands[4-i] | |
| # Convert back to time domain | |
| filtered_signal = np.real(np.fft.ifft(filtered_fft)) | |
| return filtered_signal | |
| def update_frequency_bands(self, band_idx: int, value: float) -> Tuple[List[go.Figure], Dict]: | |
| """Update a frequency band multiplier and recompute the filtered signal""" | |
| if self.latest_sample is None: | |
| return [], {} | |
| # Update the specified band multiplier | |
| self.freq_bands[band_idx] = value | |
| # Apply frequency filtering to each feature | |
| filtered_sample = self.latest_sample.copy() | |
| for feat_idx in range(self.feature_dim): | |
| filtered_sample[:, feat_idx] = self.apply_frequency_filter( | |
| self.latest_sample[:, feat_idx] | |
| ) | |
| # Ensure values remain in valid range | |
| filtered_sample = np.clip(filtered_sample, 0, 1) | |
| # Calculate metrics | |
| metrics = { | |
| 'show_normalized': self.show_normalized, | |
| 'frequency_bands': self.freq_bands.tolist() | |
| } | |
| for feat_idx in range(self.feature_dim): | |
| total = np.sum(filtered_sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1) | |
| metrics[f'total_{self.feature_names[feat_idx]}'] = self.format_value(total, feat_idx) | |
| # Update plots | |
| figures = self.create_plot( | |
| filtered_sample, | |
| self.latest_observed_points, | |
| self.latest_observed_mask, | |
| self.latest_gradient_control_signal, | |
| metrics | |
| ) | |
| return figures, metrics | |
| def parse_trending_parameters(self, trending_text: str) -> List[Tuple]: | |
| """Parse trending control parameters into (start, end, feature, function) tuples""" | |
| trending_params = [] | |
| if not trending_text.strip(): | |
| return trending_params | |
| trending_text = trending_text.replace('\n', ';') | |
| for line in trending_text.strip().split(';'): | |
| if not line.strip(): | |
| continue | |
| try: | |
| # Split by comma and handle the function part separately | |
| parts = line.strip().split(',', 4) | |
| if len(parts) != 5: | |
| continue | |
| start, end, feat = map(int, parts[:3]) | |
| function_str = parts[3].strip() | |
| confidence = float(parts[4]) | |
| # Convert the function string to a callable | |
| try: | |
| func = self.function_parser.string_to_function(function_str) | |
| trending_params.append((start, end, feat, func, confidence)) | |
| except ValueError as e: | |
| print(f"Error parsing function '{function_str}': {e}") | |
| continue | |
| except (ValueError, IndexError): | |
| continue | |
| self.trending_controls = trending_params # Store the parsed parameters | |
| return trending_params | |
| def create_gradio_interface(editor: TimeSeriesEditor): | |
| with gr.Blocks() as app: | |
| gr.Markdown("# Time Series Editor") | |
| gr.Markdown("## Instruction: Scroll Down + Click [Update Figure] [~10s]") | |
| metrics_display = gr.JSON(label="Metrics", value={}) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # with Tab(): | |
| # Scaling Parameters Section | |
| # with gr.Group(): | |
| gr.Markdown("## Scaling Parameters") | |
| with gr.Accordion("Open for More Detail", open=False): | |
| revenue_scale = gr.Number( | |
| label="Revenue Scale ($ per 0.1 in model)", | |
| value=1000000 | |
| ) | |
| download_scale = gr.Number( | |
| label="Download Scale (downloads per 0.1 in model)", | |
| value=100000 | |
| ) | |
| au_scale = gr.Number( | |
| label="Active Users Scale (users per 0.1 in model)", | |
| value=10000 | |
| ) | |
| show_normalized = gr.Checkbox( | |
| label="Show Normalized Values (0-1 scale)", | |
| value=True | |
| ) | |
| update_scaling_btn = gr.Button("Update Scaling") | |
| # TS Section | |
| gr.Markdown("## Time Series Control Panel") | |
| # with gr.Accordion("Open for More Detail"): | |
| with gr.Group(): | |
| gr.Markdown("### Fixed Point Control") | |
| data_points_df = gr.Dataframe( | |
| headers=["time", "feature", "value"], | |
| datatype=["number", "number", "number"], | |
| # label="Anchor Point Control", | |
| value=[[0, 0, 0.04], [2, 0, 0.58], [6, 0, 0.27], [58, 0, 0.8], [60, 0, 0.5]], | |
| col_count=(3, "fixed"), # Fix number of columns | |
| interactive=True | |
| ) | |
| add_data_point_btn = gr.Button("Add Data Point") | |
| def add_data_point(df): | |
| new_row = pd.DataFrame([[None, 0, None]], | |
| columns=["time", "feature", "value"]) | |
| return pd.concat([df, new_row], ignore_index=True) | |
| add_data_point_btn.click( | |
| fn=add_data_point, | |
| inputs=[data_points_df], | |
| outputs=[data_points_df] | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("### Group of Anchor Point Control with Confidence") | |
| point_groups_df = gr.Dataframe( | |
| headers=["start", "end", "interval", "feature", "value", "weight"], | |
| datatype=["number", "number", "number", "number", "number", "number"], | |
| # label="Group of Anchor Point Control", | |
| value=[[0, 50, 10, 0, 0.5, 0.1], [100, 150, 50, 0, 0.1, 0.5]], | |
| col_count=(6, "fixed"), # Fix number of columns | |
| interactive=True | |
| ) | |
| add_point_group_btn = gr.Button("Add Point Group") | |
| def add_point_group(df): | |
| new_row = pd.DataFrame([[None, None, None, 0, None, None]], | |
| columns=["start", "end", "interval", "feature", "value", "weight"]) | |
| return pd.concat([df, new_row], ignore_index=True) | |
| add_point_group_btn.click( | |
| fn=add_point_group, | |
| inputs=[point_groups_df], | |
| outputs=[point_groups_df] | |
| ) | |
| with gr.Group(): | |
| # with gr.Tab("Trending Control"): | |
| gr.Markdown("### Trending Control") | |
| gr.Markdown(""" | |
| Enter trending control parameters in the format: | |
| ``` | |
| start_time,end_time,feature,function,confidence | |
| ``` | |
| Examples: | |
| - Linear trend: `0,100,0,x` | |
| - Sine wave: `0,100,0,sin(2*pi*x)` | |
| - Exponential: `0,100,0,exp(-x)` | |
| Separate multiple trends with semicolons. | |
| """) | |
| enable_trending_control = gr.Checkbox(label="Enable Trending Control", value=False) | |
| enable_trending_control_with_diff = gr.Checkbox(label="Consider Last Generated", value=False) | |
| trending_control = gr.Textbox( | |
| label="Trending Control Parameters", | |
| lines=2, | |
| placeholder="Enter parameters: start_time,end_time,feature,function,condifdence; separated by semicolons", | |
| value="200,250,0,sin(2*pi*x),0.05" | |
| ) | |
| # Area Control Parameters | |
| with gr.Group(visible=False): | |
| gr.Markdown("### Area Control") | |
| enable_area_control = gr.Checkbox(label="Enable Area Control", value=False) | |
| area_selections = gr.Textbox( | |
| label="Area Selections (format: start_time,end_time,feature,target_value)", | |
| lines=2, | |
| placeholder="Enter areas: start,end,feature,target; separated by semicolons", | |
| ) | |
| # AUC Parameters | |
| gr.Markdown("### Statistics Control") | |
| enable_auc = gr.Checkbox(label="Enable Total Sum Control", value=True) | |
| auc_input = gr.Number(label="Target Sum Value", value=-150) | |
| auc_weight_input = gr.Number(label="Sum Weight", value=10.0) | |
| # Peak Parameters | |
| with gr.Group(visible=False): | |
| gr.Markdown("### Peak Control") | |
| enable_peaks = gr.Checkbox(label="Enable Peak Control", value=False) | |
| peak_points_input = gr.Textbox(label="Peak Points (comma-separated)", value="100,200") | |
| peak_alpha_input = gr.Number(label="Peak Alpha", value=10) | |
| peak_weight_input = gr.Number(label="Peak Weight", value=1.0) | |
| update_model_btn = gr.Button("Update Figure") | |
| gr.Markdown("## Extend Edit", visible=False) | |
| with gr.Tab("Range Shift", visible=False): | |
| # gr.Markdown("### Direct Edit Control") | |
| enable_direct_area = gr.Checkbox(label="Enable Direct Edits", value=False) # range shift | |
| direct_areas = gr.Textbox( | |
| label="Direct Edit Areas (format: start_time,end_time,feature,delta)", | |
| lines=2, | |
| placeholder="Enter areas: start,end,feature,delta; separated by semicolons", | |
| value="150,200,0,-0.1" | |
| ) | |
| update_additional_btn = gr.Button("Update Additional Edit") | |
| # with gr.Tab("Trending Control"): | |
| # gr.Markdown("### Trending Control") | |
| # gr.Markdown(""" | |
| # Enter trending control parameters in the format: | |
| # ``` | |
| # start_time,end_time,feature,function | |
| # ``` | |
| # Examples: | |
| # - Linear trend: `0,100,0,x` | |
| # - Sine wave: `0,100,0,sin(2*pi*x)` | |
| # - Exponential: `0,100,0,exp(-x)` | |
| # Separate multiple trends with semicolons. | |
| # """) | |
| # enable_trending_control = gr.Checkbox(label="Enable Trending Control", value=False) | |
| # enable_trending_control_with_diff = gr.Checkbox(label="Consider Last Generated", value=False) | |
| # trending_control = gr.Textbox( | |
| # label="Trending Control Parameters", | |
| # lines=2, | |
| # placeholder="Enter parameters: start_time,end_time,feature,function,condifdence; separated by semicolons", | |
| # value="0,100,0,sin(2*pi*x),0.3" | |
| # ) | |
| # with gr.Tab("Frequency Controls", visible=False): | |
| with gr.Group(visible=False): | |
| gr.Markdown("Adjust multipliers for different frequency bands (0-2)") | |
| freq_bands = [ | |
| gr.Slider( | |
| minimum=0, maximum=2, step=0.1, value=1.0, | |
| label=f"Band {i+1}: {'Very High' if i==0 else 'High' if i==1 else 'Mid' if i==2 else 'Low' if i==3 else 'Very Low'} Freq", | |
| ) for i in range(5) | |
| ] | |
| gr.Markdown("### Feature Index Reference:") | |
| for idx, name in enumerate(editor.feature_names): | |
| gr.Markdown(f"- {idx}: {name}") | |
| with gr.Column(scale=1.2): | |
| gr.Markdown(""" | |
| ### Plot Legend | |
| - **Points with Error Bars**: Observed values where: | |
| - Point position = observed value | |
| - Error bar size = uncertainty (inversely proportional to weight) | |
| - **Green Line**: Model prediction | |
| - **Vertical Red Lines**: Peak points (if enabled) | |
| """) | |
| plots = [gr.Plot() for _ in range(editor.feature_dim)] | |
| # - **Shaded Area**: General prediction uncertainty | |
| def update_scaling_callback(revenue_scale, download_scale, au_scale, show_normalized): | |
| figs, metrics = editor.update_scaling( | |
| revenue_scale, | |
| download_scale, | |
| au_scale, | |
| show_normalized | |
| ) | |
| return [*figs, metrics] | |
| def update_model_callback( | |
| data_points_df, | |
| point_groups_df, | |
| enable_area_control, | |
| area_selections, | |
| enable_auc, | |
| auc, | |
| auc_weight, | |
| enable_peaks, | |
| peak_points, | |
| peak_alpha, | |
| peak_weight, | |
| enable_trending, | |
| enable_trending_with_diff, | |
| trending_params | |
| ): | |
| figs, _, _, metrics = editor.update_model( | |
| plots, | |
| data_points_df, | |
| point_groups_df, | |
| enable_area_control, | |
| area_selections, | |
| enable_auc, | |
| auc, | |
| enable_peaks, | |
| peak_points, | |
| peak_alpha, | |
| auc_weight, | |
| peak_weight, | |
| enable_trending, | |
| enable_trending_with_diff, | |
| trending_params | |
| ) | |
| return [*figs, metrics] | |
| # Update the click handler | |
| update_model_btn.click( | |
| fn=update_model_callback, | |
| inputs=[ | |
| data_points_df, | |
| point_groups_df, | |
| enable_area_control, | |
| area_selections, | |
| enable_auc, | |
| auc_input, | |
| auc_weight_input, | |
| enable_peaks, | |
| peak_points_input, | |
| peak_alpha_input, | |
| peak_weight_input, | |
| enable_trending_control, | |
| enable_trending_control_with_diff, | |
| trending_control | |
| ], | |
| outputs=[*plots, metrics_display] | |
| ) | |
| def update_additional_callback(enable_direct_area, direct_areas): | |
| figs, metrics = editor.update_additional_edit( | |
| enable_direct_area, | |
| direct_areas | |
| ) | |
| return [*figs, metrics] | |
| def update_freq_band(band_idx, value): | |
| figs, metrics = editor.update_frequency_bands(band_idx, value) | |
| return [*figs, metrics] | |
| update_scaling_btn.click( | |
| fn=update_scaling_callback, | |
| inputs=[ | |
| revenue_scale, | |
| download_scale, | |
| au_scale, | |
| show_normalized | |
| ], | |
| outputs=[*plots, metrics_display] | |
| ) | |
| update_additional_btn.click( | |
| fn=update_additional_callback, | |
| inputs=[enable_direct_area, direct_areas], | |
| outputs=[*plots, metrics_display] | |
| ) | |
| # Add event handlers for frequency band sliders | |
| for i, slider in enumerate(freq_bands): | |
| slider.change( | |
| fn=update_freq_band, | |
| inputs=[gr.Number(value=i, visible=False), slider], | |
| outputs=[*plots, metrics_display] | |
| ) | |
| app.load( | |
| fn=update_model_callback, | |
| inputs=[ | |
| data_points_df, | |
| point_groups_df, | |
| enable_area_control, | |
| area_selections, | |
| enable_auc, | |
| auc_input, | |
| auc_weight_input, | |
| enable_peaks, | |
| peak_points_input, | |
| peak_alpha_input, | |
| peak_weight_input, | |
| enable_trending_control, | |
| enable_trending_control_with_diff, | |
| trending_control | |
| ], | |
| outputs=[*plots, metrics_display] | |
| ) | |
| return app | |
| class FunctionParser: | |
| def __init__(self): | |
| # Define available mathematical functions and constants | |
| self.math_functions = { | |
| 'sin': np.sin, | |
| 'cos': np.cos, | |
| 'tan': np.tan, | |
| 'exp': np.exp, | |
| 'log': np.log, | |
| 'sqrt': np.sqrt, | |
| 'abs': np.abs, | |
| 'pow': np.power, | |
| 'pi': np.pi, | |
| 'e': np.e, | |
| 'asin': np.arcsin, | |
| 'acos': np.arccos, | |
| 'atan': np.arctan, | |
| 'sinh': np.sinh, | |
| 'cosh': np.cosh, | |
| 'tanh': np.tanh | |
| } | |
| def validate_expression(self, expression: str) -> bool: | |
| """ | |
| Validate the mathematical expression for basic syntax errors. | |
| """ | |
| # Check for balanced parentheses | |
| if expression.count('(') != expression.count(')'): | |
| raise ValueError("Unbalanced parentheses in expression") | |
| # Check for invalid characters | |
| valid_chars = set('0123456789.+-*/()^ xXepi,') | |
| valid_chars.update(''.join(self.math_functions.keys())) | |
| if not all(c in valid_chars or c.isspace() for c in expression.lower()): | |
| raise ValueError("Expression contains invalid characters") | |
| return True | |
| def preprocess_expression(self, expression: str) -> str: | |
| """ | |
| Preprocess the expression to handle various input formats. | |
| """ | |
| # Remove whitespace | |
| expression = expression.replace(' ', '') | |
| # Convert ^ to ** for exponentiation | |
| expression = expression.replace('^', '**') | |
| # Ensure multiplication is explicit | |
| expression = re.sub(r'(\d+)([a-zA-Z])', r'\1*\2', expression) | |
| expression = re.sub(r'(\))([\w])', r'\1*\2', expression) | |
| # Replace X with x for consistency | |
| expression = expression.lower() | |
| return expression | |
| def string_to_function(self, expression: str) -> Callable[[Union[float, np.ndarray]], Union[float, np.ndarray]]: | |
| """ | |
| Convert a string mathematical expression to a callable function. | |
| Args: | |
| expression (str): Mathematical expression (e.g., "sin(x) + x^2") | |
| Returns: | |
| Callable: A function that takes x as input and returns the evaluated result | |
| Example: | |
| >>> f = string_to_function("sin(x) + x^2") | |
| >>> f(0.5) | |
| 0.729321... | |
| """ | |
| # Validate and preprocess the expression | |
| self.validate_expression(expression) | |
| processed_expr = self.preprocess_expression(expression) | |
| # Create the function namespace | |
| namespace = self.math_functions.copy() | |
| try: | |
| # Create the lambda function | |
| func = eval(f"lambda x: {processed_expr}", namespace) | |
| # Test the function with a simple input | |
| test_value = 1.0 | |
| try: | |
| func(test_value) | |
| except Exception as e: | |
| raise ValueError(f"Invalid function: {str(e)}") | |
| return func | |
| except SyntaxError as e: | |
| raise ValueError(f"Invalid expression syntax: {str(e)}") | |
| except Exception as e: | |
| raise ValueError(f"Error creating function: {str(e)}") | |
| def demonstrate_usage(): | |
| """ | |
| Demonstrate various uses of the function parser. | |
| """ | |
| parser = FunctionParser() | |
| # Test cases | |
| test_expressions = [ | |
| "x^2 + 2*x + 1", | |
| "sin(x) + cos(x)", | |
| "exp(-x^2)", | |
| "log(x + 1)", | |
| "sqrt(1 - x^2)", | |
| ] | |
| print("Testing various mathematical expressions:") | |
| x_test = 0.5 | |
| for expr in test_expressions: | |
| try: | |
| print(f"\nExpression: {expr}") | |
| func = parser.string_to_function(expr) | |
| result = func(x_test) | |
| print(f"f({x_test}) = {result}") | |
| # Test with numpy array | |
| x_array = np.linspace(0, 1, 5) | |
| results = func(x_array) | |
| print(f"f(array) = {results}") | |
| except Exception as e: | |
| print(f"Error: {str(e)}") | |
| # Example usage: | |
| if __name__ == "__main__": | |
| import os | |
| import torch | |
| import numpy as np | |
| # assert torch.cuda.is_available(), "CUDA must be available" | |
| os.environ["WANDB_ENABLED"] = "false" | |
| print(os.getcwd()) | |
| device = torch.device(f"cuda:0") if torch.cuda.is_available() else "cpu" | |
| print(f"Device: {device}") | |
| print(f"Using device: {device}") | |
| from models.Tiffusion import tiffusion | |
| model = tiffusion.Tiffusion( | |
| seq_length=365, | |
| feature_size=3, | |
| n_layer_enc=6, | |
| n_layer_dec=4, | |
| d_model=128, | |
| timesteps=500, | |
| sampling_timesteps=200, | |
| loss_type='l1', | |
| beta_schedule='cosine', | |
| n_heads=8, | |
| mlp_hidden_times=4, | |
| attn_pd=0.0, | |
| resid_pd=0.0, | |
| kernel_size=1, | |
| padding_size=0, | |
| control_signal=[] | |
| ).to(device) | |
| model.load_state_dict(torch.load("./weight/checkpoint-10.pt", map_location=device, weights_only=True)["model"]) | |
| coef = 1.0e-2 | |
| stepsize = 5.0e-2 | |
| sampling_steps = 100 # Adjustable between 100-500 for speed/accuracy tradeoff | |
| seq_length = 365 | |
| feature_dim = 3 | |
| print(f"seq_length: {seq_length}, feature_dim: {feature_dim}") | |
| editor = TimeSeriesEditor(seq_length, feature_dim, model) | |
| editor.coef = coef | |
| editor.stepsize = stepsize | |
| editor.sampling_steps = sampling_steps | |
| app = create_gradio_interface(editor) | |
| app.launch(show_api=False) | |