gemma-4 / app.py
Valtry's picture
Update app.py
60b835a verified
import torch
import uvicorn
import gradio as gr
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
# -------- CONFIG --------
MODEL_ID = "google/gemma-4-E2B-it"
# -------- LOAD TOKENIZER --------
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# -------- LOAD MODEL (LOW MEMORY MODE) --------
torch.set_grad_enabled(False)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="cpu",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
# -------- FASTAPI --------
app = FastAPI()
class Request(BaseModel):
prompt: str
# -------- STREAM FUNCTION --------
def stream_llm(prompt):
# ✅ FIX: proper chat format for Gemma
messages = [
{"role": "user", "content": prompt}
]
inputs = tokenizer.apply_chat_template(
messages,
return_tensors="pt"
)
inputs = inputs.to(model.device)
input_ids = inputs["input_ids"]
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
generation_kwargs = dict(
input_ids=input_ids,
max_new_tokens=128,
temperature=0.7,
top_p=0.9,
do_sample=True,
streamer=streamer,
eos_token_id=tokenizer.eos_token_id
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
for token in streamer:
yield token
# -------- API --------
@app.post("/generate")
def generate(req: Request):
return StreamingResponse(stream_llm(req.prompt), media_type="text/plain")
# -------- UI --------
def chat_fn(message, history):
response = ""
for chunk in stream_llm(message):
response += chunk
yield response
demo = gr.ChatInterface(chat_fn)
# Mount UI
app = gr.mount_gradio_app(app, demo, path="/")
# -------- SERVER START --------
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)