|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
from llama_cpp import Llama |
|
|
from huggingface_hub import hf_hub_download |
|
|
import os |
|
|
|
|
|
app = FastAPI(title="Llama 3.2 1B API") |
|
|
|
|
|
|
|
|
REPO_ID = "bartowski/Llama-3.2-1B-Instruct-GGUF" |
|
|
FILENAME = "Llama-3.2-1B-Instruct-Q4_K_M.gguf" |
|
|
MODEL_PATH = os.path.join(os.getcwd(), FILENAME) |
|
|
|
|
|
def ensure_model_exists(): |
|
|
if not os.path.exists(MODEL_PATH): |
|
|
print(f"Downloading model {FILENAME} from {REPO_ID}...") |
|
|
try: |
|
|
hf_hub_download( |
|
|
repo_id=REPO_ID, |
|
|
filename=FILENAME, |
|
|
local_dir=os.getcwd(), |
|
|
local_dir_use_symlinks=False |
|
|
) |
|
|
print("Download complete.") |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to download model: {e}") |
|
|
else: |
|
|
print(f"Model found at {MODEL_PATH}") |
|
|
|
|
|
|
|
|
ensure_model_exists() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
llm = Llama( |
|
|
model_path=MODEL_PATH, |
|
|
n_threads=4, |
|
|
n_ctx=2048, |
|
|
verbose=False |
|
|
) |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
prompt: str |
|
|
max_tokens: int = 512 |
|
|
temperature: float = 0.7 |
|
|
top_p: float = 0.9 |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return {"message": "Llama 3.2 1B FastAPI server is running", "model": FILENAME} |
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
|
async def chat_completion(request: ChatRequest): |
|
|
try: |
|
|
|
|
|
formatted_prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{request.prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" |
|
|
|
|
|
output = llm( |
|
|
formatted_prompt, |
|
|
max_tokens=request.max_tokens, |
|
|
temperature=request.temperature, |
|
|
top_p=request.top_p, |
|
|
stop=["<|eot_id|>"] |
|
|
) |
|
|
|
|
|
return output |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|