File size: 7,581 Bytes
249ef4b
 
 
dac6325
249ef4b
 
 
 
ece025a
 
 
 
 
249ef4b
 
 
 
 
 
 
 
 
 
f1a92ec
249ef4b
 
 
 
 
 
 
 
 
 
dac6325
249ef4b
dac6325
b726489
249ef4b
 
 
 
 
 
 
d93c464
dac6325
 
 
249ef4b
16f7f73
249ef4b
dac6325
 
 
249ef4b
dac6325
 
249ef4b
dac6325
249ef4b
dac6325
 
 
 
 
 
 
16f7f73
dac6325
 
 
249ef4b
 
16f7f73
dac6325
16f7f73
dac6325
16f7f73
 
 
 
 
 
dce4935
16f7f73
dce4935
16f7f73
dac6325
 
 
 
 
16f7f73
dac6325
 
 
 
 
16f7f73
dac6325
 
 
 
 
16f7f73
dac6325
 
 
 
 
 
 
 
 
 
 
 
 
16f7f73
 
 
 
 
dac6325
ece025a
 
 
 
 
 
 
249ef4b
16f7f73
249ef4b
16f7f73
249ef4b
dac6325
249ef4b
dac6325
249ef4b
 
 
d93c464
16f7f73
 
 
249ef4b
 
 
 
16f7f73
249ef4b
 
 
16f7f73
249ef4b
dac6325
 
249ef4b
 
 
 
dac6325
249ef4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce34d83
249ef4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import os
from threading import Thread
from typing import Iterator
from PIL import Image

import gradio as gr
import spaces
import torch
from transformers import (
    AutoProcessor,
    Gemma3ForConditionalGeneration,
    TextIteratorStreamer,
)
import subprocess

subprocess.run(
    "pip install flash-attn --no-build-isolation",
    env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
    shell=True,
)


DESCRIPTION = """\
# MamayLM-Gemma-3-12B-IT-v1.0 demo

[🪪 **Model card**](https://huggingface.co/INSAIT-Institute/MamayLM-Gemma-3-12B-IT-v1.0)
"""

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 2048
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

model_id = "INSAIT-Institute/MamayLM-Gemma-3-12B-IT-v1.0"

processor = AutoProcessor.from_pretrained(model_id)
attn_impl = "flash_attention_2" if torch.cuda.is_available() else "eager"

model = Gemma3ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    attn_implementation=attn_impl,
)
model.eval()

eos_token_ids = [1, 106]

print("Model loaded successfully.")


@spaces.GPU(duration=120)  # Increased duration for the larger model
def generate(
    message: dict,
    chat_history: list[list],
    system_message: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.95,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    """
    Generates a response from the model based on the user's message and chat history.
    This function is designed to work with Gradio's multimodal ChatInterface.
    """
    conversation = []
    all_images = []

    # Add system message if provided. This guides the model's behavior.
    if system_message:
        conversation.append(
            {"role": "system", "content": [{"type": "text", "text": system_message}]}
        )

    # Process past turns from Gradio's chat_history
    for user_turn, bot_turn in chat_history:
        # Reconstruct the user's turn, which might include an image
        user_content = []
        print("---")
        print(user_turn)
        print(bot_turn)
        print("---")
        if isinstance(user_turn, tuple):  # User turn with an image
            if len(user_turn) == 1: # TODO: IDK now how to fix this bug
                continue

            img_path, txt = user_turn

            pil_img = Image.open(img_path).convert("RGB")
            all_images.append(pil_img)
            user_content.append({"type": "image"})
            if txt:
                user_content.append({"type": "text", "text": txt})
        elif user_turn:  # Text-only user turn
            user_content.append({"type": "text", "text": user_turn})

        if user_content:
            conversation.append({"role": "user", "content": user_content})

        # Reconstruct the assistant's turn
        if bot_turn:
            conversation.append(
                {"role": "assistant", "content": [{"type": "text", "text": bot_turn}]}
            )

    # Process the current user message, which can include new images
    current_user_content = []
    text = message["text"]
    for img_path in message["files"]:
        pil_img = Image.open(img_path).convert("RGB")
        all_images.append(pil_img)
        current_user_content.append({"type": "image"})

    if text:
        current_user_content.append({"type": "text", "text": text})

    if current_user_content:
        conversation.append({"role": "user", "content": current_user_content})

    print("### DEBUG:")
    print(conversation)
    print("####")

    # Use the processor to create the prompt and preprocess images
    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

    if len(all_images) > 0:
        inputs = processor(text=prompt, images=all_images, return_tensors="pt").to(
            model.device
        )
    else:
        inputs = processor(text=prompt, return_tensors="pt").to(model.device)

    # Set up the streamer for text generation
    streamer = TextIteratorStreamer(
        processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
    )

    generate_kwargs = dict(
        inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        eos_token_id=eos_token_ids,
        temperature=temperature
        if temperature > 0
        else 0.001,  # Temperature must be > 0
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
    )
    # Run generation in a separate thread
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    # Yield generated text chunks
    outputs = []
    for text_chunk in streamer:
        outputs.append(text_chunk)
        yield "".join(outputs)


chat_interface = gr.ChatInterface(
    multimodal=True,
    fn=generate,
    additional_inputs=[
        gr.Textbox(
            value="",
            label="System message",
            render=False,
        ),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0,
            maximum=4.0,
            step=0.1,
            value=0.1,  # default from https://huggingface.co/docs/transformers/en/main_classes/text_generation
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=1,  # from https://huggingface.co/google/gemma-3-270m-it/blob/main/generation_config.json
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=25,  # from https://huggingface.co/google/gemma-3-270m-it/blob/main/generation_config.json
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.1,  # default from https://huggingface.co/docs/transformers/en/main_classes/text_generation
        ),
    ],
    stop_btn="Stop Generation",
    examples=[
        ["Привіт! Як справи?"],
        [
            "Плюси та мінуси довгострокових стосунків. Маркований список із максимум 3 перевагами та 3 недоліками, стисло."
        ],
        ["Скільки годин потрібно людині, щоб з'їсти гелікоптер?"],
        ["Як відкрити файл JSON у Python?"],
        [
            "Створіть маркований список переваг і недоліків життя в Сан-Франциско. Максимум 2 переваги та 2 недоліки."
        ],
        ["Придумай коротке оповідання з тваринами про цінність дружби."],
        ["Чи можеш ти коротко пояснити, що таке мова програмування Python?"],
        [
            "Напишіть статтю на 100 слів на тему 'Переваги відкритого коду в дослідженнях ШІ'."
        ],
    ],
    cache_examples=False,
)

with gr.Blocks(css="style.css", fill_height=True, theme="soft") as demo:
    gr.Markdown(DESCRIPTION)
    chat_interface.render()

if __name__ == "__main__":
    demo.queue(max_size=20).launch()