| import fastapi |
| from fastapi.responses import JSONResponse |
| from time import time |
| |
| import logging |
| import llama_cpp |
| import llama_cpp.llama_tokenizer |
|
|
| llama = llama_cpp.Llama.from_pretrained( |
| repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF", |
| filename="*q4_0.gguf", |
| tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B"), |
| verbose=False, |
| n_ctx=4096, |
| n_gpu_layers=0, |
| chat_format="llama-2" |
| ) |
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| """ |
| try: |
| llm = Llama.from_pretrained( |
| repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF", |
| filename="*q4_0.gguf", |
| verbose=False, |
| n_ctx=4096, |
| n_threads=4, |
| n_gpu_layers=0, |
| ) |
| |
| llm = Llama( |
| model_path=MODEL_PATH, |
| chat_format="llama-2", |
| n_ctx=4096, |
| n_threads=8, |
| n_gpu_layers=0, |
| ) |
| |
| except Exception as e: |
| logger.error(f"Failed to load model: {e}") |
| raise |
| """ |
|
|
| app = fastapi.FastAPI() |
|
|
|
|
| @app.get("/") |
| def index(): |
| return fastapi.responses.RedirectResponse(url="/docs") |
|
|
|
|
| @app.get("/health") |
| def health(): |
| return {"status": "ok"} |
|
|
|
|
| |
| @app.get("/generate") |
| async def complete( |
| question: str, |
| system: str = "You are a story writing assistant.", |
| temperature: float = 0.7, |
| seed: int = 42, |
| ) -> dict: |
| try: |
| st = time() |
| output = llama.create_chat_completion( |
| messages=[ |
| {"role": "system", "content": system}, |
| {"role": "user", "content": question}, |
| ], |
| temperature=temperature, |
| seed=seed, |
| stream=True |
| ) |
| for chunk in output: |
| """ |
| delta = chunk['choices'][0]['delta'] |
| if 'role' in delta: |
| print(delta['role'], end=': ') |
| elif 'content' in delta: |
| print(delta['content'], end='') |
| """ |
| print(chunk) |
| et = time() |
| output["time"] = et - st |
| |
| except Exception as e: |
| logger.error(f"Error in /complete endpoint: {e}") |
| return JSONResponse( |
| status_code=500, content={"message": "Internal Server Error"} |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
|
|
| uvicorn.run(app, host="0.0.0.0", port=7860) |