import gradio import spaces import psaiops.common.model import psaiops.common.style import psaiops.common.tokenizer import psaiops.compose.contrast.app as app import psaiops.compose.contrast.lib as lib # META ######################################################################### app.MODEL = 'qwen/qwen3.5-9b' # additional args to use when loading the model _CONFIG = {} # frontload the model on the CPU to avoid downloading it from the GPU slot psaiops.common.model.get_model(name=app.MODEL, device='cpu', **_CONFIG) # but do not instantiate unless necessary _MODEL = None _TOKENIZER = psaiops.common.tokenizer.get_tokenizer(name=app.MODEL) # LAZY ######################################################################### def fetch_model() -> object: global _MODEL # control when the model is downloaded to avoid moving it to the CPU if _MODEL is None: _MODEL = psaiops.common.model.get_model(name=app.MODEL, device='cuda', **_CONFIG) # tuple of objects or (None, None) return _MODEL def fetch_tokenizer() -> object: global _TOKENIZER # not strictly necessary, but symmetry is everything if _TOKENIZER is None: _TOKENIZER = psaiops.common.tokenizer.get_tokenizer(name=app.MODEL) # tuple of objects or (None, None) return _TOKENIZER # EVENT HANDLERS ############################################################### def update_table_data( positive: str, negative: str, prompt: str, output: str, ) -> object: # do not download the model without the GPU wrapper __tokenizer = fetch_tokenizer() # fill all the arguments that cannot be pickled return app.update_table_data( positive=positive, negative=negative, prompt=prompt, output=output, tokenizer=__tokenizer) @spaces.GPU(duration=30) def steer_model_output( positive_str: str, negative_str: str, prompt_str: str, positive_rate: float, negative_rate: float, prompt_rate: float, token_num: int, topk_num: int, topp_num: float, layer_idx: int, ) -> str: # load the model and tokenizer inside the GPU wrapper __model = fetch_model() __tokenizer = fetch_tokenizer() # fill all the arguments that cannot be pickled return lib.steer_model_output( positive_str=positive_str, negative_str=negative_str, prompt_str=prompt_str, positive_rate=positive_rate, negative_rate=negative_rate, prompt_rate=prompt_rate, token_num=token_num, topk_num=topk_num, topp_num=topp_num, layer_idx=layer_idx, device_str='cuda', model_obj=__model, tokenizer_obj=__tokenizer) # MAIN ######################################################################### demo = app.create_app(compute=steer_model_output, tabulate=update_table_data) demo.queue() demo.launch(theme=gradio.themes.Soft(), css=psaiops.common.style.ALL, debug=False)