Yehor's picture
Init
249ef4b
raw
history blame
5.71 kB
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import subprocess
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
DESCRIPTION = """\
# INSAIT-Institute/MamayLM-Gemma-3-12B-IT-v1.0
[🪪 **Model card**](https://huggingface.co/INSAIT-Institute/MamayLM-Gemma-3-12B-IT-v1.0)
"""
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 2048
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
model_id = "INSAIT-Institute/MamayLM-Gemma-3-12B-IT-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_id)
attn_impl = "flash_attention_2" if torch.cuda.is_available() else "eager"
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation=attn_impl,
)
# model.config.sliding_window = 4096
model.eval()
@spaces.GPU(duration=90)
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_message: str = "",
max_new_tokens: int = 1024,
temperature: float = 0.001,
top_p: float = 1.0,
top_k: int = 50,
repetition_penalty: float = 1.0,
) -> Iterator[str]:
conversation = [{"role": "system", "content": system_message}]
for user, assistant in chat_history:
conversation.extend(
[
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
]
)
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(
conversation, add_generation_prompt=True, return_tensors="pt"
)
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(
f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
)
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
eos_token_id=[1, 107],
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
disable_compile=True, # https://ai.google.dev/gemma/docs/core/huggingface_text_full_finetune#test_model_inference
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(
value="",
label="System message",
render=False,
),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0,
maximum=4.0,
step=0.1,
value=0.1, # default from https://huggingface.co/docs/transformers/en/main_classes/text_generation
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=1, # from https://huggingface.co/google/gemma-3-270m-it/blob/main/generation_config.json
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=25, # from https://huggingface.co/google/gemma-3-270m-it/blob/main/generation_config.json
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.1, # default from https://huggingface.co/docs/transformers/en/main_classes/text_generation
),
],
stop_btn=None,
examples=[
["Привіт! Як справи?"],
[
"Плюси та мінуси довгострокових стосунків. Маркований список із максимум 3 перевагами та 3 недоліками, стисло."
],
["Скільки годин потрібно людині, щоб з'їсти гелікоптер?"],
["Як відкрити файл JSON у Python?"],
[
"Створіть маркований список переваг і недоліків життя в Сан-Франциско. Максимум 2 переваги та 2 недоліки."
],
["Придумай коротке оповідання з тваринами про цінність дружби."],
["Чи можеш ти коротко пояснити, що таке мова програмування Python?"],
[
"Напишіть статтю на 100 слів на тему 'Переваги відкритого коду в дослідженнях ШІ'."
],
],
cache_examples=False,
)
with gr.Blocks(css="style.css", fill_height=True, theme="soft") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(
value="Duplicate Space for private use", elem_id="duplicate-button"
)
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()