| import torch |
| import uvicorn |
| import gradio as gr |
| from fastapi import FastAPI |
| from fastapi.responses import StreamingResponse |
| from pydantic import BaseModel |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
| from threading import Thread |
|
|
| |
| MODEL_ID = "google/gemma-4-E2B-it" |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
| |
| torch.set_grad_enabled(False) |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| device_map="cpu", |
| torch_dtype=torch.float16, |
| low_cpu_mem_usage=True |
| ) |
|
|
| |
| app = FastAPI() |
|
|
| class Request(BaseModel): |
| prompt: str |
|
|
| |
| def stream_llm(prompt): |
| |
| messages = [ |
| {"role": "user", "content": prompt} |
| ] |
|
|
| inputs = tokenizer.apply_chat_template( |
| messages, |
| return_tensors="pt" |
| ) |
|
|
| inputs = inputs.to(model.device) |
|
|
| input_ids = inputs["input_ids"] |
|
|
| streamer = TextIteratorStreamer( |
| tokenizer, |
| skip_prompt=True, |
| skip_special_tokens=True |
| ) |
|
|
| generation_kwargs = dict( |
| input_ids=input_ids, |
| max_new_tokens=128, |
| temperature=0.7, |
| top_p=0.9, |
| do_sample=True, |
| streamer=streamer, |
| eos_token_id=tokenizer.eos_token_id |
| ) |
|
|
| thread = Thread(target=model.generate, kwargs=generation_kwargs) |
| thread.start() |
|
|
| for token in streamer: |
| yield token |
|
|
| |
| @app.post("/generate") |
| def generate(req: Request): |
| return StreamingResponse(stream_llm(req.prompt), media_type="text/plain") |
|
|
| |
| def chat_fn(message, history): |
| response = "" |
| for chunk in stream_llm(message): |
| response += chunk |
| yield response |
|
|
| demo = gr.ChatInterface(chat_fn) |
|
|
| |
| app = gr.mount_gradio_app(app, demo, path="/") |
|
|
| |
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=7860) |