gemma4tunnel / app.py
TUSTResearcher's picture
Sync from GitHub via hub-sync
8783e4d verified
import os
import gradio as gr
import torch
from peft import PeftModel
from transformers import AutoModelForMultimodalLM, AutoProcessor
BASE_MODEL = os.environ.get("BASE_MODEL", "google/gemma-4-12B-it")
ADAPTER_MODEL = os.environ.get("ADAPTER_MODEL", "TUSTResearcher/gemma4tunnel")
model = None
processor = None
def load_model():
global model, processor
if model is not None:
return model, processor
processor = AutoProcessor.from_pretrained(BASE_MODEL)
base = AutoModelForMultimodalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
)
model = PeftModel.from_pretrained(base, ADAPTER_MODEL)
model.eval()
return model, processor
def respond(message, history, max_new_tokens=512, temperature=0.3):
m, proc = load_model()
messages = [
{
"role": "system",
"content": "You are Gemma4Tunnel, a careful tunnel boring machine and tunnelling engineering assistant. State assumptions and do not invent citations.",
}
]
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
inputs = proc.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(m.device)
with torch.no_grad():
output = m.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=temperature > 0,
temperature=temperature,
top_p=0.9,
)
text = proc.decode(output[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
return text.strip()
demo = gr.ChatInterface(
fn=respond,
title="Gemma4Tunnel",
description="Research assistant for tunnel boring machine and tunnelling engineering. Use source-grounded RAG for paper-specific citations.",
additional_inputs=[
gr.Slider(64, 2048, value=512, step=64, label="max_new_tokens"),
gr.Slider(0.0, 1.0, value=0.3, step=0.05, label="temperature"),
],
)
if __name__ == "__main__":
demo.launch()