Spaces:
Sleeping
Sleeping
Add: Allow repetition mode option.
Browse files
app.py
CHANGED
|
@@ -14,15 +14,20 @@ def load_conette(*args, **kwargs) -> CoNeTTEModel:
|
|
| 14 |
return conette(*args, **kwargs)
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
def main() -> None:
|
| 18 |
st.header("Describe audio content with CoNeTTE")
|
| 19 |
|
| 20 |
model = load_conette(model_kwds=dict(device="cpu"))
|
| 21 |
|
| 22 |
task = st.selectbox("Task embedding input", model.tasks, 0)
|
|
|
|
| 23 |
beam_size: int = st.select_slider( # type: ignore
|
| 24 |
"Beam size",
|
| 25 |
-
list(range(1,
|
| 26 |
model.config.beam_size,
|
| 27 |
)
|
| 28 |
min_pred_size: int = st.select_slider( # type: ignore
|
|
@@ -36,7 +41,7 @@ def main() -> None:
|
|
| 36 |
model.config.max_pred_size,
|
| 37 |
)
|
| 38 |
|
| 39 |
-
st.
|
| 40 |
audios = st.file_uploader(
|
| 41 |
"Upload an audio file",
|
| 42 |
type=["wav", "flac", "mp3", "ogg", "avi"],
|
|
@@ -49,11 +54,22 @@ def main() -> None:
|
|
| 49 |
temp.write(audio.getvalue())
|
| 50 |
fpath = temp.name
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
kwargs: dict[str, Any] = dict(
|
| 53 |
task=task,
|
| 54 |
beam_size=beam_size,
|
| 55 |
min_pred_size=min_pred_size,
|
| 56 |
max_pred_size=max_pred_size,
|
|
|
|
| 57 |
)
|
| 58 |
cand_key = f"{audio.name}-{kwargs}"
|
| 59 |
|
|
@@ -67,8 +83,8 @@ def main() -> None:
|
|
| 67 |
cand = outputs["cands"][0]
|
| 68 |
st.session_state[cand_key] = cand
|
| 69 |
|
| 70 |
-
st.
|
| 71 |
-
st.
|
| 72 |
|
| 73 |
|
| 74 |
if __name__ == "__main__":
|
|
|
|
| 14 |
return conette(*args, **kwargs)
|
| 15 |
|
| 16 |
|
| 17 |
+
def format_cand(cand: str) -> str:
|
| 18 |
+
return f"{cand[0].title()}{cand[1:]}."
|
| 19 |
+
|
| 20 |
+
|
| 21 |
def main() -> None:
|
| 22 |
st.header("Describe audio content with CoNeTTE")
|
| 23 |
|
| 24 |
model = load_conette(model_kwds=dict(device="cpu"))
|
| 25 |
|
| 26 |
task = st.selectbox("Task embedding input", model.tasks, 0)
|
| 27 |
+
allow_rep_mode = st.selectbox("Allow repetition of words", ["stopwords", "all", "none"], 0)
|
| 28 |
beam_size: int = st.select_slider( # type: ignore
|
| 29 |
"Beam size",
|
| 30 |
+
list(range(1, 21)),
|
| 31 |
model.config.beam_size,
|
| 32 |
)
|
| 33 |
min_pred_size: int = st.select_slider( # type: ignore
|
|
|
|
| 41 |
model.config.max_pred_size,
|
| 42 |
)
|
| 43 |
|
| 44 |
+
st.markdown("Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz**.")
|
| 45 |
audios = st.file_uploader(
|
| 46 |
"Upload an audio file",
|
| 47 |
type=["wav", "flac", "mp3", "ogg", "avi"],
|
|
|
|
| 54 |
temp.write(audio.getvalue())
|
| 55 |
fpath = temp.name
|
| 56 |
|
| 57 |
+
if allow_rep_mode == "all":
|
| 58 |
+
forbid_rep_mode = "none"
|
| 59 |
+
elif allow_rep_mode == "none":
|
| 60 |
+
forbid_rep_mode = "all"
|
| 61 |
+
elif allow_rep_mode == "stopwords":
|
| 62 |
+
forbid_rep_mode = "content_words"
|
| 63 |
+
else:
|
| 64 |
+
ALLOW_REP_MODES = ("all", "none", "stopwords")
|
| 65 |
+
raise ValueError(f"Unknown option {allow_rep_mode=}. (expected one of {ALLOW_REP_MODES})")
|
| 66 |
+
|
| 67 |
kwargs: dict[str, Any] = dict(
|
| 68 |
task=task,
|
| 69 |
beam_size=beam_size,
|
| 70 |
min_pred_size=min_pred_size,
|
| 71 |
max_pred_size=max_pred_size,
|
| 72 |
+
forbid_rep_mode=forbid_rep_mode,
|
| 73 |
)
|
| 74 |
cand_key = f"{audio.name}-{kwargs}"
|
| 75 |
|
|
|
|
| 83 |
cand = outputs["cands"][0]
|
| 84 |
st.session_state[cand_key] = cand
|
| 85 |
|
| 86 |
+
st.markdown(f"Output for {audio.name}:")
|
| 87 |
+
st.markdown(f" - red[{format_cand(cand)}]")
|
| 88 |
|
| 89 |
|
| 90 |
if __name__ == "__main__":
|