import gradio import huggingface_hub import spaces import psaiops.common.model import psaiops.common.style import psaiops.common.tokenizer import psaiops.score.human.ux as _ux import psaiops.score.human.app as _app # META ######################################################################### # additional args used when loading the model _MODEL_CFG = {} _LOAD_CFG = {'repo_type': 'model', 'ignore_patterns': ['*.onnx', '*.tflite', '*.msgpack'],} _app.MODELS = ['qwen/qwen3.5-9b', 'qwen/qwen3.5-27b'] # frontload the models on the CPU to avoid downloading them from the GPU slot for __m in _app.MODELS: huggingface_hub.snapshot_download(repo_id=__m, **_LOAD_CFG) # but do not instantiate unless necessary _MODEL = None _TOKENIZER = psaiops.common.tokenizer.get_tokenizer(name=_app.MODELS[0]) # CURRENT ###################################################################### def current_selection() -> dict: return gradio.update(value=_app.MODELS[0], choices=_app.MODELS) def switch_selection(name: str) -> None: _app.MODELS = [name] + list(set(_app.MODELS) - {name}) # LAZY ######################################################################### def fetch_model() -> object: global _MODEL # control when the model is downloaded to avoid moving it to the CPU if _MODEL is None: # the first item in the list is always the selected model _MODEL = psaiops.common.model.get_model(name=_app.MODELS[0], device='cuda', **_MODEL_CFG) # give some feedback if hasattr(_MODEL, 'name_or_path'): gradio.Info(title='Info', message='Successfully switched to `{}`.'.format(getattr(_MODEL, 'name_or_path', 'None')), duration=2) else: gradio.Warning(title='Warning', message='The GPU time slot expired before the model could be loaded.', duration=4) # model object or `None` return _MODEL def fetch_tokenizer() -> object: global _TOKENIZER # not strictly necessary, but symmetry is everything if _TOKENIZER is None: _TOKENIZER = psaiops.common.tokenizer.get_tokenizer(name=_app.MODELS[0]) # tokenizer object or `None` return _TOKENIZER # SWITCH ####################################################################### @spaces.GPU(duration=30) def switch_model( model_str: str ) -> dict: global _MODEL, _TOKENIZER # end early if the selection isn't changed if model_str == _app.MODELS[0]: return current_selection() # reorder the model list so that the selected model is at index 0 switch_selection(name=model_str) # free the memory allocated to the previous model psaiops.common.model.free_memory(model=_MODEL) # reset the pointers _MODEL = None _TOKENIZER = None # load the selected model try: fetch_model() except: pass # return the reordered list of models even if the loading failed return current_selection() # TOKENS ####################################################################### def compute_tokens( prompt_str: str, export_str: str, ) -> object: # do not download the model without the GPU wrapper __tokenizer = fetch_tokenizer() # fill all the arguments that cannot be pickled return _ux.update_tokens_state( prompt_str=prompt_str, export_str=export_str, tokenizer_obj=__tokenizer) # INDICES ###################################################################### def compute_indices( prompt_str: str, export_str: str, ) -> object: # do not download the model without the GPU wrapper __tokenizer = fetch_tokenizer() # fill all the arguments that cannot be pickled return _ux.update_indices_state( prompt_str=prompt_str, export_str=export_str, tokenizer_obj=__tokenizer) # LOGITS ####################################################################### @spaces.GPU(duration=30) def compute_logits( indices_arr: object, export_str: str, ) -> object: __logits = None # load the model inside the GPU wrapper (not before) __model = fetch_model() # the allocation might expire before the calculations are finished try: __logits = _ux.update_logits_state( indices_arr=indices_arr, export_str=export_str, model_obj=__model) except: gradio.Warning(title='Warning', message='Calculations aborted because the GPU allocation expired.', duration=4) # tensor or None return __logits # CREATE ####################################################################### demo = _app.create_app( current=current_selection, switch=switch_model, partition=compute_tokens, convert=compute_indices, compute=compute_logits, models=_app.MODELS, export='') # LAUNCH ####################################################################### demo.queue() demo.launch(theme=gradio.themes.Soft(), css=psaiops.common.style.ALL, share=False, debug=False)