de-generate / app.py
apehex
Fix syntax bug.
8513e01
import gradio
import huggingface_hub
import spaces
import psaiops.common.model
import psaiops.common.style
import psaiops.common.tokenizer
import psaiops.score.human.ux as _ux
import psaiops.score.human.app as _app
# META #########################################################################
# additional args used when loading the model
_MODEL_CFG = {}
_LOAD_CFG = {'repo_type': 'model', 'ignore_patterns': ['*.onnx', '*.tflite', '*.msgpack'],}
_app.MODELS = ['qwen/qwen3.5-9b', 'qwen/qwen3.5-27b']
# frontload the models on the CPU to avoid downloading them from the GPU slot
for __m in _app.MODELS:
huggingface_hub.snapshot_download(repo_id=__m, **_LOAD_CFG)
# but do not instantiate unless necessary
_MODEL = None
_TOKENIZER = psaiops.common.tokenizer.get_tokenizer(name=_app.MODELS[0])
# CURRENT ######################################################################
def current_selection() -> dict:
return gradio.update(value=_app.MODELS[0], choices=_app.MODELS)
def switch_selection(name: str) -> None:
_app.MODELS = [name] + list(set(_app.MODELS) - {name})
# LAZY #########################################################################
def fetch_model() -> object:
global _MODEL
# control when the model is downloaded to avoid moving it to the CPU
if _MODEL is None:
# the first item in the list is always the selected model
_MODEL = psaiops.common.model.get_model(name=_app.MODELS[0], device='cuda', **_MODEL_CFG)
# give some feedback
if hasattr(_MODEL, 'name_or_path'):
gradio.Info(title='Info', message='Successfully switched to `{}`.'.format(getattr(_MODEL, 'name_or_path', 'None')), duration=2)
else:
gradio.Warning(title='Warning', message='The GPU time slot expired before the model could be loaded.', duration=4)
# model object or `None`
return _MODEL
def fetch_tokenizer() -> object:
global _TOKENIZER
# not strictly necessary, but symmetry is everything
if _TOKENIZER is None:
_TOKENIZER = psaiops.common.tokenizer.get_tokenizer(name=_app.MODELS[0])
# tokenizer object or `None`
return _TOKENIZER
# SWITCH #######################################################################
@spaces.GPU(duration=30)
def switch_model(
model_str: str
) -> dict:
global _MODEL, _TOKENIZER
# end early if the selection isn't changed
if model_str == _app.MODELS[0]:
return current_selection()
# reorder the model list so that the selected model is at index 0
switch_selection(name=model_str)
# free the memory allocated to the previous model
psaiops.common.model.free_memory(model=_MODEL)
# reset the pointers
_MODEL = None
_TOKENIZER = None
# load the selected model
try:
fetch_model()
except:
pass
# return the reordered list of models even if the loading failed
return current_selection()
# TOKENS #######################################################################
def compute_tokens(
prompt_str: str,
export_str: str,
) -> object:
# do not download the model without the GPU wrapper
__tokenizer = fetch_tokenizer()
# fill all the arguments that cannot be pickled
return _ux.update_tokens_state(
prompt_str=prompt_str,
export_str=export_str,
tokenizer_obj=__tokenizer)
# INDICES ######################################################################
def compute_indices(
prompt_str: str,
export_str: str,
) -> object:
# do not download the model without the GPU wrapper
__tokenizer = fetch_tokenizer()
# fill all the arguments that cannot be pickled
return _ux.update_indices_state(
prompt_str=prompt_str,
export_str=export_str,
tokenizer_obj=__tokenizer)
# LOGITS #######################################################################
@spaces.GPU(duration=30)
def compute_logits(
indices_arr: object,
export_str: str,
) -> object:
__logits = None
# load the model inside the GPU wrapper (not before)
__model = fetch_model()
# the allocation might expire before the calculations are finished
try:
__logits = _ux.update_logits_state(
indices_arr=indices_arr,
export_str=export_str,
model_obj=__model)
except:
gradio.Warning(title='Warning', message='Calculations aborted because the GPU allocation expired.', duration=4)
# tensor or None
return __logits
# CREATE #######################################################################
demo = _app.create_app(
current=current_selection,
switch=switch_model,
partition=compute_tokens,
convert=compute_indices,
compute=compute_logits,
models=_app.MODELS,
export='')
# LAUNCH #######################################################################
demo.queue()
demo.launch(theme=gradio.themes.Soft(), css=psaiops.common.style.ALL, share=False, debug=False)