Nipun commited on
Commit
803e278
·
1 Parent(s): 2ac8f6e

done some error correction

Browse files
Files changed (3) hide show
  1. pages/Continuous.py +2 -4
  2. pages/Discrete.py +2 -4
  3. utils.py +22 -1
pages/Continuous.py CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
2
  import torch
3
  import plotly.graph_objects as go
4
  from config.config_continuous import CONTINUOUS_DISTRIBUTIONS
5
- from utils import compute_pdf
6
 
7
  st.title("Continuous Distributions")
8
 
@@ -14,7 +14,5 @@ support = support(params) if callable(support) else None
14
 
15
  x_range = torch.linspace(-10, 10, 1000)
16
  pdf = compute_pdf(dist, x_range, support)
 
17
 
18
- fig = go.Figure()
19
- fig.add_trace(go.Scatter(x=x_range.numpy(), y=pdf.numpy(), mode="lines", name=selected_dist))
20
- st.plotly_chart(fig)
 
2
  import torch
3
  import plotly.graph_objects as go
4
  from config.config_continuous import CONTINUOUS_DISTRIBUTIONS
5
+ from utils import compute_pdf, plot_pdf
6
 
7
  st.title("Continuous Distributions")
8
 
 
14
 
15
  x_range = torch.linspace(-10, 10, 1000)
16
  pdf = compute_pdf(dist, x_range, support)
17
+ plot_pdf(pdf, x_range, selected_dist)
18
 
 
 
 
pages/Discrete.py CHANGED
@@ -2,6 +2,7 @@ import streamlit as st
2
  import torch
3
  import plotly.graph_objects as go
4
  from config.config_discrete import DISCRETE_DISTRIBUTIONS
 
5
 
6
  st.title("Discrete Distributions")
7
 
@@ -18,7 +19,4 @@ else:
18
  x_range = torch.arange(0, 20)
19
 
20
  pmf = dist.log_prob(x_range).exp()
21
-
22
- fig = go.Figure()
23
- fig.add_trace(go.Bar(x=x_range.numpy(), y=pmf.numpy(), name=selected_dist))
24
- st.plotly_chart(fig)
 
2
  import torch
3
  import plotly.graph_objects as go
4
  from config.config_discrete import DISCRETE_DISTRIBUTIONS
5
+ from utils import compute_pdf, plot_pmf
6
 
7
  st.title("Discrete Distributions")
8
 
 
19
  x_range = torch.arange(0, 20)
20
 
21
  pmf = dist.log_prob(x_range).exp()
22
+ plot_pmf(pmf, x_range, selected_dist)
 
 
 
utils.py CHANGED
@@ -1,4 +1,6 @@
1
  import torch
 
 
2
 
3
  def compute_pdf(dist, x_range, support):
4
  pdf = torch.zeros_like(x_range, dtype=torch.float)
@@ -12,4 +14,23 @@ def compute_pdf(dist, x_range, support):
12
  mask = torch.ones_like(x_range, dtype=torch.bool)
13
 
14
  pdf[mask] = dist.log_prob(x_range[mask]).exp()
15
- return pdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import streamlit as st
3
+ import plotly.graph_objects as go
4
 
5
  def compute_pdf(dist, x_range, support):
6
  pdf = torch.zeros_like(x_range, dtype=torch.float)
 
14
  mask = torch.ones_like(x_range, dtype=torch.bool)
15
 
16
  pdf[mask] = dist.log_prob(x_range[mask]).exp()
17
+ return pdf
18
+
19
+ def plot_pdf(pdf, x_range, dist_name):
20
+ # Main plot area
21
+ st.markdown(f'#### Probability Density Function for {dist_name} Distribution')
22
+ fig = go.Figure()
23
+ fig.add_trace(go.Scatter(x=x_range, y=pdf, mode='lines',
24
+ name='', hovertemplate='x: %{x:.2f}<br>f(x): %{y:.2f}'))
25
+ fig.update_layout(xaxis_title='x', yaxis_title='f(x)', showlegend=False)
26
+
27
+ st.plotly_chart(fig)
28
+
29
+ def plot_pmf(pmf, x_range, dist_name):
30
+ # Main plot area
31
+ st.markdown(f'#### Probability Mass Function for {dist_name} Distribution')
32
+ fig = go.Figure()
33
+ fig.add_trace(go.Bar(x=x_range.tolist(), y=pmf.tolist()))
34
+ fig.update_layout(xaxis_title='x', yaxis_title='P(x)', showlegend=False)
35
+
36
+ st.plotly_chart(fig)