File size: 4,517 Bytes
83b24f3 e7dd500 1491ed2 cbc595a 83b24f3 e7dd500 83b24f3 e7dd500 83b24f3 1a97454 83b24f3 e7dd500 cbc595a e7dd500 cbc595a e7dd500 83b24f3 e7dd500 83b24f3 e7dd500 83b24f3 1491ed2 83b24f3 1491ed2 83b24f3 1491ed2 83b24f3 1491ed2 83b24f3 1491ed2 cbc595a 1491ed2 83b24f3 e7dd500 83b24f3 e7dd500 83b24f3 e7dd500 83b24f3 e7dd500 83b24f3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | from fastapi import FastAPI, HTTPException, Query
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from llama_cpp import Llama
import time
import io
import asyncio
import logging
# Initialize FastAPI app
app = FastAPI()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load the Phi-3 model using llama-cpp-python
model_path = "./Qwen2-1.5B-Instruct.IQ3_M.gguf" # Ensure this path is correct
try:
llama_model = Llama(
model_path=model_path,
n_ctx=4096,
n_threads=8,
n_gpu_layers=35
)
logger.info(f"Model loaded successfully from {model_path}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise RuntimeError(f"Could not load model from {model_path}")
# Response model based on OpenAI API structure
class Choice(BaseModel):
text: str
index: int
logprobs: int = None
finish_reason: str
class ResponseModel(BaseModel):
id: str
object: str
created: int
model: str
choices: list[Choice]
@app.get("/v1/completions")
async def create_completion(
prompt: str = Query(..., description="The prompt to complete"),
model: str = "default",
max_tokens: int = 256,
temperature: float = 0.7,
top_p: float = 1.0,
n: int = 1,
stream: bool = False
):
try:
logger.info(f"Received GET request with prompt: {prompt}")
if stream:
# Streaming response using GET
async def generate():
logger.info(f"Generating streaming response for prompt: {prompt}")
try:
response = llama_model(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
stop=None,
echo=True,
)
generated_text = response["choices"][0]["text"]
for chunk in generated_text.split('\n'):
yield chunk
await asyncio.sleep(0.1) # Simulate delay for streaming effect
except Exception as e:
logger.error(f"Error during model inference: {e}")
raise HTTPException(status_code=500, detail="Error generating response.")
return StreamingResponse(generate(), media_type="text/plain")
else:
# Non-streaming JSON response
logger.info(f"Generating non-streaming response for prompt: {prompt}")
response = llama_model(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
stop=None,
echo=True,
)
generated_text = response["choices"][0]["text"]
# Build the response in OpenAI's format
response_data = ResponseModel(
id="cmpl-xxxx", # Replace with unique ID generator if needed
object="text_completion",
created=int(time.time()),
model=model,
choices=[
Choice(
text=generated_text,
index=0,
logprobs=None,
finish_reason="stop"
)
]
)
return JSONResponse(content=response_data.dict())
except Exception as e:
logger.error(f"Internal error: {e}")
raise HTTPException(status_code=500, detail="Server error occurred.")
# Handle root route to avoid 404 errors
@app.get("/")
async def root():
return {"message": "Welcome to the Phi-3 API"}
# Handle robots.txt to avoid unnecessary 404 errors
@app.get("/robots.txt")
async def robots():
return StreamingResponse(io.StringIO("User-agent: *\nDisallow: /"), media_type="text/plain")
# Test the model in isolation (optional, you can use this to debug locally)
@app.get("/test_model")
async def test_model():
test_prompt = "This is a test prompt."
try:
response = llama_model(prompt=test_prompt, max_tokens=10, temperature=0.7, echo=True)
return {"response": response["choices"][0]["text"]}
except Exception as e:
logger.error(f"Model test error: {e}")
return {"error": str(e)}
# Main entry point
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|