File size: 2,171 Bytes
e36439e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
import os

app = FastAPI(title="Llama 3.2 1B API")

# Model configuration
REPO_ID = "bartowski/Llama-3.2-1B-Instruct-GGUF"
FILENAME = "Llama-3.2-1B-Instruct-Q4_K_M.gguf"
MODEL_PATH = os.path.join(os.getcwd(), FILENAME)

def ensure_model_exists():
    if not os.path.exists(MODEL_PATH):
        print(f"Downloading model {FILENAME} from {REPO_ID}...")
        try:
            hf_hub_download(
                repo_id=REPO_ID,
                filename=FILENAME,
                local_dir=os.getcwd(),
                local_dir_use_symlinks=False
            )
            print("Download complete.")
        except Exception as e:
            raise RuntimeError(f"Failed to download model: {e}")
    else:
        print(f"Model found at {MODEL_PATH}")

# Ensure model is downloaded before initializing Llama
ensure_model_exists()

# Initialize the model
# n_threads=4 as requested by the user
# n_ctx=2048 for a reasonable context window
llm = Llama(
    model_path=MODEL_PATH,
    n_threads=4,
    n_ctx=2048,
    verbose=False
)

class ChatRequest(BaseModel):
    prompt: str
    max_tokens: int = 512
    temperature: float = 0.7
    top_p: float = 0.9

@app.get("/")
async def root():
    return {"message": "Llama 3.2 1B FastAPI server is running", "model": FILENAME}

@app.post("/v1/chat/completions")
async def chat_completion(request: ChatRequest):
    try:
        # Simple prompt template for Llama 3.2 Instruct
        formatted_prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{request.prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
        
        output = llm(
            formatted_prompt,
            max_tokens=request.max_tokens,
            temperature=request.temperature,
            top_p=request.top_p,
            stop=["<|eot_id|>"]
        )
        
        return output
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)