gemma-RL / app.py
ghitaben's picture
Update app.py
86dc3aa verified
import gradio as gr
import torch
from PIL import Image
from transformers import AutoTokenizer, AutoModelForMultimodalLM
MODEL_NAME = "ghitaben/unsloth_finetune_gemma3"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForMultimodalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
def generate(image, instruction):
try:
# Convert numpy image to PIL
image = Image.fromarray(image)
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": instruction},
],
}
]
input_text = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False,
)
inputs = tokenizer(
input_text,
images=image,
add_special_tokens=False,
return_tensors="pt",
).to(device)
# 1. Get length of input tokens to slice them off later
input_len = inputs.input_ids.shape[1]
output_ids = model.generate(
**inputs,
max_new_tokens=128,
do_sample=True,
temperature=1.0,
top_p=0.95,
top_k=64,
use_cache=True,
)
# 3. Slice the output to keep only the NEW tokens
generated_ids = output_ids[0][input_len:]
return tokenizer.decode(generated_ids, skip_special_tokens=True)
except Exception as e:
return f"Error: {e}"
with gr.Blocks() as demo:
gr.Markdown("## 🧠 Multimodal Reasoning Demo")
with gr.Row():
img = gr.Image(type="numpy", label="Image")
prompt = gr.Textbox(
label="Instruction",
placeholder="Ask a question about the image..."
)
btn = gr.Button("Generate")
out = gr.Textbox(label="Output", lines=8)
btn.click(generate, inputs=[img, prompt], outputs=out)
demo.launch()