File size: 2,959 Bytes
95f56e7
049f9d7
 
 
6f9a41b
049f9d7
3edb8f8
049f9d7
f98a297
049f9d7
 
1880559
 
13b12f6
bd52465
13b12f6
049f9d7
1880559
049f9d7
 
826eeb2
9e0d64b
049f9d7
 
 
826eeb2
 
049f9d7
826eeb2
 
049f9d7
826eeb2
049f9d7
826eeb2
 
049f9d7
826eeb2
 
049f9d7
826eeb2
049f9d7
 
 
c3773b4
 
 
 
 
 
 
826eeb2
c3773b4
4284fca
c3773b4
 
 
 
 
 
049f9d7
 
 
 
 
 
 
 
 
 
 
 
 
 
826eeb2
 
049f9d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19bf528
c3773b4
049f9d7
2a45a7b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio
import spaces

import psaiops.common.model
import psaiops.common.style
import psaiops.common.tokenizer
import psaiops.compose.contrast.app as app
import psaiops.compose.contrast.lib as lib

# META #########################################################################

app.MODEL = 'qwen/qwen3.5-9b'

# additional args to use when loading the model
_CONFIG = {}

# frontload the model on the CPU to avoid downloading it from the GPU slot
psaiops.common.model.get_model(name=app.MODEL, device='cpu', **_CONFIG)

# but do not instantiate unless necessary
_MODEL = None
_TOKENIZER = psaiops.common.tokenizer.get_tokenizer(name=app.MODEL)

# LAZY #########################################################################

def fetch_model() -> object:
    global _MODEL
    # control when the model is downloaded to avoid moving it to the CPU
    if _MODEL is None:
        _MODEL = psaiops.common.model.get_model(name=app.MODEL, device='cuda', **_CONFIG)
    # tuple of objects or (None, None)
    return _MODEL

def fetch_tokenizer() -> object:
    global _TOKENIZER
    # not strictly necessary, but symmetry is everything
    if _TOKENIZER is None:
        _TOKENIZER = psaiops.common.tokenizer.get_tokenizer(name=app.MODEL)
    # tuple of objects or (None, None)
    return _TOKENIZER

# EVENT HANDLERS ###############################################################

def update_table_data(
    positive: str,
    negative: str,
    prompt: str,
    output: str,
) -> object:
    # do not download the model without the GPU wrapper
    __tokenizer = fetch_tokenizer()
    # fill all the arguments that cannot be pickled
    return app.update_table_data(
        positive=positive,
        negative=negative,
        prompt=prompt,
        output=output,
        tokenizer=__tokenizer)

@spaces.GPU(duration=30)
def steer_model_output(
    positive_str: str,
    negative_str: str,
    prompt_str: str,
    positive_rate: float,
    negative_rate: float,
    prompt_rate: float,
    token_num: int,
    topk_num: int,
    topp_num: float,
    layer_idx: int,
) -> str:
    # load the model and tokenizer inside the GPU wrapper
    __model = fetch_model()
    __tokenizer = fetch_tokenizer()
    # fill all the arguments that cannot be pickled
    return lib.steer_model_output(
        positive_str=positive_str,
        negative_str=negative_str,
        prompt_str=prompt_str,
        positive_rate=positive_rate,
        negative_rate=negative_rate,
        prompt_rate=prompt_rate,
        token_num=token_num,
        topk_num=topk_num,
        topp_num=topp_num,
        layer_idx=layer_idx,
        device_str='cuda',
        model_obj=__model,
        tokenizer_obj=__tokenizer)

# MAIN #########################################################################

demo = app.create_app(compute=steer_model_output, tabulate=update_table_data)
demo.queue()
demo.launch(theme=gradio.themes.Soft(), css=psaiops.common.style.ALL, debug=False)