Spaces:
Running
Running
Joshua Lochner
commited on
Commit
·
c415610
1
Parent(s):
85661b3
Use custom caching system for loading models
Browse files
app.py
CHANGED
|
@@ -40,11 +40,17 @@ st.set_page_config(
|
|
| 40 |
|
| 41 |
# Faster caching system for predictions (No need to hash)
|
| 42 |
@st.cache(persist=True, allow_output_mutation=True)
|
| 43 |
-
def
|
| 44 |
return {}
|
| 45 |
|
| 46 |
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
MODELS = {
|
| 50 |
'Small (293 MB)': {
|
|
@@ -100,23 +106,27 @@ def predict_function(model_id, model, tokenizer, segmentation_args, classifier_a
|
|
| 100 |
return prediction_cache[model_id][video_id]
|
| 101 |
|
| 102 |
|
| 103 |
-
@st.cache(persist=True, allow_output_mutation=True)
|
| 104 |
def load_predict(model_id):
|
| 105 |
model_info = MODELS[model_id]
|
| 106 |
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
-
|
| 113 |
-
model.to(device())
|
| 114 |
|
| 115 |
-
|
| 116 |
|
| 117 |
-
|
|
|
|
| 118 |
|
| 119 |
-
return
|
| 120 |
|
| 121 |
|
| 122 |
def main():
|
|
|
|
| 40 |
|
| 41 |
# Faster caching system for predictions (No need to hash)
|
| 42 |
@st.cache(persist=True, allow_output_mutation=True)
|
| 43 |
+
def create_prediction_cache():
|
| 44 |
return {}
|
| 45 |
|
| 46 |
|
| 47 |
+
@st.cache(persist=True, allow_output_mutation=True)
|
| 48 |
+
def create_function_cache():
|
| 49 |
+
return {}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
prediction_cache = create_prediction_cache()
|
| 53 |
+
prediction_function_cache = create_function_cache()
|
| 54 |
|
| 55 |
MODELS = {
|
| 56 |
'Small (293 MB)': {
|
|
|
|
| 106 |
return prediction_cache[model_id][video_id]
|
| 107 |
|
| 108 |
|
|
|
|
| 109 |
def load_predict(model_id):
|
| 110 |
model_info = MODELS[model_id]
|
| 111 |
|
| 112 |
+
if model_id not in prediction_function_cache:
|
| 113 |
+
# Use default segmentation and classification arguments
|
| 114 |
+
evaluation_args = EvaluationArguments(model_path=model_info['repo_id'])
|
| 115 |
+
segmentation_args = SegmentationArguments()
|
| 116 |
+
classifier_args = ClassifierArguments()
|
| 117 |
+
|
| 118 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 119 |
+
evaluation_args.model_path)
|
| 120 |
+
model.to(device())
|
| 121 |
|
| 122 |
+
tokenizer = AutoTokenizer.from_pretrained(evaluation_args.model_path)
|
|
|
|
| 123 |
|
| 124 |
+
download_classifier(classifier_args)
|
| 125 |
|
| 126 |
+
prediction_function_cache[model_id] = partial(
|
| 127 |
+
predict_function, model_id, model, tokenizer, segmentation_args, classifier_args)
|
| 128 |
|
| 129 |
+
return prediction_function_cache[model_id]
|
| 130 |
|
| 131 |
|
| 132 |
def main():
|