cryogenic22 commited on
Commit
5251eeb
·
verified ·
1 Parent(s): 3616223

Create utils/visualization.py

Browse files
Files changed (1) hide show
  1. utils/visualization.py +448 -0
utils/visualization.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualization Utility Functions
3
+
4
+ This module provides utility functions for creating common visualizations
5
+ used in pharmaceutical analytics dashboards.
6
+ """
7
+
8
+ import pandas as pd
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+ import seaborn as sns
12
+ import plotly.express as px
13
+ import plotly.graph_objects as go
14
+ from typing import List, Dict, Any, Optional, Tuple, Union
15
+
16
+ def create_trend_chart(
17
+ df: pd.DataFrame,
18
+ date_column: str,
19
+ value_columns: List[str],
20
+ title: str = "Trend Analysis",
21
+ colors: Optional[List[str]] = None,
22
+ markers: bool = True,
23
+ annotations: Optional[List[Dict[str, Any]]] = None,
24
+ height: int = 400
25
+ ) -> go.Figure:
26
+ """
27
+ Create a time series trend chart with Plotly
28
+
29
+ Parameters:
30
+ -----------
31
+ df : DataFrame
32
+ Pandas DataFrame containing the data
33
+ date_column : str
34
+ Name of the column containing dates
35
+ value_columns : List[str]
36
+ List of column names to plot as lines
37
+ title : str
38
+ Chart title
39
+ colors : List[str], optional
40
+ List of colors for each line
41
+ markers : bool
42
+ Whether to show markers on lines
43
+ annotations : List[Dict], optional
44
+ List of annotation dictionaries
45
+ height : int
46
+ Height of the chart in pixels
47
+
48
+ Returns:
49
+ --------
50
+ go.Figure
51
+ Plotly figure object
52
+ """
53
+ # Create figure
54
+ fig = go.Figure()
55
+
56
+ # Default colors if not provided
57
+ if not colors:
58
+ colors = ['blue', 'green', 'red', 'orange', 'purple']
59
+
60
+ # Convert date column to datetime if not already
61
+ if not pd.api.types.is_datetime64_any_dtype(df[date_column]):
62
+ df = df.copy()
63
+ df[date_column] = pd.to_datetime(df[date_column])
64
+
65
+ # Add each value column as a line
66
+ for i, column in enumerate(value_columns):
67
+ color = colors[i % len(colors)]
68
+ mode = 'lines+markers' if markers else 'lines'
69
+
70
+ fig.add_trace(go.Scatter(
71
+ x=df[date_column],
72
+ y=df[column],
73
+ mode=mode,
74
+ name=column,
75
+ line=dict(color=color, width=2)
76
+ ))
77
+
78
+ # Add annotations if provided
79
+ if annotations:
80
+ for annotation in annotations:
81
+ if 'x' in annotation and 'text' in annotation:
82
+ # Convert annotation date to datetime if it's a string
83
+ if isinstance(annotation['x'], str):
84
+ annotation['x'] = pd.to_datetime(annotation['x'])
85
+
86
+ fig.add_vline(
87
+ x=annotation['x'],
88
+ line_dash="dash",
89
+ line_color=annotation.get('color', 'red'),
90
+ annotation_text=annotation['text'],
91
+ annotation_position=annotation.get('position', 'top right')
92
+ )
93
+
94
+ # Update layout
95
+ fig.update_layout(
96
+ title=title,
97
+ xaxis_title=date_column,
98
+ yaxis_title="Value",
99
+ height=height,
100
+ legend=dict(
101
+ orientation="h",
102
+ yanchor="bottom",
103
+ y=1.02,
104
+ xanchor="right",
105
+ x=1
106
+ ),
107
+ margin=dict(l=20, r=20, t=40, b=20)
108
+ )
109
+
110
+ return fig
111
+
112
+ def create_comparison_chart(
113
+ df: pd.DataFrame,
114
+ category_column: str,
115
+ value_columns: List[str],
116
+ title: str = "Comparison Analysis",
117
+ chart_type: str = "bar",
118
+ stacked: bool = False,
119
+ colors: Optional[List[str]] = None,
120
+ height: int = 400,
121
+ horizontal: bool = False
122
+ ) -> go.Figure:
123
+ """
124
+ Create a comparison chart (bar, line, area) with Plotly
125
+
126
+ Parameters:
127
+ -----------
128
+ df : DataFrame
129
+ Pandas DataFrame containing the data
130
+ category_column : str
131
+ Name of the column containing categories
132
+ value_columns : List[str]
133
+ List of column names to plot
134
+ title : str
135
+ Chart title
136
+ chart_type : str
137
+ Type of chart ('bar', 'line', 'area')
138
+ stacked : bool
139
+ Whether to stack the bars/areas
140
+ colors : List[str], optional
141
+ List of colors for each series
142
+ height : int
143
+ Height of the chart in pixels
144
+ horizontal : bool
145
+ If True, create horizontal bar chart
146
+
147
+ Returns:
148
+ --------
149
+ go.Figure
150
+ Plotly figure object
151
+ """
152
+ # Default colors if not provided
153
+ if not colors:
154
+ colors = ['blue', 'green', 'red', 'orange', 'purple']
155
+
156
+ fig = go.Figure()
157
+
158
+ # Determine barmode based on stacked parameter
159
+ barmode = 'stack' if stacked else 'group'
160
+
161
+ # Add each value column as a series
162
+ for i, column in enumerate(value_columns):
163
+ color = colors[i % len(colors)]
164
+
165
+ if chart_type == 'bar':
166
+ if horizontal:
167
+ fig.add_trace(go.Bar(
168
+ y=df[category_column],
169
+ x=df[column],
170
+ name=column,
171
+ marker_color=color,
172
+ orientation='h'
173
+ ))
174
+ else:
175
+ fig.add_trace(go.Bar(
176
+ x=df[category_column],
177
+ y=df[column],
178
+ name=column,
179
+ marker_color=color
180
+ ))
181
+ elif chart_type == 'line':
182
+ fig.add_trace(go.Scatter(
183
+ x=df[category_column],
184
+ y=df[column],
185
+ mode='lines+markers',
186
+ name=column,
187
+ line=dict(color=color)
188
+ ))
189
+ elif chart_type == 'area':
190
+ fig.add_trace(go.Scatter(
191
+ x=df[category_column],
192
+ y=df[column],
193
+ mode='lines',
194
+ name=column,
195
+ fill='tonexty' if stacked else 'none',
196
+ line=dict(color=color)
197
+ ))
198
+
199
+ # Update layout
200
+ x_title = None if horizontal else category_column
201
+ y_title = category_column if horizontal else None
202
+
203
+ fig.update_layout(
204
+ title=title,
205
+ xaxis_title=x_title,
206
+ yaxis_title=y_title,
207
+ barmode=barmode,
208
+ height=height,
209
+ legend=dict(
210
+ orientation="h",
211
+ yanchor="bottom",
212
+ y=1.02,
213
+ xanchor="right",
214
+ x=1
215
+ )
216
+ )
217
+
218
+ return fig
219
+
220
+ def create_heatmap(
221
+ df: pd.DataFrame,
222
+ x_column: str,
223
+ y_column: str,
224
+ value_column: str,
225
+ title: str = "Heatmap Analysis",
226
+ colorscale: str = "Blues",
227
+ height: int = 500,
228
+ width: int = 700,
229
+ text_format: Optional[str] = None
230
+ ) -> go.Figure:
231
+ """
232
+ Create a heatmap with Plotly
233
+
234
+ Parameters:
235
+ -----------
236
+ df : DataFrame
237
+ Pandas DataFrame containing the data
238
+ x_column : str
239
+ Name of the column for x-axis categories
240
+ y_column : str
241
+ Name of the column for y-axis categories
242
+ value_column : str
243
+ Name of the column containing values to plot
244
+ title : str
245
+ Chart title
246
+ colorscale : str
247
+ Colorscale for the heatmap
248
+ height : int
249
+ Height of the chart in pixels
250
+ width : int
251
+ Width of the chart in pixels
252
+ text_format : str, optional
253
+ Format string for text values (e.g., ".1f" for float with 1 decimal)
254
+
255
+ Returns:
256
+ --------
257
+ go.Figure
258
+ Plotly figure object
259
+ """
260
+ # Pivot the data for the heatmap
261
+ pivot_df = df.pivot_table(
262
+ index=y_column,
263
+ columns=x_column,
264
+ values=value_column,
265
+ aggfunc='mean'
266
+ )
267
+
268
+ # Format text values if specified
269
+ text_values = None
270
+ if text_format:
271
+ text_values = pivot_df.applymap(lambda x: f"{x:{text_format}}")
272
+
273
+ # Create heatmap
274
+ fig = px.imshow(
275
+ pivot_df,
276
+ labels=dict(x=x_column, y=y_column, color=value_column),
277
+ x=pivot_df.columns,
278
+ y=pivot_df.index,
279
+ color_continuous_scale=colorscale,
280
+ text_auto=text_format is None, # Auto text if format not specified
281
+ aspect="auto"
282
+ )
283
+
284
+ # Add custom text if format specified
285
+ if text_values is not None:
286
+ fig.update_traces(text=text_values.values, texttemplate="%{text}")
287
+
288
+ # Update layout
289
+ fig.update_layout(
290
+ title=title,
291
+ height=height,
292
+ width=width,
293
+ xaxis=dict(side="bottom"),
294
+ margin=dict(l=20, r=20, t=40, b=20)
295
+ )
296
+
297
+ return fig
298
+
299
+ def create_pie_chart(
300
+ df: pd.DataFrame,
301
+ names_column: str,
302
+ values_column: str,
303
+ title: str = "Distribution Analysis",
304
+ colors: Optional[List[str]] = None,
305
+ hole: float = 0.0,
306
+ height: int = 400
307
+ ) -> go.Figure:
308
+ """
309
+ Create a pie or donut chart with Plotly
310
+
311
+ Parameters:
312
+ -----------
313
+ df : DataFrame
314
+ Pandas DataFrame containing the data
315
+ names_column : str
316
+ Name of the column containing category names
317
+ values_column : str
318
+ Name of the column containing values
319
+ title : str
320
+ Chart title
321
+ colors : List[str], optional
322
+ List of colors for pie slices
323
+ hole : float
324
+ Size of hole for donut chart (0.0 for pie chart)
325
+ height : int
326
+ Height of the chart in pixels
327
+
328
+ Returns:
329
+ --------
330
+ go.Figure
331
+ Plotly figure object
332
+ """
333
+ # Create pie chart
334
+ fig = px.pie(
335
+ df,
336
+ names=names_column,
337
+ values=values_column,
338
+ title=title,
339
+ color_discrete_sequence=colors,
340
+ hole=hole,
341
+ height=height
342
+ )
343
+
344
+ # Update layout
345
+ fig.update_layout(
346
+ margin=dict(l=20, r=20, t=40, b=20),
347
+ legend=dict(
348
+ orientation="h",
349
+ yanchor="bottom",
350
+ y=-0.2,
351
+ xanchor="center",
352
+ x=0.5
353
+ )
354
+ )
355
+
356
+ # Update traces
357
+ fig.update_traces(
358
+ textposition='inside',
359
+ textinfo='percent+label'
360
+ )
361
+
362
+ return fig
363
+
364
+ def create_scatter_plot(
365
+ df: pd.DataFrame,
366
+ x_column: str,
367
+ y_column: str,
368
+ size_column: Optional[str] = None,
369
+ color_column: Optional[str] = None,
370
+ title: str = "Correlation Analysis",
371
+ height: int = 500,
372
+ trendline: bool = False,
373
+ hover_data: Optional[List[str]] = None
374
+ ) -> go.Figure:
375
+ """
376
+ Create a scatter plot with Plotly
377
+
378
+ Parameters:
379
+ -----------
380
+ df : DataFrame
381
+ Pandas DataFrame containing the data
382
+ x_column : str
383
+ Name of the column for x-axis values
384
+ y_column : str
385
+ Name of the column for y-axis values
386
+ size_column : str, optional
387
+ Name of the column for point sizes
388
+ color_column : str, optional
389
+ Name of the column for point colors
390
+ title : str
391
+ Chart title
392
+ height : int
393
+ Height of the chart in pixels
394
+ trendline : bool
395
+ Whether to add a trendline
396
+ hover_data : List[str], optional
397
+ List of column names to include in hover data
398
+
399
+ Returns:
400
+ --------
401
+ go.Figure
402
+ Plotly figure object
403
+ """
404
+ # Create scatter plot
405
+ fig = px.scatter(
406
+ df,
407
+ x=x_column,
408
+ y=y_column,
409
+ size=size_column,
410
+ color=color_column,
411
+ title=title,
412
+ height=height,
413
+ hover_data=hover_data,
414
+ trendline='ols' if trendline else None
415
+ )
416
+
417
+ # Update layout
418
+ fig.update_layout(
419
+ xaxis_title=x_column,
420
+ yaxis_title=y_column,
421
+ margin=dict(l=20, r=20, t=40, b=20)
422
+ )
423
+
424
+ return fig
425
+
426
+ # Example usage
427
+ if __name__ == "__main__":
428
+ # Create sample data
429
+ dates = pd.date_range(start='2023-01-01', periods=12, freq='M')
430
+ data = {
431
+ 'date': dates,
432
+ 'sales': [100, 110, 120, 115, 130, 140, 135, 150, 145, 160, 155, 170],
433
+ 'target': [105, 110, 115, 120, 125, 130, 135, 140, 145, 150, 155, 160],
434
+ 'region': ['Northeast'] * 12
435
+ }
436
+ df = pd.DataFrame(data)
437
+
438
+ # Create trend chart
439
+ fig = create_trend_chart(
440
+ df,
441
+ date_column='date',
442
+ value_columns=['sales', 'target'],
443
+ title='Sales vs Target',
444
+ annotations=[{'x': '2023-06-01', 'text': 'Campaign Launch'}]
445
+ )
446
+
447
+ # Display the chart (in a notebook or Streamlit app)
448
+ print("Trend chart created successfully!")