Gemma3n / app.py
rufatronics's picture
Update app.py
72c6b5a verified
import gradio as gr
from transformers import AutoModelForImageTextToText, AutoProcessor
import torch
import os
model_id = "google/gemma-3n-E2B-it"
hf_token = os.getenv("HF_TOKEN")
device = "cpu"
print("Loading Gemma 3n with Memory Optimizations...")
# 1. We use bfloat16 to cut RAM usage by 50%
# 2. low_cpu_mem_usage prevents the 'double loading' crash
processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
model = AutoModelForImageTextToText.from_pretrained(
model_id,
token=hf_token,
torch_dtype=torch.bfloat16, # KEY FIX: Half-precision for CPU
low_cpu_mem_usage=True, # KEY FIX: Don't use double RAM on load
device_map="auto"
)
def chat_function(message, history):
msgs = []
for user_msg, assistant_msg in history:
if user_msg: msgs.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
if assistant_msg: msgs.append({"role": "model", "content": [{"type": "text", "text": assistant_msg}]})
msgs.append({"role": "user", "content": [{"type": "text", "text": message}]})
inputs = processor.apply_chat_template(
msgs,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt"
).to(device)
# Note: Inference on CPU with bfloat16 is much safer for RAM
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=400,
do_sample=True,
temperature=0.4
)
response = processor.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
return response
demo = gr.ChatInterface(fn=chat_function, title="Gemma 3n E2B (RAM Optimized)")
if __name__ == "__main__":
demo.launch()