Spaces:
Sleeping
Sleeping
Update my_model/results/demo.py
Browse files- my_model/results/demo.py +8 -8
my_model/results/demo.py
CHANGED
|
@@ -96,7 +96,7 @@ class ResultDemonstrator:
|
|
| 96 |
|
| 97 |
def plot_token_count_vs_scores(self, conf: str, model_name: str, score_name: str = 'VQA Score') -> None:
|
| 98 |
"""
|
| 99 |
-
Plots an interactive scatter plot comparing token
|
| 100 |
|
| 101 |
Args:
|
| 102 |
conf (str): The configuration name.
|
|
@@ -125,13 +125,13 @@ class ResultDemonstrator:
|
|
| 125 |
legend_map = ['Correct' if score == 1 else 'Incorrect' for score in scores]
|
| 126 |
color_scale = alt.Scale(domain=['Correct', 'Incorrect'], range=['green', 'red'])
|
| 127 |
|
| 128 |
-
# Retrieve token
|
| 129 |
-
|
| 130 |
|
| 131 |
# Create a DataFrame for the scatter plot
|
| 132 |
scatter_data = pd.DataFrame({
|
| 133 |
-
'Index': range(len(
|
| 134 |
-
'Token Count':
|
| 135 |
score_name: legend_map
|
| 136 |
})
|
| 137 |
|
|
@@ -143,14 +143,14 @@ class ResultDemonstrator:
|
|
| 143 |
stroke='black' # Sets the border color to black
|
| 144 |
).encode(
|
| 145 |
x=alt.X('Index', scale=alt.Scale(domain=[0, 1020])),
|
| 146 |
-
y=alt.Y('Token
|
| 147 |
color=alt.Color(score_name, scale=color_scale, legend=alt.Legend(title=score_name)),
|
| 148 |
-
tooltip=['Index', 'Token
|
| 149 |
).interactive() # Enables zoom & pan
|
| 150 |
|
| 151 |
chart = chart.properties(
|
| 152 |
title={
|
| 153 |
-
"text": f"Token
|
| 154 |
"color": "black", # Optional color
|
| 155 |
"fontSize": 20, # Optional font size
|
| 156 |
"anchor": "middle", # Optional anchor position
|
|
|
|
| 96 |
|
| 97 |
def plot_token_count_vs_scores(self, conf: str, model_name: str, score_name: str = 'VQA Score') -> None:
|
| 98 |
"""
|
| 99 |
+
Plots an interactive scatter plot comparing token count to VQA or EM scores using Altair.
|
| 100 |
|
| 101 |
Args:
|
| 102 |
conf (str): The configuration name.
|
|
|
|
| 125 |
legend_map = ['Correct' if score == 1 else 'Incorrect' for score in scores]
|
| 126 |
color_scale = alt.Scale(domain=['Correct', 'Incorrect'], range=['green', 'red'])
|
| 127 |
|
| 128 |
+
# Retrieve token count from the data
|
| 129 |
+
token_count = self.main_data[f'tokens_count_{conf}']
|
| 130 |
|
| 131 |
# Create a DataFrame for the scatter plot
|
| 132 |
scatter_data = pd.DataFrame({
|
| 133 |
+
'Index': range(len(token_count)),
|
| 134 |
+
'Token Count': token_count,
|
| 135 |
score_name: legend_map
|
| 136 |
})
|
| 137 |
|
|
|
|
| 143 |
stroke='black' # Sets the border color to black
|
| 144 |
).encode(
|
| 145 |
x=alt.X('Index', scale=alt.Scale(domain=[0, 1020])),
|
| 146 |
+
y=alt.Y('Token Count', scale=alt.Scale(domain=[token_count.min()-200, token_count.max()+200])),
|
| 147 |
color=alt.Color(score_name, scale=color_scale, legend=alt.Legend(title=score_name)),
|
| 148 |
+
tooltip=['Index', 'Token Count', score_name]
|
| 149 |
).interactive() # Enables zoom & pan
|
| 150 |
|
| 151 |
chart = chart.properties(
|
| 152 |
title={
|
| 153 |
+
"text": f"Token Count vs {score_name} ({model_configuration})",
|
| 154 |
"color": "black", # Optional color
|
| 155 |
"fontSize": 20, # Optional font size
|
| 156 |
"anchor": "middle", # Optional anchor position
|