File size: 4,969 Bytes
3e8a402
2699bf3
3e8a402
899575c
3e8a402
cbb58dd
3e8a402
4a0d54e
 
2431602
3e8a402
 
 
1ce5a7b
2699bf3
 
0958d5d
1ce5a7b
0958d5d
 
2699bf3
3e8a402
 
165f637
0958d5d
 
 
 
 
 
 
 
 
3e8a402
 
 
165f637
 
3e8a402
165f637
cfe9e70
2699bf3
cfe9e70
 
8513e01
cfe9e70
 
 
165f637
3e8a402
165f637
 
3e8a402
165f637
0958d5d
cfe9e70
165f637
3e8a402
0958d5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1767d1
0958d5d
 
 
2431602
3e8a402
2431602
 
4a0d54e
3e8a402
 
165f637
3e8a402
4a0d54e
2431602
4a0d54e
2431602
 
 
 
 
 
4a0d54e
3e8a402
 
165f637
3e8a402
4a0d54e
2431602
4a0d54e
2431602
 
 
 
75344c6
2431602
 
4a0d54e
3e8a402
75344c6
 
165f637
75344c6
 
4a0d54e
75344c6
4a0d54e
75344c6
 
 
 
 
3e8a402
4a0d54e
 
 
 
 
 
 
 
 
 
3e8a402
4a0d54e
3e8a402
 
4a0d54e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio
import huggingface_hub
import spaces

import psaiops.common.model
import psaiops.common.style
import psaiops.common.tokenizer

import psaiops.score.human.ux as _ux
import psaiops.score.human.app as _app

# META #########################################################################

# additional args used when loading the model
_MODEL_CFG = {}
_LOAD_CFG = {'repo_type': 'model', 'ignore_patterns': ['*.onnx', '*.tflite', '*.msgpack'],}
_app.MODELS = ['qwen/qwen3.5-9b', 'qwen/qwen3.5-27b']

# frontload the models on the CPU to avoid downloading them from the GPU slot
for __m in _app.MODELS:
    huggingface_hub.snapshot_download(repo_id=__m, **_LOAD_CFG)

# but do not instantiate unless necessary
_MODEL = None
_TOKENIZER = psaiops.common.tokenizer.get_tokenizer(name=_app.MODELS[0])

# CURRENT ######################################################################

def current_selection() -> dict:
    return gradio.update(value=_app.MODELS[0], choices=_app.MODELS)

def switch_selection(name: str) -> None:
    _app.MODELS = [name] + list(set(_app.MODELS) - {name})

# LAZY #########################################################################

def fetch_model() -> object:
    global _MODEL
    # control when the model is downloaded to avoid moving it to the CPU
    if _MODEL is None:
        # the first item in the list is always the selected model
        _MODEL = psaiops.common.model.get_model(name=_app.MODELS[0], device='cuda', **_MODEL_CFG)
        # give some feedback
        if hasattr(_MODEL, 'name_or_path'):
            gradio.Info(title='Info', message='Successfully switched to `{}`.'.format(getattr(_MODEL, 'name_or_path', 'None')), duration=2)
        else:
            gradio.Warning(title='Warning', message='The GPU time slot expired before the model could be loaded.', duration=4)
    # model object or `None`
    return _MODEL

def fetch_tokenizer() -> object:
    global _TOKENIZER
    # not strictly necessary, but symmetry is everything
    if _TOKENIZER is None:
        _TOKENIZER = psaiops.common.tokenizer.get_tokenizer(name=_app.MODELS[0])
    # tokenizer object or `None`
    return _TOKENIZER

# SWITCH #######################################################################

@spaces.GPU(duration=30)
def switch_model(
    model_str: str
) -> dict:
    global _MODEL, _TOKENIZER
    # end early if the selection isn't changed
    if model_str == _app.MODELS[0]:
        return current_selection()
    # reorder the model list so that the selected model is at index 0
    switch_selection(name=model_str)
    # free the memory allocated to the previous model
    psaiops.common.model.free_memory(model=_MODEL)
    # reset the pointers
    _MODEL = None
    _TOKENIZER = None
    # load the selected model
    try:
        fetch_model()
    except:
        pass
    # return the reordered list of models even if the loading failed
    return current_selection()

# TOKENS #######################################################################

def compute_tokens(
    prompt_str: str,
    export_str: str,
) -> object:
    # do not download the model without the GPU wrapper
    __tokenizer = fetch_tokenizer()
    # fill all the arguments that cannot be pickled
    return _ux.update_tokens_state(
        prompt_str=prompt_str,
        export_str=export_str,
        tokenizer_obj=__tokenizer)

# INDICES ######################################################################

def compute_indices(
    prompt_str: str,
    export_str: str,
) -> object:
    # do not download the model without the GPU wrapper
    __tokenizer = fetch_tokenizer()
    # fill all the arguments that cannot be pickled
    return _ux.update_indices_state(
        prompt_str=prompt_str,
        export_str=export_str,
        tokenizer_obj=__tokenizer)

# LOGITS #######################################################################

@spaces.GPU(duration=30)
def compute_logits(
    indices_arr: object,
    export_str: str,
) -> object:
    __logits = None
    # load the model inside the GPU wrapper (not before)
    __model = fetch_model()
    # the allocation might expire before the calculations are finished
    try:
        __logits = _ux.update_logits_state(
            indices_arr=indices_arr,
            export_str=export_str,
            model_obj=__model)
    except:
        gradio.Warning(title='Warning', message='Calculations aborted because the GPU allocation expired.', duration=4)
    # tensor or None
    return __logits

# CREATE #######################################################################

demo = _app.create_app(
    current=current_selection,
    switch=switch_model,
    partition=compute_tokens,
    convert=compute_indices,
    compute=compute_logits,
    models=_app.MODELS,
    export='')

# LAUNCH #######################################################################

demo.queue()
demo.launch(theme=gradio.themes.Soft(), css=psaiops.common.style.ALL, share=False, debug=False)