import torch import uvicorn import gradio as gr from fastapi import FastAPI from fastapi.responses import StreamingResponse from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from threading import Thread # -------- CONFIG -------- MODEL_ID = "google/gemma-4-E2B-it" # -------- LOAD TOKENIZER -------- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # -------- LOAD MODEL (LOW MEMORY MODE) -------- torch.set_grad_enabled(False) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="cpu", torch_dtype=torch.float16, low_cpu_mem_usage=True ) # -------- FASTAPI -------- app = FastAPI() class Request(BaseModel): prompt: str # -------- STREAM FUNCTION -------- def stream_llm(prompt): # ✅ FIX: proper chat format for Gemma messages = [ {"role": "user", "content": prompt} ] inputs = tokenizer.apply_chat_template( messages, return_tensors="pt" ) inputs = inputs.to(model.device) input_ids = inputs["input_ids"] streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) generation_kwargs = dict( input_ids=input_ids, max_new_tokens=128, temperature=0.7, top_p=0.9, do_sample=True, streamer=streamer, eos_token_id=tokenizer.eos_token_id ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() for token in streamer: yield token # -------- API -------- @app.post("/generate") def generate(req: Request): return StreamingResponse(stream_llm(req.prompt), media_type="text/plain") # -------- UI -------- def chat_fn(message, history): response = "" for chunk in stream_llm(message): response += chunk yield response demo = gr.ChatInterface(chat_fn) # Mount UI app = gr.mount_gradio_app(app, demo, path="/") # -------- SERVER START -------- if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)