File size: 6,165 Bytes
8863499
 
34ec432
8863499
9bfe19a
8863499
 
 
 
 
29c6a08
 
8863499
29c6a08
8863499
 
 
 
 
 
29c6a08
3da8a03
8863499
3da8a03
ff973d2
 
8863499
 
3da8a03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8863499
3da8a03
 
 
 
 
 
 
8863499
e56e240
34ec432
10d0e59
190e12e
34ec432
 
 
10d0e59
34ec432
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10d0e59
34ec432
 
29c6a08
 
 
 
 
 
 
 
 
 
 
 
3da8a03
190e12e
29c6a08
 
190e12e
29c6a08
 
 
190e12e
ff973d2
190e12e
ff973d2
190e12e
f200aaf
34ec432
 
 
 
 
190e12e
 
 
f200aaf
34ec432
190e12e
 
10d0e59
190e12e
10d0e59
190e12e
10d0e59
190e12e
 
 
34ec432
 
 
 
 
 
 
 
 
 
 
 
29c6a08
34ec432
 
 
 
 
 
 
 
190e12e
3da8a03
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
import os
import spaces

# Set the device for model inference
# This will automatically use the GPU if one is available and configured
device = "cuda" if torch.cuda.is_available() else "cpu"

MODEL = "pszemraj/medgemma-27b-text-heretic_med"

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL, local_files_only=True)

# Load the model and ensure it's on the correct device
# We've added `load_in_8bit=True` to reduce the memory footprint.
# We've also added `offload_folder` to explicitly enable disk offloading
# for the model when it can't fit into VRAM or system RAM.
model = AutoModelForCausalLM.from_pretrained(
    MODEL,
    dtype=torch.bfloat16,
    device_map="auto",
    # load_in_8bit=True,
    offload_folder="./offload_dir",
    local_files_only=True,
)

if False:
    def chat_interface(message, history):
        """
        Main chat function to interact with the model.
        """
        chat_history = list(history)
        
        # Add the current user message to the chat history
        chat_history.append({"role": "user", "content": message})
        
        # Apply the tokenizer's chat template
        prompt = tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)
        
        # Generate the response
        input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
        outputs = model.generate(
            input_ids.to(device),
            max_new_tokens=256,
            do_sample=True,
            temperature=0.7,
            top_k=50,
            top_p=0.95
        )
        
        # Decode the response and extract the model's part
        response = tokenizer.decode(outputs[0])
        response = response.split("<end_of_turn>")[1].strip()
        
        return response
    
    # Create the Gradio interface
    gr.ChatInterface(
        fn=chat_interface,
        type="messages",
        title="MedGemma-4B-IT Medical Assistant",
        description="A fine-tuned model for medical-related questions."
    ).launch(share=True)

@spaces.GPU(duration=60)
def extend(text, max_new_tokens, chunk_size, progress=gr.Progress()):
    PREFIX = "<bos>\n" # Model just repeats the last token without this
    progress(0, desc="Tokenizing...")
    token_ids = tokenizer.encode(PREFIX + text, add_special_tokens=False, return_tensors="pt")
    past_key_values = DynamicCache(config=model.config)
    done_tokens = 0
    try:
        # Generate in loop to allow it to be interrupted
        while done_tokens < max_new_tokens:
            progress(done_tokens / max_new_tokens, desc="Generating...")
            chunk_max_new_tokens = min(chunk_size, max_new_tokens - done_tokens)
            
            new_ids = model.generate(
                token_ids.to(device),
                max_new_tokens=chunk_max_new_tokens,
                do_sample=True,
                temperature=0.7,
                top_k=50,
                top_p=0.95,
                past_key_values=past_key_values, # continue from where we left off
            )
            
            chunk_new_tokens = new_ids.shape[1] - token_ids.shape[1]
            if chunk_new_tokens < chunk_max_new_tokens:
                break # Model decided to stop early
            done_tokens += chunk_new_tokens
            token_ids = new_ids

            (unwrapped_new_ids,) = new_ids
            new_text = tokenizer.decode(unwrapped_new_ids).removeprefix(PREFIX)
            if not new_text.startswith(text):
                yield text, "New text somehow deleted existing text!\n\n" + new_text
                return
            yield new_text, f"New tokens generated: {done_tokens}/{max_new_tokens}"
    except Exception as e:
        yield text, f"# ERROR: {e!r}"

DEBUG_ENABLED = False

if DEBUG_ENABLED:
    def debug(cmd):
        """Run `result.append(...)` to display values."""
        result = []
        exec(cmd, globals(), locals())
        return repr(result)
else:
    def debug(x):
        """Debug print the input."""
        return repr(x)

with gr.Blocks() as demo:
    gr.Markdown("# Medical Text Generation")
    gr.Markdown(f"Model in use: {MODEL}")
    with gr.Tab("Extend"):
        gr.Markdown("Enter some medical text, and press Generate to continue it.")
        gr.Markdown("To allow interrupting the generation, it occurs in chunks, remembering the KV cache (only during the generation, not currently across executions).")
        gr.Markdown("Raising the chunk size will increase latency, but might make it go faster by reducing overhead.")
        document = gr.Code(
            language="markdown",
            interactive=True,
            wrap_lines=True,
        )
        max_new_tokens = gr.Slider(label="Max New Tokens", minimum=10, maximum=8192, step=10, value=128)
        chunk_size = gr.Slider(label="Streaming Chunk Size", minimum=1, maximum=100, step=1, value=5)
        with gr.Row():
            generate_button = gr.Button("Generate")
            abort_button = gr.Button("Abort")
        generate_event = generate_button.click(
            fn=extend,
            inputs=[
                document,
                max_new_tokens,
                chunk_size,
            ],
            outputs=[
                document,
                gr.Code(
                    label="Status",
                    language="markdown",
                    interactive=False,
                    wrap_lines=True,
                ),
            ],
            show_progress="minimal",
        )
        abort_button.click(
            fn=None,
            inputs=None,
            outputs=None,
            cancels=[generate_event],
        )
    with gr.Tab("Debug"):
        gr.Interface(
            fn=debug,
            inputs=[gr.Code(
                label=debug.__doc__,
                language="python",
                interactive=True,
                wrap_lines=True,
            )],
            outputs=[gr.Code(
                language="python",
                wrap_lines=True,
            )],
        )
demo.launch()