Spaces:
Paused
Paused
| import os | |
| from threading import Thread | |
| from typing import Iterator | |
| from PIL import Image | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import ( | |
| AutoProcessor, | |
| Gemma3ForConditionalGeneration, | |
| TextIteratorStreamer, | |
| ) | |
| import subprocess | |
| subprocess.run( | |
| "pip install flash-attn --no-build-isolation", | |
| env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
| shell=True, | |
| ) | |
| DESCRIPTION = """\ | |
| # MamayLM-Gemma-3-12B-IT-v1.0 demo | |
| [🪪 **Model card**](https://huggingface.co/INSAIT-Institute/MamayLM-Gemma-3-12B-IT-v1.0) | |
| """ | |
| MAX_MAX_NEW_TOKENS = 2048 | |
| DEFAULT_MAX_NEW_TOKENS = 2048 | |
| MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
| model_id = "INSAIT-Institute/MamayLM-Gemma-3-12B-IT-v1.0" | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| attn_impl = "flash_attention_2" if torch.cuda.is_available() else "eager" | |
| model = Gemma3ForConditionalGeneration.from_pretrained( | |
| model_id, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation=attn_impl, | |
| ) | |
| model.eval() | |
| eos_token_ids = [1, 106] | |
| print("Model loaded successfully.") | |
| # Increased duration for the larger model | |
| def generate( | |
| message: dict, | |
| chat_history: list[list], | |
| system_message: str, | |
| max_new_tokens: int = 1024, | |
| temperature: float = 0.6, | |
| top_p: float = 0.95, | |
| top_k: int = 50, | |
| repetition_penalty: float = 1.2, | |
| ) -> Iterator[str]: | |
| """ | |
| Generates a response from the model based on the user's message and chat history. | |
| This function is designed to work with Gradio's multimodal ChatInterface. | |
| """ | |
| conversation = [] | |
| all_images = [] | |
| # Add system message if provided. This guides the model's behavior. | |
| if system_message: | |
| conversation.append( | |
| {"role": "system", "content": [{"type": "text", "text": system_message}]} | |
| ) | |
| # Process past turns from Gradio's chat_history | |
| for user_turn, bot_turn in chat_history: | |
| # Reconstruct the user's turn, which might include an image | |
| user_content = [] | |
| print("---") | |
| print(user_turn) | |
| print(bot_turn) | |
| print("---") | |
| if isinstance(user_turn, tuple): # User turn with an image | |
| if len(user_turn) == 1: # TODO: IDK now how to fix this bug | |
| continue | |
| img_path, txt = user_turn | |
| pil_img = Image.open(img_path).convert("RGB") | |
| all_images.append(pil_img) | |
| user_content.append({"type": "image"}) | |
| if txt: | |
| user_content.append({"type": "text", "text": txt}) | |
| elif user_turn: # Text-only user turn | |
| user_content.append({"type": "text", "text": user_turn}) | |
| if user_content: | |
| conversation.append({"role": "user", "content": user_content}) | |
| # Reconstruct the assistant's turn | |
| if bot_turn: | |
| conversation.append( | |
| {"role": "assistant", "content": [{"type": "text", "text": bot_turn}]} | |
| ) | |
| # Process the current user message, which can include new images | |
| current_user_content = [] | |
| text = message["text"] | |
| for img_path in message["files"]: | |
| pil_img = Image.open(img_path).convert("RGB") | |
| all_images.append(pil_img) | |
| current_user_content.append({"type": "image"}) | |
| if text: | |
| current_user_content.append({"type": "text", "text": text}) | |
| if current_user_content: | |
| conversation.append({"role": "user", "content": current_user_content}) | |
| print("### DEBUG:") | |
| print(conversation) | |
| print("####") | |
| # Use the processor to create the prompt and preprocess images | |
| prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) | |
| if len(all_images) > 0: | |
| inputs = processor(text=prompt, images=all_images, return_tensors="pt").to( | |
| model.device | |
| ) | |
| else: | |
| inputs = processor(text=prompt, return_tensors="pt").to(model.device) | |
| # Set up the streamer for text generation | |
| streamer = TextIteratorStreamer( | |
| processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| generate_kwargs = dict( | |
| inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| eos_token_id=eos_token_ids, | |
| temperature=temperature | |
| if temperature > 0 | |
| else 0.001, # Temperature must be > 0 | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| # Run generation in a separate thread | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| # Yield generated text chunks | |
| outputs = [] | |
| for text_chunk in streamer: | |
| outputs.append(text_chunk) | |
| yield "".join(outputs) | |
| chat_interface = gr.ChatInterface( | |
| multimodal=True, | |
| fn=generate, | |
| additional_inputs=[ | |
| gr.Textbox( | |
| value="", | |
| label="System message", | |
| render=False, | |
| ), | |
| gr.Slider( | |
| label="Max new tokens", | |
| minimum=1, | |
| maximum=MAX_MAX_NEW_TOKENS, | |
| step=1, | |
| value=DEFAULT_MAX_NEW_TOKENS, | |
| ), | |
| gr.Slider( | |
| label="Temperature", | |
| minimum=0, | |
| maximum=4.0, | |
| step=0.1, | |
| value=0.1, # default from https://huggingface.co/docs/transformers/en/main_classes/text_generation | |
| ), | |
| gr.Slider( | |
| label="Top-p (nucleus sampling)", | |
| minimum=0.05, | |
| maximum=1.0, | |
| step=0.05, | |
| value=1, # from https://huggingface.co/google/gemma-3-270m-it/blob/main/generation_config.json | |
| ), | |
| gr.Slider( | |
| label="Top-k", | |
| minimum=1, | |
| maximum=1000, | |
| step=1, | |
| value=25, # from https://huggingface.co/google/gemma-3-270m-it/blob/main/generation_config.json | |
| ), | |
| gr.Slider( | |
| label="Repetition penalty", | |
| minimum=1.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.1, # default from https://huggingface.co/docs/transformers/en/main_classes/text_generation | |
| ), | |
| ], | |
| stop_btn="Stop Generation", | |
| examples=[ | |
| ["Привіт! Як справи?"], | |
| [ | |
| "Плюси та мінуси довгострокових стосунків. Маркований список із максимум 3 перевагами та 3 недоліками, стисло." | |
| ], | |
| ["Скільки годин потрібно людині, щоб з'їсти гелікоптер?"], | |
| ["Як відкрити файл JSON у Python?"], | |
| [ | |
| "Створіть маркований список переваг і недоліків життя в Сан-Франциско. Максимум 2 переваги та 2 недоліки." | |
| ], | |
| ["Придумай коротке оповідання з тваринами про цінність дружби."], | |
| ["Чи можеш ти коротко пояснити, що таке мова програмування Python?"], | |
| [ | |
| "Напишіть статтю на 100 слів на тему 'Переваги відкритого коду в дослідженнях ШІ'." | |
| ], | |
| ], | |
| cache_examples=False, | |
| ) | |
| with gr.Blocks(css="style.css", fill_height=True, theme="soft") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| chat_interface.render() | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch() | |