| import gradio as gr |
| import cv2 |
| import torch |
| from PIL import Image |
| from pathlib import Path |
| from threading import Thread |
| from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer |
| import spaces |
| import time |
|
|
| |
| model_12b_name = "google/gemma-3-12b-it" |
| model_4b_name = "google/gemma-3-4b-it" |
| model_12b = Gemma3ForConditionalGeneration.from_pretrained( |
| model_12b_name, |
| device_map="auto", |
| torch_dtype=torch.bfloat16 |
| ).eval() |
| processor_12b = AutoProcessor.from_pretrained(model_12b_name) |
| model_4b = Gemma3ForConditionalGeneration.from_pretrained( |
| model_4b_name, |
| device_map="auto", |
| torch_dtype=torch.bfloat16 |
| ).eval() |
| processor_4b = AutoProcessor.from_pretrained(model_4b_name) |
| |
| def extract_video_frames(video_path, num_frames=8): |
| cap = cv2.VideoCapture(video_path) |
| frames = [] |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| step = max(total_frames // num_frames, 1) |
| |
| for i in range(num_frames): |
| cap.set(cv2.CAP_PROP_POS_FRAMES, i * step) |
| ret, frame = cap.read() |
| if ret: |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| frames.append(Image.fromarray(frame)) |
| cap.release() |
| return frames |
|
|
| def format_message(content, files): |
| |
| message_content = [] |
|
|
| if content: |
| parts = content.split('<image>') |
| for i, part in enumerate(parts): |
| if part.strip(): |
| message_content.append({"type": "text", "text": part.strip()}) |
| if i < len(parts) - 1 and files: |
| img = Image.open(files.pop(0)) |
| message_content.append({"type": "image", "image": img}) |
| for file in files: |
| file_path = file if isinstance(file, str) else file.name |
| if Path(file_path).suffix.lower() in ['.jpg', '.jpeg', '.png']: |
| img = Image.open(file_path) |
| message_content.append({"type": "image", "image": img}) |
| elif Path(file_path).suffix.lower() in ['.mp4', '.mov']: |
| frames = extract_video_frames(file_path) |
| for frame in frames: |
| message_content.append({"type": "image", "image": frame}) |
| return message_content |
|
|
| def format_conversation_history(chat_history): |
| messages = [] |
| current_user_content = [] |
| for item in chat_history: |
| role = item["role"] |
| content = item["content"] |
| if role == "user": |
| if isinstance(content, str): |
| current_user_content.append({"type": "text", "text": content}) |
| elif isinstance(content, list): |
| current_user_content.extend(content) |
| else: |
| current_user_content.append({"type": "text", "text": str(content)}) |
| elif role == "assistant": |
| if current_user_content: |
| messages.append({"role": "user", "content": current_user_content}) |
| current_user_content = [] |
| messages.append({"role": "assistant", "content": [{"type": "text", "text": str(content)}]}) |
| if current_user_content: |
| messages.append({"role": "user", "content": current_user_content}) |
| return messages |
|
|
| @spaces.GPU(duration=120) |
| def generate_response(input_data, chat_history, model_choice, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty): |
| if isinstance(input_data, dict) and "text" in input_data: |
| text = input_data["text"] |
| files = input_data.get("files", []) |
| else: |
| text = str(input_data) |
| files = [] |
|
|
| new_message_content = format_message(text, files) |
| new_message = {"role": "user", "content": new_message_content} |
| system_message = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] if system_prompt else [] |
| processed_history = format_conversation_history(chat_history) |
| messages = system_message + processed_history |
| if messages and messages[-1]["role"] == "user": |
| messages[-1]["content"].extend(new_message["content"]) |
| else: |
| messages.append(new_message) |
| if model_choice == "Gemma 3 12B": |
| model = model_12b |
| processor = processor_12b |
| else: |
| model = model_4b |
| processor = processor_4b |
| inputs = processor.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| tokenize=True, |
| return_tensors="pt", |
| return_dict=True |
| ).to(model.device) |
| streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) |
| generation_kwargs = dict( |
| inputs, |
| streamer=streamer, |
| max_new_tokens=max_new_tokens, |
| do_sample=True, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| repetition_penalty=repetition_penalty |
| ) |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) |
| thread.start() |
| |
| outputs = [] |
| for text in streamer: |
| outputs.append(text) |
| yield "".join(outputs) |
|
|
| demo = gr.ChatInterface( |
| fn=generate_response, |
| additional_inputs=[ |
| gr.Dropdown( |
| label="Model", |
| choices=["Gemma 3 12B", "Gemma 3 4B"], |
| value="Gemma 3 12B" |
| ), |
| gr.Slider(label="Max new tokens", minimum=100, maximum=2000, step=1, value=512), |
| gr.Textbox( |
| label="System Prompt", |
| value="You are a friendly chatbot. ", |
| lines=4, |
| placeholder="Change system prompt" |
| ), |
| gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7), |
| gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9), |
| gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50), |
| gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0), |
| ], |
| examples=[ |
| [{"text": "Explain this image", "files": ["examples/image1.jpg"]}], |
| ], |
| cache_examples=False, |
| type="messages", |
| description=""" |
| # Gemma 3 |
| You can pick your model 12B or 4B, upload images or videos, and adjust settings below to customize your experience. |
| """, |
| fill_height=True, |
| textbox=gr.MultimodalTextbox( |
| label="Query Input", |
| file_types=["image", "video"], |
| file_count="multiple", |
| placeholder="Type your message or upload media" |
| ), |
| stop_btn="Stop Generation", |
| multimodal=True, |
| theme=gr.themes.Soft(), |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |