Spaces:
Paused
Paused
File size: 7,581 Bytes
249ef4b dac6325 249ef4b ece025a 249ef4b f1a92ec 249ef4b dac6325 249ef4b dac6325 b726489 249ef4b d93c464 dac6325 249ef4b 16f7f73 249ef4b dac6325 249ef4b dac6325 249ef4b dac6325 249ef4b dac6325 16f7f73 dac6325 249ef4b 16f7f73 dac6325 16f7f73 dac6325 16f7f73 dce4935 16f7f73 dce4935 16f7f73 dac6325 16f7f73 dac6325 16f7f73 dac6325 16f7f73 dac6325 16f7f73 dac6325 ece025a 249ef4b 16f7f73 249ef4b 16f7f73 249ef4b dac6325 249ef4b dac6325 249ef4b d93c464 16f7f73 249ef4b 16f7f73 249ef4b 16f7f73 249ef4b dac6325 249ef4b dac6325 249ef4b ce34d83 249ef4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
import os
from threading import Thread
from typing import Iterator
from PIL import Image
import gradio as gr
import spaces
import torch
from transformers import (
AutoProcessor,
Gemma3ForConditionalGeneration,
TextIteratorStreamer,
)
import subprocess
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
DESCRIPTION = """\
# MamayLM-Gemma-3-12B-IT-v1.0 demo
[🪪 **Model card**](https://huggingface.co/INSAIT-Institute/MamayLM-Gemma-3-12B-IT-v1.0)
"""
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 2048
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
model_id = "INSAIT-Institute/MamayLM-Gemma-3-12B-IT-v1.0"
processor = AutoProcessor.from_pretrained(model_id)
attn_impl = "flash_attention_2" if torch.cuda.is_available() else "eager"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation=attn_impl,
)
model.eval()
eos_token_ids = [1, 106]
print("Model loaded successfully.")
@spaces.GPU(duration=120) # Increased duration for the larger model
def generate(
message: dict,
chat_history: list[list],
system_message: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.95,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
"""
Generates a response from the model based on the user's message and chat history.
This function is designed to work with Gradio's multimodal ChatInterface.
"""
conversation = []
all_images = []
# Add system message if provided. This guides the model's behavior.
if system_message:
conversation.append(
{"role": "system", "content": [{"type": "text", "text": system_message}]}
)
# Process past turns from Gradio's chat_history
for user_turn, bot_turn in chat_history:
# Reconstruct the user's turn, which might include an image
user_content = []
print("---")
print(user_turn)
print(bot_turn)
print("---")
if isinstance(user_turn, tuple): # User turn with an image
if len(user_turn) == 1: # TODO: IDK now how to fix this bug
continue
img_path, txt = user_turn
pil_img = Image.open(img_path).convert("RGB")
all_images.append(pil_img)
user_content.append({"type": "image"})
if txt:
user_content.append({"type": "text", "text": txt})
elif user_turn: # Text-only user turn
user_content.append({"type": "text", "text": user_turn})
if user_content:
conversation.append({"role": "user", "content": user_content})
# Reconstruct the assistant's turn
if bot_turn:
conversation.append(
{"role": "assistant", "content": [{"type": "text", "text": bot_turn}]}
)
# Process the current user message, which can include new images
current_user_content = []
text = message["text"]
for img_path in message["files"]:
pil_img = Image.open(img_path).convert("RGB")
all_images.append(pil_img)
current_user_content.append({"type": "image"})
if text:
current_user_content.append({"type": "text", "text": text})
if current_user_content:
conversation.append({"role": "user", "content": current_user_content})
print("### DEBUG:")
print(conversation)
print("####")
# Use the processor to create the prompt and preprocess images
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
if len(all_images) > 0:
inputs = processor(text=prompt, images=all_images, return_tensors="pt").to(
model.device
)
else:
inputs = processor(text=prompt, return_tensors="pt").to(model.device)
# Set up the streamer for text generation
streamer = TextIteratorStreamer(
processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
eos_token_id=eos_token_ids,
temperature=temperature
if temperature > 0
else 0.001, # Temperature must be > 0
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
)
# Run generation in a separate thread
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Yield generated text chunks
outputs = []
for text_chunk in streamer:
outputs.append(text_chunk)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
multimodal=True,
fn=generate,
additional_inputs=[
gr.Textbox(
value="",
label="System message",
render=False,
),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0,
maximum=4.0,
step=0.1,
value=0.1, # default from https://huggingface.co/docs/transformers/en/main_classes/text_generation
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=1, # from https://huggingface.co/google/gemma-3-270m-it/blob/main/generation_config.json
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=25, # from https://huggingface.co/google/gemma-3-270m-it/blob/main/generation_config.json
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.1, # default from https://huggingface.co/docs/transformers/en/main_classes/text_generation
),
],
stop_btn="Stop Generation",
examples=[
["Привіт! Як справи?"],
[
"Плюси та мінуси довгострокових стосунків. Маркований список із максимум 3 перевагами та 3 недоліками, стисло."
],
["Скільки годин потрібно людині, щоб з'їсти гелікоптер?"],
["Як відкрити файл JSON у Python?"],
[
"Створіть маркований список переваг і недоліків життя в Сан-Франциско. Максимум 2 переваги та 2 недоліки."
],
["Придумай коротке оповідання з тваринами про цінність дружби."],
["Чи можеш ти коротко пояснити, що таке мова програмування Python?"],
[
"Напишіть статтю на 100 слів на тему 'Переваги відкритого коду в дослідженнях ШІ'."
],
],
cache_examples=False,
)
with gr.Blocks(css="style.css", fill_height=True, theme="soft") as demo:
gr.Markdown(DESCRIPTION)
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|