Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| import pickle | |
| import gzip | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from scipy import stats | |
| from gluformer.model import Gluformer | |
| from utils.darts_processing import * | |
| from utils.darts_dataset import * | |
| import hashlib | |
| from urllib.parse import urlparse | |
| from huggingface_hub import hf_hub_download | |
| import plotly.graph_objects as go | |
| import gradio as gr | |
| from format_dexcom import * | |
| from typing import Tuple, Union, List | |
| from plotly.graph_objs._figure import Figure | |
| from gradio.components import Slider | |
| from gradio.components import Markdown | |
| glucose = Path(os.path.abspath(__file__)).parent.resolve() | |
| file_directory = glucose / "files" | |
| def plot_forecast(forecasts: np.ndarray, filename: str,ind:int=10) -> Tuple[Path, Figure]: | |
| forecasts = (forecasts - scalers['target'].min_) / scalers['target'].scale_ | |
| trues = [dataset_test_glufo.evalsample(i) for i in range(len(dataset_test_glufo))] | |
| trues = scalers['target'].inverse_transform(trues) | |
| trues = [ts.values() for ts in trues] # Convert TimeSeries to numpy arrays | |
| trues = np.array(trues) | |
| inputs = [dataset_test_glufo[i][0] for i in range(len(dataset_test_glufo))] | |
| inputs = (np.array(inputs) - scalers['target'].min_) / scalers['target'].scale_ | |
| # Select a specific sample to plot | |
| samples = np.random.normal( | |
| loc=forecasts[ind, :], # Mean (center) of the distribution | |
| scale=0.1, # Standard deviation (spread) of the distribution | |
| size=(forecasts.shape[1], forecasts.shape[2]) | |
| ) | |
| # Create figure | |
| fig = go.Figure() | |
| # Plot predictive distribution | |
| for point in range(samples.shape[0]): | |
| kde = stats.gaussian_kde(samples[point,:]) | |
| maxi, mini = 1.2 * np.max(samples[point, :]), 0.8 * np.min(samples[point, :]) | |
| y_grid = np.linspace(mini, maxi, 200) | |
| x = kde(y_grid) | |
| # Create gradient color | |
| color = f'rgba(53, 138, 217, {(point + 1) / samples.shape[0]})' | |
| # Add filled area | |
| fig.add_trace(go.Scatter( | |
| x=np.concatenate([np.full_like(y_grid, point), np.full_like(y_grid, point - x * 15)[::-1]]), | |
| y=np.concatenate([y_grid, y_grid[::-1]]), | |
| fill='tonexty', | |
| fillcolor=color, | |
| line=dict(color='rgba(0,0,0,0)'), | |
| showlegend=False | |
| )) | |
| true_values = np.concatenate([inputs[ind, -12:], trues[ind, :]]) | |
| true_values_flat=true_values.flatten() | |
| fig.add_trace(go.Scatter( | |
| x=list(range(-12, 12)), | |
| y=true_values_flat.tolist(), # Convert to list explicitly | |
| mode='lines+markers', | |
| line=dict(color='blue', width=2), | |
| marker=dict(size=6), | |
| name='True Values' | |
| )) | |
| # Plot median | |
| forecast = samples[:, :] | |
| median = np.quantile(forecast, 0.5, axis=-1) | |
| fig.add_trace(go.Scatter( | |
| x=list(range(12)), | |
| y=median.tolist(), # Convert to list explicitly | |
| mode='lines+markers', | |
| line=dict(color='red', width=2), | |
| marker=dict(size=8), | |
| name='Median Forecast' | |
| )) | |
| # Update layout | |
| fig.update_layout( | |
| title='Gluformer Prediction with Gradient for dataset', | |
| xaxis_title='Time (in 5 minute intervals)', | |
| yaxis_title='Glucose (mg/dL)', | |
| font=dict(size=14), | |
| showlegend=True, | |
| width=1000, | |
| height=600 | |
| ) | |
| # Save figure | |
| where = file_directory / filename | |
| fig.write_html(str(where.with_suffix('.html'))) | |
| fig.write_image(str(where)) | |
| return where, fig | |
| def generate_filename_from_url(url: str, extension: str = "png") -> str: | |
| """ | |
| :param url: | |
| :param extension: | |
| :return: | |
| """ | |
| # Extract the last segment of the URL | |
| last_segment = urlparse(url).path.split('/')[-1] | |
| # Compute the hash of the URL | |
| url_hash = hashlib.md5(url.encode('utf-8')).hexdigest() | |
| # Create the filename | |
| filename = f"{last_segment.replace('.','_')}_{url_hash}.{extension}" | |
| return filename | |
| glufo = None | |
| scalers = None | |
| dataset_test_glufo = None | |
| filename = None | |
| def prep_predict_glucose_tool(file: Union[str, Path], model_name: str = "gluformer_1samples_10000epochs_10heads_32batch_geluactivation_livia_mini_weights.pth") -> Tuple[Slider, Markdown]: | |
| """ | |
| Function to predict future glucose of user. | |
| """ | |
| global formatter, series, scalers, glufo, dataset_test_glufo, filename | |
| model = "Livia-Zaharia/gluformer_models" | |
| model_path = hf_hub_download(repo_id=model, filename=model_name) | |
| formatter, series, scalers = load_data( | |
| url=str(file), | |
| config_path=file_directory / "config.yaml", | |
| use_covs=True, | |
| cov_type='dual', | |
| use_static_covs=True | |
| ) | |
| formatter.params['gluformer'] = { | |
| 'in_len': 96, # example input length, adjust as necessary | |
| 'd_model': 512, # model dimension | |
| 'n_heads': 10, # number of attention heads######################## | |
| 'd_fcn': 1024, # fully connected layer dimension | |
| 'num_enc_layers': 2, # number of encoder layers | |
| 'num_dec_layers': 2, # number of decoder layers | |
| 'length_pred': 12 # prediction length, adjust as necessary represents 1 h | |
| } | |
| num_dynamic_features = series['train']['future'][-1].n_components | |
| num_static_features = series['train']['static'][-1].n_components | |
| global glufo | |
| glufo = Gluformer( | |
| d_model=formatter.params['gluformer']['d_model'], | |
| n_heads=formatter.params['gluformer']['n_heads'], | |
| d_fcn=formatter.params['gluformer']['d_fcn'], | |
| r_drop=0.2, | |
| activ='gelu', | |
| num_enc_layers=formatter.params['gluformer']['num_enc_layers'], | |
| num_dec_layers=formatter.params['gluformer']['num_dec_layers'], | |
| distil=True, | |
| len_seq=formatter.params['gluformer']['in_len'], | |
| label_len=formatter.params['gluformer']['in_len'] // 3, | |
| len_pred=formatter.params['length_pred'], | |
| num_dynamic_features=num_dynamic_features, | |
| num_static_features=num_static_features | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| glufo.load_state_dict(torch.load(str(model_path), map_location=torch.device(device))) | |
| global dataset_test_glufo | |
| dataset_test_glufo = SamplingDatasetInferenceDual( | |
| target_series=series['test']['target'], | |
| covariates=series['test']['future'], | |
| input_chunk_length=formatter.params['gluformer']['in_len'], | |
| output_chunk_length=formatter.params['length_pred'], | |
| use_static_covariates=True, | |
| array_output_only=True | |
| ) | |
| global filename | |
| filename = generate_filename_from_url(file) | |
| max_index = len(dataset_test_glufo) - 1 | |
| print(f"Total number of test samples: {max_index + 1}") | |
| return ( | |
| gr.Slider( | |
| minimum=0, | |
| maximum=max_index-1, | |
| value=max_index, | |
| step=1, | |
| label="Select Sample Index", | |
| visible=True | |
| ), | |
| gr.Markdown(f"Total number of test samples: {max_index + 1}", visible=True) | |
| ) | |
| def predict_glucose_tool(ind: int) -> Figure: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| forecasts, _ = glufo.predict( | |
| dataset_test_glufo, | |
| batch_size=16,####### | |
| num_samples=10, | |
| device=device | |
| ) | |
| figure_path, result = plot_forecast(forecasts,filename,ind) | |
| return result | |