Spaces:
Running
Running
jhj0517
commited on
Commit
·
cee12df
1
Parent(s):
250b9b4
Handle gradio None values
Browse files
modules/whisper/data_classes.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
-
from typing import Optional, Dict, List
|
| 4 |
from pydantic import BaseModel, Field, field_validator, ConfigDict
|
| 5 |
from gradio_i18n import Translate, gettext as _
|
| 6 |
from enum import Enum
|
|
@@ -241,7 +241,7 @@ class WhisperParams(BaseParams):
|
|
| 241 |
default=True,
|
| 242 |
description="Suppress blank outputs at start of sampling"
|
| 243 |
)
|
| 244 |
-
suppress_tokens: Optional[str] = Field(default=
|
| 245 |
max_initial_timestamp: float = Field(
|
| 246 |
default=0.0,
|
| 247 |
ge=0.0,
|
|
@@ -279,6 +279,20 @@ class WhisperParams(BaseParams):
|
|
| 279 |
from modules.utils.constants import AUTOMATIC_DETECTION
|
| 280 |
return None if v == AUTOMATIC_DETECTION.unwrap() else v
|
| 281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
@classmethod
|
| 283 |
def to_gradio_inputs(cls,
|
| 284 |
defaults: Optional[Dict] = None,
|
|
@@ -301,7 +315,7 @@ class WhisperParams(BaseParams):
|
|
| 301 |
gr.Dropdown(
|
| 302 |
label=_("Language"),
|
| 303 |
choices=available_langs,
|
| 304 |
-
value=defaults.get("lang",
|
| 305 |
),
|
| 306 |
gr.Checkbox(
|
| 307 |
label=_("Translate to English?"),
|
|
@@ -407,7 +421,7 @@ class WhisperParams(BaseParams):
|
|
| 407 |
),
|
| 408 |
gr.Textbox(
|
| 409 |
label="Suppress Tokens",
|
| 410 |
-
value=defaults.get("suppress_tokens",
|
| 411 |
info="Token IDs to suppress"
|
| 412 |
),
|
| 413 |
gr.Number(
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
+
from typing import Optional, Dict, List, Union
|
| 4 |
from pydantic import BaseModel, Field, field_validator, ConfigDict
|
| 5 |
from gradio_i18n import Translate, gettext as _
|
| 6 |
from enum import Enum
|
|
|
|
| 241 |
default=True,
|
| 242 |
description="Suppress blank outputs at start of sampling"
|
| 243 |
)
|
| 244 |
+
suppress_tokens: Optional[Union[List, str]] = Field(default=[-1], description="Token IDs to suppress")
|
| 245 |
max_initial_timestamp: float = Field(
|
| 246 |
default=0.0,
|
| 247 |
ge=0.0,
|
|
|
|
| 279 |
from modules.utils.constants import AUTOMATIC_DETECTION
|
| 280 |
return None if v == AUTOMATIC_DETECTION.unwrap() else v
|
| 281 |
|
| 282 |
+
@field_validator('suppress_tokens')
|
| 283 |
+
def validate_supress_tokens(cls, v):
|
| 284 |
+
import ast
|
| 285 |
+
try:
|
| 286 |
+
if isinstance(v, str):
|
| 287 |
+
suppress_tokens = ast.literal_eval(v)
|
| 288 |
+
if not isinstance(suppress_tokens, list):
|
| 289 |
+
raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
|
| 290 |
+
return suppress_tokens
|
| 291 |
+
if isinstance(v, list):
|
| 292 |
+
return v
|
| 293 |
+
except Exception as e:
|
| 294 |
+
raise ValueError(f"Invalid Suppress Tokens. The value must be type of List[int]: {e}")
|
| 295 |
+
|
| 296 |
@classmethod
|
| 297 |
def to_gradio_inputs(cls,
|
| 298 |
defaults: Optional[Dict] = None,
|
|
|
|
| 315 |
gr.Dropdown(
|
| 316 |
label=_("Language"),
|
| 317 |
choices=available_langs,
|
| 318 |
+
value=defaults.get("lang", AUTOMATIC_DETECTION),
|
| 319 |
),
|
| 320 |
gr.Checkbox(
|
| 321 |
label=_("Translate to English?"),
|
|
|
|
| 421 |
),
|
| 422 |
gr.Textbox(
|
| 423 |
label="Suppress Tokens",
|
| 424 |
+
value=defaults.get("suppress_tokens", "[-1]"),
|
| 425 |
info="Token IDs to suppress"
|
| 426 |
),
|
| 427 |
gr.Number(
|