apehex
Disable the debugging mode.
2a45a7b
import gradio
import spaces
import psaiops.common.model
import psaiops.common.style
import psaiops.common.tokenizer
import psaiops.compose.contrast.app as app
import psaiops.compose.contrast.lib as lib
# META #########################################################################
app.MODEL = 'qwen/qwen3.5-9b'
# additional args to use when loading the model
_CONFIG = {}
# frontload the model on the CPU to avoid downloading it from the GPU slot
psaiops.common.model.get_model(name=app.MODEL, device='cpu', **_CONFIG)
# but do not instantiate unless necessary
_MODEL = None
_TOKENIZER = psaiops.common.tokenizer.get_tokenizer(name=app.MODEL)
# LAZY #########################################################################
def fetch_model() -> object:
global _MODEL
# control when the model is downloaded to avoid moving it to the CPU
if _MODEL is None:
_MODEL = psaiops.common.model.get_model(name=app.MODEL, device='cuda', **_CONFIG)
# tuple of objects or (None, None)
return _MODEL
def fetch_tokenizer() -> object:
global _TOKENIZER
# not strictly necessary, but symmetry is everything
if _TOKENIZER is None:
_TOKENIZER = psaiops.common.tokenizer.get_tokenizer(name=app.MODEL)
# tuple of objects or (None, None)
return _TOKENIZER
# EVENT HANDLERS ###############################################################
def update_table_data(
positive: str,
negative: str,
prompt: str,
output: str,
) -> object:
# do not download the model without the GPU wrapper
__tokenizer = fetch_tokenizer()
# fill all the arguments that cannot be pickled
return app.update_table_data(
positive=positive,
negative=negative,
prompt=prompt,
output=output,
tokenizer=__tokenizer)
@spaces.GPU(duration=30)
def steer_model_output(
positive_str: str,
negative_str: str,
prompt_str: str,
positive_rate: float,
negative_rate: float,
prompt_rate: float,
token_num: int,
topk_num: int,
topp_num: float,
layer_idx: int,
) -> str:
# load the model and tokenizer inside the GPU wrapper
__model = fetch_model()
__tokenizer = fetch_tokenizer()
# fill all the arguments that cannot be pickled
return lib.steer_model_output(
positive_str=positive_str,
negative_str=negative_str,
prompt_str=prompt_str,
positive_rate=positive_rate,
negative_rate=negative_rate,
prompt_rate=prompt_rate,
token_num=token_num,
topk_num=topk_num,
topp_num=topp_num,
layer_idx=layer_idx,
device_str='cuda',
model_obj=__model,
tokenizer_obj=__tokenizer)
# MAIN #########################################################################
demo = app.create_app(compute=steer_model_output, tabulate=update_table_data)
demo.queue()
demo.launch(theme=gradio.themes.Soft(), css=psaiops.common.style.ALL, debug=False)