Yehor's picture
A fix
d93c464
import os
from threading import Thread
from typing import Iterator
from PIL import Image
import gradio as gr
import spaces
import torch
from transformers import (
AutoProcessor,
Gemma3ForConditionalGeneration,
TextIteratorStreamer,
)
import subprocess
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
DESCRIPTION = """\
# MamayLM-Gemma-3-12B-IT-v1.0 demo
[🪪 **Model card**](https://huggingface.co/INSAIT-Institute/MamayLM-Gemma-3-12B-IT-v1.0)
"""
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 2048
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
model_id = "INSAIT-Institute/MamayLM-Gemma-3-12B-IT-v1.0"
processor = AutoProcessor.from_pretrained(model_id)
attn_impl = "flash_attention_2" if torch.cuda.is_available() else "eager"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation=attn_impl,
)
model.eval()
eos_token_ids = [1, 106]
print("Model loaded successfully.")
@spaces.GPU(duration=120) # Increased duration for the larger model
def generate(
message: dict,
chat_history: list[list],
system_message: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.95,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
"""
Generates a response from the model based on the user's message and chat history.
This function is designed to work with Gradio's multimodal ChatInterface.
"""
conversation = []
all_images = []
# Add system message if provided. This guides the model's behavior.
if system_message:
conversation.append(
{"role": "system", "content": [{"type": "text", "text": system_message}]}
)
# Process past turns from Gradio's chat_history
for user_turn, bot_turn in chat_history:
# Reconstruct the user's turn, which might include an image
user_content = []
print("---")
print(user_turn)
print(bot_turn)
print("---")
if isinstance(user_turn, tuple): # User turn with an image
if len(user_turn) == 1: # TODO: IDK now how to fix this bug
continue
img_path, txt = user_turn
pil_img = Image.open(img_path).convert("RGB")
all_images.append(pil_img)
user_content.append({"type": "image"})
if txt:
user_content.append({"type": "text", "text": txt})
elif user_turn: # Text-only user turn
user_content.append({"type": "text", "text": user_turn})
if user_content:
conversation.append({"role": "user", "content": user_content})
# Reconstruct the assistant's turn
if bot_turn:
conversation.append(
{"role": "assistant", "content": [{"type": "text", "text": bot_turn}]}
)
# Process the current user message, which can include new images
current_user_content = []
text = message["text"]
for img_path in message["files"]:
pil_img = Image.open(img_path).convert("RGB")
all_images.append(pil_img)
current_user_content.append({"type": "image"})
if text:
current_user_content.append({"type": "text", "text": text})
if current_user_content:
conversation.append({"role": "user", "content": current_user_content})
print("### DEBUG:")
print(conversation)
print("####")
# Use the processor to create the prompt and preprocess images
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
if len(all_images) > 0:
inputs = processor(text=prompt, images=all_images, return_tensors="pt").to(
model.device
)
else:
inputs = processor(text=prompt, return_tensors="pt").to(model.device)
# Set up the streamer for text generation
streamer = TextIteratorStreamer(
processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
eos_token_id=eos_token_ids,
temperature=temperature
if temperature > 0
else 0.001, # Temperature must be > 0
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
)
# Run generation in a separate thread
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Yield generated text chunks
outputs = []
for text_chunk in streamer:
outputs.append(text_chunk)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
multimodal=True,
fn=generate,
additional_inputs=[
gr.Textbox(
value="",
label="System message",
render=False,
),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0,
maximum=4.0,
step=0.1,
value=0.1, # default from https://huggingface.co/docs/transformers/en/main_classes/text_generation
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=1, # from https://huggingface.co/google/gemma-3-270m-it/blob/main/generation_config.json
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=25, # from https://huggingface.co/google/gemma-3-270m-it/blob/main/generation_config.json
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.1, # default from https://huggingface.co/docs/transformers/en/main_classes/text_generation
),
],
stop_btn="Stop Generation",
examples=[
["Привіт! Як справи?"],
[
"Плюси та мінуси довгострокових стосунків. Маркований список із максимум 3 перевагами та 3 недоліками, стисло."
],
["Скільки годин потрібно людині, щоб з'їсти гелікоптер?"],
["Як відкрити файл JSON у Python?"],
[
"Створіть маркований список переваг і недоліків життя в Сан-Франциско. Максимум 2 переваги та 2 недоліки."
],
["Придумай коротке оповідання з тваринами про цінність дружби."],
["Чи можеш ти коротко пояснити, що таке мова програмування Python?"],
[
"Напишіть статтю на 100 слів на тему 'Переваги відкритого коду в дослідженнях ШІ'."
],
],
cache_examples=False,
)
with gr.Blocks(css="style.css", fill_height=True, theme="soft") as demo:
gr.Markdown(DESCRIPTION)
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()