gemma / main.py
sarveshpatel's picture
Update main.py
15d0c30 verified
import json
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from llama_cpp import Llama
app = FastAPI()
# 1. Load the Model
# n_gpu_layers=-1 tries to offload all layers to GPU. Set to 0 if using CPU only.
# n_ctx=2048 is the context window size.
print("Loading model...")
llm = Llama(
model_path="/gemma-3-1b-it-Q8_0.gguf",
n_gpu_layers=0,
n_ctx=2048,
verbose=False
)
# 2. Define Request Body
class ChatRequest(BaseModel):
message: str
temperature: float = 0.7
max_tokens: int = 512
# 3. The Streaming Generator
def stream_text(prompt: str, temperature: float, max_tokens: int):
"""
Generates tokens one by one and yields them.
"""
# Prepare messages for Gemma (Chat format)
messages = [
{"role": "user", "content": prompt}
]
# create_chat_completion handles the specific templating for Gemma automatically
stream = llm.create_chat_completion(
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True # <--- Vital for streaming
)
for chunk in stream:
if "content" in chunk["choices"][0]["delta"]:
text_chunk = chunk["choices"][0]["delta"]["content"]
# Yield just the text (clean stream) or JSON (structured stream)
yield text_chunk
@app.post("/generate")
async def generate_stream(request: ChatRequest):
return StreamingResponse(
stream_text(request.message, request.temperature, request.max_tokens),
media_type="text/plain"
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)