Spaces:
Running
Running
Joshua Lochner
commited on
Commit
·
3d1c770
1
Parent(s):
36f7534
Update streamlit app to use new classifier
Browse files
app.py
CHANGED
|
@@ -12,11 +12,10 @@ from urllib.parse import quote
|
|
| 12 |
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'src')) # noqa
|
| 13 |
|
| 14 |
from preprocess import get_words
|
| 15 |
-
from predict import
|
| 16 |
-
from
|
| 17 |
-
from shared import seconds_to_time, CATGEGORY_OPTIONS
|
| 18 |
from utils import regex_search
|
| 19 |
-
from model import
|
| 20 |
from errors import TranscriptError
|
| 21 |
|
| 22 |
st.set_page_config(
|
|
@@ -104,7 +103,7 @@ for m in MODELS:
|
|
| 104 |
prediction_cache[m] = {}
|
| 105 |
|
| 106 |
|
| 107 |
-
CLASSIFIER_PATH = 'Xenova/sponsorblock-classifier'
|
| 108 |
|
| 109 |
|
| 110 |
TRANSCRIPT_TYPES = {
|
|
@@ -122,15 +121,15 @@ TRANSCRIPT_TYPES = {
|
|
| 122 |
}
|
| 123 |
|
| 124 |
|
| 125 |
-
def predict_function(model_id, model, tokenizer, segmentation_args,
|
| 126 |
cache_id = f'{video_id}_{ts_type_id}'
|
| 127 |
|
| 128 |
if cache_id not in prediction_cache[model_id]:
|
| 129 |
prediction_cache[model_id][cache_id] = pred(
|
| 130 |
video_id, model, tokenizer,
|
| 131 |
segmentation_args=segmentation_args,
|
| 132 |
-
|
| 133 |
-
|
| 134 |
)
|
| 135 |
return prediction_cache[model_id][cache_id]
|
| 136 |
|
|
@@ -140,15 +139,15 @@ def load_predict(model_id):
|
|
| 140 |
|
| 141 |
if model_id not in prediction_function_cache:
|
| 142 |
# Use default segmentation and classification arguments
|
| 143 |
-
|
|
|
|
| 144 |
segmentation_args = SegmentationArguments()
|
| 145 |
-
classifier_args = ClassifierArguments(
|
| 146 |
-
min_probability=0) # Filtering done later
|
| 147 |
|
| 148 |
-
model, tokenizer =
|
| 149 |
|
| 150 |
prediction_function_cache[model_id] = partial(
|
| 151 |
-
predict_function, model_id, model, tokenizer, segmentation_args,
|
|
|
|
| 152 |
|
| 153 |
return prediction_function_cache[model_id]
|
| 154 |
|
|
@@ -252,7 +251,8 @@ def main():
|
|
| 252 |
|
| 253 |
submit_segments = []
|
| 254 |
for index, prediction in enumerate(predictions, start=1):
|
| 255 |
-
|
|
|
|
| 256 |
continue # Skip
|
| 257 |
|
| 258 |
confidence = prediction['probability'] * 100
|
|
@@ -262,13 +262,13 @@ def main():
|
|
| 262 |
|
| 263 |
submit_segments.append({
|
| 264 |
'segment': [prediction['start'], prediction['end']],
|
| 265 |
-
'category': prediction['category']
|
| 266 |
'actionType': 'skip'
|
| 267 |
})
|
| 268 |
start_time = seconds_to_time(prediction['start'])
|
| 269 |
end_time = seconds_to_time(prediction['end'])
|
| 270 |
with st.expander(
|
| 271 |
-
f"[{
|
| 272 |
):
|
| 273 |
|
| 274 |
url = f"https://www.youtube-nocookie.com/embed/{video_id}?&start={floor(prediction['start'])}&end={ceil(prediction['end'])}"
|
|
@@ -280,7 +280,7 @@ def main():
|
|
| 280 |
text = ' '.join(w['text'] for w in prediction['words'])
|
| 281 |
st.write(f"**Times:** {start_time} \u2192 {end_time}")
|
| 282 |
st.write(
|
| 283 |
-
f"**Category:** {CATGEGORY_OPTIONS[
|
| 284 |
st.write(f"**Confidence:** {confidence:.2f}%")
|
| 285 |
st.write(f'**Text:** "{text}"')
|
| 286 |
|
|
|
|
| 12 |
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'src')) # noqa
|
| 13 |
|
| 14 |
from preprocess import get_words
|
| 15 |
+
from predict import PredictArguments, SegmentationArguments, predict as pred
|
| 16 |
+
from shared import GeneralArguments, seconds_to_time, CATGEGORY_OPTIONS
|
|
|
|
| 17 |
from utils import regex_search
|
| 18 |
+
from model import get_model_tokenizer_classifier
|
| 19 |
from errors import TranscriptError
|
| 20 |
|
| 21 |
st.set_page_config(
|
|
|
|
| 103 |
prediction_cache[m] = {}
|
| 104 |
|
| 105 |
|
| 106 |
+
CLASSIFIER_PATH = 'Xenova/sponsorblock-classifier-v2'
|
| 107 |
|
| 108 |
|
| 109 |
TRANSCRIPT_TYPES = {
|
|
|
|
| 121 |
}
|
| 122 |
|
| 123 |
|
| 124 |
+
def predict_function(model_id, model, tokenizer, segmentation_args, classifier, video_id, words, ts_type_id):
|
| 125 |
cache_id = f'{video_id}_{ts_type_id}'
|
| 126 |
|
| 127 |
if cache_id not in prediction_cache[model_id]:
|
| 128 |
prediction_cache[model_id][cache_id] = pred(
|
| 129 |
video_id, model, tokenizer,
|
| 130 |
segmentation_args=segmentation_args,
|
| 131 |
+
words=words,
|
| 132 |
+
classifier=classifier
|
| 133 |
)
|
| 134 |
return prediction_cache[model_id][cache_id]
|
| 135 |
|
|
|
|
| 139 |
|
| 140 |
if model_id not in prediction_function_cache:
|
| 141 |
# Use default segmentation and classification arguments
|
| 142 |
+
predict_args = PredictArguments(model_name_or_path=model_info['repo_id'])
|
| 143 |
+
general_args = GeneralArguments()
|
| 144 |
segmentation_args = SegmentationArguments()
|
|
|
|
|
|
|
| 145 |
|
| 146 |
+
model, tokenizer, classifier = get_model_tokenizer_classifier(predict_args, general_args)
|
| 147 |
|
| 148 |
prediction_function_cache[model_id] = partial(
|
| 149 |
+
predict_function, model_id, model, tokenizer, segmentation_args, classifier)
|
| 150 |
+
|
| 151 |
|
| 152 |
return prediction_function_cache[model_id]
|
| 153 |
|
|
|
|
| 251 |
|
| 252 |
submit_segments = []
|
| 253 |
for index, prediction in enumerate(predictions, start=1):
|
| 254 |
+
category_key = prediction['category'].upper()
|
| 255 |
+
if category_key not in categories:
|
| 256 |
continue # Skip
|
| 257 |
|
| 258 |
confidence = prediction['probability'] * 100
|
|
|
|
| 262 |
|
| 263 |
submit_segments.append({
|
| 264 |
'segment': [prediction['start'], prediction['end']],
|
| 265 |
+
'category': prediction['category'],
|
| 266 |
'actionType': 'skip'
|
| 267 |
})
|
| 268 |
start_time = seconds_to_time(prediction['start'])
|
| 269 |
end_time = seconds_to_time(prediction['end'])
|
| 270 |
with st.expander(
|
| 271 |
+
f"[{category_key}] Prediction #{index} ({start_time} \u2192 {end_time})"
|
| 272 |
):
|
| 273 |
|
| 274 |
url = f"https://www.youtube-nocookie.com/embed/{video_id}?&start={floor(prediction['start'])}&end={ceil(prediction['end'])}"
|
|
|
|
| 280 |
text = ' '.join(w['text'] for w in prediction['words'])
|
| 281 |
st.write(f"**Times:** {start_time} \u2192 {end_time}")
|
| 282 |
st.write(
|
| 283 |
+
f"**Category:** {CATGEGORY_OPTIONS[category_key]}")
|
| 284 |
st.write(f"**Confidence:** {confidence:.2f}%")
|
| 285 |
st.write(f'**Text:** "{text}"')
|
| 286 |
|