Spaces:
Build error
Build error
Iskaj
commited on
Commit
·
58b39ad
1
Parent(s):
0afaddb
add doc to plot.py
Browse files
plot.py
CHANGED
|
@@ -8,62 +8,33 @@ from scipy import stats as st
|
|
| 8 |
|
| 9 |
from config import FPS
|
| 10 |
|
| 11 |
-
|
| 12 |
-
def plot_comparison(lims, D, I, hash_vectors, MIN_DISTANCE = 3):
|
| 13 |
-
sns.set_theme()
|
| 14 |
-
|
| 15 |
-
x = [(lims[i+1]-lims[i]) * [i] for i in range(hash_vectors.shape[0])]
|
| 16 |
-
x = [i/FPS for j in x for i in j]
|
| 17 |
-
y = [i/FPS for i in I]
|
| 18 |
-
|
| 19 |
-
# Create figure and dataframe to plot with sns
|
| 20 |
-
fig = plt.figure()
|
| 21 |
-
# plt.tight_layout()
|
| 22 |
-
df = pd.DataFrame(zip(x, y), columns = ['X', 'Y'])
|
| 23 |
-
g = sns.scatterplot(data=df, x='X', y='Y', s=2*(1-D/(MIN_DISTANCE+1)), alpha=1-D/MIN_DISTANCE)
|
| 24 |
-
|
| 25 |
-
# Set x-labels to be more readable
|
| 26 |
-
x_locs, x_labels = plt.xticks() # Get original locations and labels for x ticks
|
| 27 |
-
x_labels = [time.strftime('%H:%M:%S', time.gmtime(x)) for x in x_locs]
|
| 28 |
-
plt.xticks(x_locs, x_labels)
|
| 29 |
-
plt.xticks(rotation=90)
|
| 30 |
-
plt.xlabel('Time in source video (H:M:S)')
|
| 31 |
-
plt.xlim(0, None)
|
| 32 |
-
|
| 33 |
-
# Set y-labels to be more readable
|
| 34 |
-
y_locs, y_labels = plt.yticks() # Get original locations and labels for x ticks
|
| 35 |
-
y_labels = [time.strftime('%H:%M:%S', time.gmtime(y)) for y in y_locs]
|
| 36 |
-
plt.yticks(y_locs, y_labels)
|
| 37 |
-
plt.ylabel('Time in target video (H:M:S)')
|
| 38 |
-
|
| 39 |
-
# Adjust padding to fit gradio
|
| 40 |
-
plt.subplots_adjust(bottom=0.25, left=0.20)
|
| 41 |
-
return fig
|
| 42 |
-
|
| 43 |
-
def plot_multi_comparison(df, change_points):
|
| 44 |
-
""" From the dataframe plot the current set of plots, where the bottom right is most indicative """
|
| 45 |
-
fig, ax_arr = plt.subplots(3, 2, figsize=(12, 6), dpi=100, sharex=True)
|
| 46 |
-
sns.scatterplot(data = df, x='time', y='SOURCE_S', ax=ax_arr[0,0])
|
| 47 |
-
sns.lineplot(data = df, x='time', y='SOURCE_LIP_S', ax=ax_arr[0,1])
|
| 48 |
-
sns.scatterplot(data = df, x='time', y='OFFSET', ax=ax_arr[1,0])
|
| 49 |
-
sns.lineplot(data = df, x='time', y='OFFSET_LIP', ax=ax_arr[1,1])
|
| 50 |
-
|
| 51 |
-
# Plot change point as lines
|
| 52 |
-
sns.lineplot(data = df, x='time', y='OFFSET_LIP', ax=ax_arr[2,1])
|
| 53 |
-
for x in change_points:
|
| 54 |
-
cp_time = x.start_time
|
| 55 |
-
plt.vlines(x=cp_time, ymin=np.min(df['OFFSET_LIP']), ymax=np.max(df['OFFSET_LIP']), colors='red', lw=2)
|
| 56 |
-
rand_y_pos = np.random.uniform(low=np.min(df['OFFSET_LIP']), high=np.max(df['OFFSET_LIP']), size=None)
|
| 57 |
-
plt.text(x=cp_time, y=rand_y_pos, s=str(np.round(x.confidence, 2)), color='r', rotation=-0.0, fontsize=14)
|
| 58 |
-
plt.xticks(rotation=90)
|
| 59 |
-
return fig
|
| 60 |
-
|
| 61 |
def change_points_to_segments(df, change_points):
|
| 62 |
-
""" Convert change points from kats detector to segment indicators
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
return [pd.to_datetime(0.0, unit='s').to_datetime64()] + [cp.start_time for cp in change_points] + [pd.to_datetime(df.iloc[-1]['TARGET_S'], unit='s').to_datetime64()]
|
| 64 |
|
| 65 |
def add_seconds_to_datetime64(datetime64, seconds, subtract=False):
|
| 66 |
-
"""Add or substract a number of seconds to a
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
s, m = divmod(seconds, 1.0)
|
| 68 |
if subtract:
|
| 69 |
return datetime64 - np.timedelta64(int(s), 's') - np.timedelta64(int(m * 1000), 'ms')
|
|
@@ -74,12 +45,18 @@ def plot_segment_comparison(df, change_points, video_mp4 = "Placeholder.mp4", vi
|
|
| 74 |
1. Make a decision on where each segment belongs in time and return that info as a list of dicts
|
| 75 |
2. Plot how this decision got made as an informative plot
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
"""
|
|
|
|
| 83 |
fig, ax_arr = plt.subplots(4, 1, figsize=(16, 6), dpi=300, sharex=True)
|
| 84 |
ax_arr[0].set_title(video_id)
|
| 85 |
sns.scatterplot(data = df, x='time', y='SOURCE_S', ax=ax_arr[0], label="SOURCE_S", color='blue', alpha=1.0)
|
|
@@ -146,11 +123,9 @@ def plot_segment_comparison(df, change_points, video_mp4 = "Placeholder.mp4", vi
|
|
| 146 |
"Source Video .mp4" : video_mp4,
|
| 147 |
"Uncertainty" : np.round(average_diff, 3),
|
| 148 |
"Average Offset in Seconds" : np.round(average_offset, 3),
|
| 149 |
-
# "Explanation" : f"{start_time_str} -> {end_time_str} comes from video with ID={video_id} from {origin_start_time_str} -> {origin_end_time_str}",
|
| 150 |
}
|
| 151 |
segment_decisions.append(decision)
|
| 152 |
seg_i += 1
|
| 153 |
-
# print(decision)
|
| 154 |
|
| 155 |
# Return figure
|
| 156 |
plt.xticks(rotation=90)
|
|
|
|
| 8 |
|
| 9 |
from config import FPS
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
def change_points_to_segments(df, change_points):
|
| 12 |
+
""" Convert change points from kats detector to segment indicators.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
df (DataFrame): Dataframe with information regarding the match between videos.
|
| 16 |
+
change_points ([TimeSeriesChangePoint]): Array of time series change point objects.
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
List of numpy.datetime64 objects where the first element is '0.0 in time' and the final element is the last
|
| 20 |
+
element of the video in time so the segment starts and ends in a logical place.
|
| 21 |
+
|
| 22 |
+
"""
|
| 23 |
return [pd.to_datetime(0.0, unit='s').to_datetime64()] + [cp.start_time for cp in change_points] + [pd.to_datetime(df.iloc[-1]['TARGET_S'], unit='s').to_datetime64()]
|
| 24 |
|
| 25 |
def add_seconds_to_datetime64(datetime64, seconds, subtract=False):
|
| 26 |
+
""" Add or substract a number of seconds to a numpy.datetime64 object.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
datetime64 (numpy.datetime64): Datetime object that we want to increase or decrease by number of seconds.
|
| 30 |
+
seconds (float): Amount of seconds we want to add or subtract.
|
| 31 |
+
subtract (bool): Toggle for subtracting or adding.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
A numpy.datetime64 object.
|
| 35 |
+
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
s, m = divmod(seconds, 1.0)
|
| 39 |
if subtract:
|
| 40 |
return datetime64 - np.timedelta64(int(s), 's') - np.timedelta64(int(m * 1000), 'ms')
|
|
|
|
| 45 |
1. Make a decision on where each segment belongs in time and return that info as a list of dicts
|
| 46 |
2. Plot how this decision got made as an informative plot
|
| 47 |
|
| 48 |
+
Args:
|
| 49 |
+
df (DataFrame): Dataframe with information regarding the match between videos.
|
| 50 |
+
change_points ([TimeSeriesChangePoint]): Array of time series change point objects.
|
| 51 |
+
video_mp4 (str): Name of the source video to return as extra info.
|
| 52 |
+
video_id (str): The unique identifier for the video currently being compared
|
| 53 |
+
threshold_diff (float): Threshold for the average distance to plot which segments are likely bad matches.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
fig (Figure): Figure that shows the comparison between two videos.
|
| 57 |
+
segment_decisions (dict): JSON-style dictionary containing the decision information of the comparison between two videos.
|
| 58 |
"""
|
| 59 |
+
# Plot it with certain characteristics
|
| 60 |
fig, ax_arr = plt.subplots(4, 1, figsize=(16, 6), dpi=300, sharex=True)
|
| 61 |
ax_arr[0].set_title(video_id)
|
| 62 |
sns.scatterplot(data = df, x='time', y='SOURCE_S', ax=ax_arr[0], label="SOURCE_S", color='blue', alpha=1.0)
|
|
|
|
| 123 |
"Source Video .mp4" : video_mp4,
|
| 124 |
"Uncertainty" : np.round(average_diff, 3),
|
| 125 |
"Average Offset in Seconds" : np.round(average_offset, 3),
|
|
|
|
| 126 |
}
|
| 127 |
segment_decisions.append(decision)
|
| 128 |
seg_i += 1
|
|
|
|
| 129 |
|
| 130 |
# Return figure
|
| 131 |
plt.xticks(rotation=90)
|