File size: 2,171 Bytes
e36439e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
import os
app = FastAPI(title="Llama 3.2 1B API")
# Model configuration
REPO_ID = "bartowski/Llama-3.2-1B-Instruct-GGUF"
FILENAME = "Llama-3.2-1B-Instruct-Q4_K_M.gguf"
MODEL_PATH = os.path.join(os.getcwd(), FILENAME)
def ensure_model_exists():
if not os.path.exists(MODEL_PATH):
print(f"Downloading model {FILENAME} from {REPO_ID}...")
try:
hf_hub_download(
repo_id=REPO_ID,
filename=FILENAME,
local_dir=os.getcwd(),
local_dir_use_symlinks=False
)
print("Download complete.")
except Exception as e:
raise RuntimeError(f"Failed to download model: {e}")
else:
print(f"Model found at {MODEL_PATH}")
# Ensure model is downloaded before initializing Llama
ensure_model_exists()
# Initialize the model
# n_threads=4 as requested by the user
# n_ctx=2048 for a reasonable context window
llm = Llama(
model_path=MODEL_PATH,
n_threads=4,
n_ctx=2048,
verbose=False
)
class ChatRequest(BaseModel):
prompt: str
max_tokens: int = 512
temperature: float = 0.7
top_p: float = 0.9
@app.get("/")
async def root():
return {"message": "Llama 3.2 1B FastAPI server is running", "model": FILENAME}
@app.post("/v1/chat/completions")
async def chat_completion(request: ChatRequest):
try:
# Simple prompt template for Llama 3.2 Instruct
formatted_prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{request.prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
output = llm(
formatted_prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
stop=["<|eot_id|>"]
)
return output
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|