owlninjam commited on
Commit
37dd92f
·
verified ·
1 Parent(s): 6647f42

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +143 -109
api.py CHANGED
@@ -1,34 +1,42 @@
1
- from fastapi import FastAPI, HTTPException, Depends, status
2
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
3
- from fastapi.middleware.cors import CORSMiddleware
4
- from pydantic import BaseModel
5
- from llama_cpp import Llama
6
  import os
7
  import uvicorn
8
- from typing import Optional, List, Dict, Union, Literal
9
  import time
10
  import json
11
- import uuid
12
  from datetime import datetime
 
13
 
14
- # Configuration
 
 
 
 
 
 
 
15
  VALID_API_KEYS = {
 
16
  "sk-adminkey02",
17
  "sk-testkey123",
18
  "sk-userkey456",
19
  "sk-demokey789"
20
  }
 
 
21
 
22
- # Global model variable
23
  llm = None
24
  security = HTTPBearer()
25
 
 
 
26
  class Message(BaseModel):
27
  role: Literal["system", "user", "assistant"]
28
  content: str
29
 
30
  class ChatCompletionRequest(BaseModel):
31
- model: str = "zephyr-quiklang-3b-4k"
32
  messages: List[Message]
33
  max_tokens: Optional[int] = 512
34
  temperature: Optional[float] = 0.7
@@ -40,7 +48,7 @@ class ChatCompletionRequest(BaseModel):
40
  class ChatCompletionChoice(BaseModel):
41
  index: int
42
  message: Message
43
- finish_reason: Literal["stop", "length", "content_filter"]
44
 
45
  class Usage(BaseModel):
46
  prompt_tokens: int
@@ -48,26 +56,28 @@ class Usage(BaseModel):
48
  total_tokens: int
49
 
50
  class ChatCompletionResponse(BaseModel):
51
- id: str
52
  object: str = "chat.completion"
53
- created: int
54
- model: str
55
  choices: List[ChatCompletionChoice]
56
  usage: Usage
57
 
58
- class Model(BaseModel):
59
  id: str
60
  object: str = "model"
61
- created: int
62
- owned_by: str
63
 
64
  class ModelsResponse(BaseModel):
65
  object: str = "list"
66
- data: List[Model]
 
 
67
 
68
  app = FastAPI(
69
- title="Zephyr Quiklang OpenAI API",
70
- description="OpenAI-compatible API for Zephyr-Quiklang-3B-4K",
71
  version="1.0.0",
72
  docs_url="/v1/docs",
73
  redoc_url="/v1/redoc"
@@ -81,127 +91,151 @@ app.add_middleware(
81
  allow_headers=["*"],
82
  )
83
 
 
 
84
  def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
85
  if credentials.credentials not in VALID_API_KEYS:
86
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key")
 
 
 
87
  return credentials.credentials
88
 
 
 
 
89
  def load_model():
90
  global llm
91
- model_path = "zephyr-quiklang-3b-4k.Q4_K_M.gguf"
92
-
93
- if not os.path.exists(model_path):
94
- raise Exception(f"Model file {model_path} not found!")
95
-
96
  llm = Llama(
97
- model_path=model_path,
98
- n_ctx=4096,
99
  n_threads=2,
100
  n_batch=512,
101
  verbose=False,
102
  use_mlock=True,
103
  n_gpu_layers=0,
104
  )
 
 
 
105
 
106
  def format_messages(messages: List[Message]) -> str:
107
- formatted = ""
 
 
 
108
  for message in messages:
109
- formatted += f"<|im_start|>{message.role}\n{message.content}\n<|im_end|>\n"
110
- formatted += "<|im_start|>assistant\n"
111
- return formatted
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  def count_tokens_rough(text: str) -> int:
 
114
  return len(text.split())
115
 
116
- @app.on_event("startup")
117
- async def startup_event():
118
- print("🚀 Starting Zephyr Quiklang API...")
119
- load_model()
120
- print(" Model loaded.")
 
121
 
122
  @app.get("/v1/models", response_model=ModelsResponse)
123
  async def list_models(api_key: str = Depends(verify_api_key)):
124
- return ModelsResponse(data=[
125
- Model(
126
- id="zephyr-quiklang-3b-4k",
127
- created=int(datetime.now().timestamp()),
128
- owned_by="local"
129
- )
130
- ])
131
-
132
- @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
133
- async def create_chat_completion(request: ChatCompletionRequest, api_key: str = Depends(verify_api_key)):
134
  if llm is None:
135
- raise HTTPException(status_code=503, detail="Model not loaded")
136
 
137
  prompt = format_messages(request.messages)
138
- prompt_tokens = count_tokens_rough(prompt)
139
- start_time = time.time()
140
-
141
- try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  response = llm(
143
  prompt,
144
  max_tokens=request.max_tokens,
145
  temperature=request.temperature,
146
  top_p=request.top_p,
147
- stop=["<|im_end|>", "<|im_start|>"] + (request.stop or []),
148
  echo=False
149
  )
150
- except Exception as e:
151
- raise HTTPException(status_code=500, detail=f"Error generating response: {str(e)}")
152
-
153
- end_time = time.time()
154
- generation_time = end_time - start_time
155
- response_text = response['choices'][0]['text'].strip()
156
- completion_tokens = count_tokens_rough(response_text)
157
-
158
- return ChatCompletionResponse(
159
- id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
160
- created=int(time.time()),
161
- model=request.model,
162
- choices=[
163
- ChatCompletionChoice(
164
- index=0,
165
- message=Message(role="assistant", content=response_text),
166
- finish_reason="stop"
167
  )
168
- ],
169
- usage=Usage(
170
- prompt_tokens=prompt_tokens,
171
- completion_tokens=completion_tokens,
172
- total_tokens=prompt_tokens + completion_tokens
173
  )
174
- )
175
-
176
- @app.get("/v1/health")
177
- async def health_check():
178
- if llm is None:
179
- raise HTTPException(status_code=503, detail="Model not loaded")
180
- return {
181
- "status": "healthy",
182
- "model_loaded": True,
183
- "model": "zephyr-quiklang-3b-4k",
184
- "timestamp": datetime.now().isoformat()
185
- }
186
-
187
- @app.get("/v1")
188
- async def api_info():
189
- return {
190
- "message": "Zephyr Quiklang OpenAI-Compatible API",
191
- "model": "zephyr-quiklang-3b-4k (Q4_K_M)",
192
- "endpoints": {
193
- "chat_completions": "/v1/chat/completions",
194
- "models": "/v1/models",
195
- "health": "/v1/health",
196
- "docs": "/v1/docs"
197
- },
198
- "authentication": {
199
- "required": True,
200
- "type": "Bearer token",
201
- "valid_keys": list(VALID_API_KEYS)
202
- },
203
- "performance": {
204
- "context_length": 4096,
205
- "expected_speed": "2–8 tok/s (CPU)"
206
- }
207
- }
 
1
+ # api.py
 
 
 
 
2
  import os
3
  import uvicorn
4
+ import uuid
5
  import time
6
  import json
 
7
  from datetime import datetime
8
+ from typing import Optional, List, Union, Literal
9
 
10
+ from fastapi import FastAPI, HTTPException, Depends, status
11
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+ from fastapi.responses import StreamingResponse
14
+ from pydantic import BaseModel, Field
15
+ from llama_cpp import Llama
16
+
17
+ # --- Configuration for NEW Model ---
18
  VALID_API_KEYS = {
19
+ # You can keep the same keys or change them
20
  "sk-adminkey02",
21
  "sk-testkey123",
22
  "sk-userkey456",
23
  "sk-demokey789"
24
  }
25
+ MODEL_PATH = "zephyr-quiklang-3b-4k.Q4_K_M.gguf"
26
+ MODEL_NAME = "zephyr-quiklang-3b-4k"
27
 
28
+ # --- Global Model Variable ---
29
  llm = None
30
  security = HTTPBearer()
31
 
32
+ # --- Pydantic Models for OpenAI Compatibility (No changes needed here) ---
33
+
34
  class Message(BaseModel):
35
  role: Literal["system", "user", "assistant"]
36
  content: str
37
 
38
  class ChatCompletionRequest(BaseModel):
39
+ model: str = MODEL_NAME
40
  messages: List[Message]
41
  max_tokens: Optional[int] = 512
42
  temperature: Optional[float] = 0.7
 
48
  class ChatCompletionChoice(BaseModel):
49
  index: int
50
  message: Message
51
+ finish_reason: Optional[Literal["stop", "length"]] = None
52
 
53
  class Usage(BaseModel):
54
  prompt_tokens: int
 
56
  total_tokens: int
57
 
58
  class ChatCompletionResponse(BaseModel):
59
+ id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}")
60
  object: str = "chat.completion"
61
+ created: int = Field(default_factory=lambda: int(time.time()))
62
+ model: str = MODEL_NAME
63
  choices: List[ChatCompletionChoice]
64
  usage: Usage
65
 
66
+ class ModelData(BaseModel):
67
  id: str
68
  object: str = "model"
69
+ created: int = Field(default_factory=lambda: int(time.time()))
70
+ owned_by: str = "user"
71
 
72
  class ModelsResponse(BaseModel):
73
  object: str = "list"
74
+ data: List[ModelData]
75
+
76
+ # --- FastAPI App Initialization ---
77
 
78
  app = FastAPI(
79
+ title="Zephyr-3B OpenAI-Compatible API",
80
+ description=f"An OpenAI-compatible API for the {MODEL_NAME} model.",
81
  version="1.0.0",
82
  docs_url="/v1/docs",
83
  redoc_url="/v1/redoc"
 
91
  allow_headers=["*"],
92
  )
93
 
94
+ # --- Dependency for API Key Verification ---
95
+
96
  def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
97
  if credentials.credentials not in VALID_API_KEYS:
98
+ raise HTTPException(
99
+ status_code=status.HTTP_401_UNAUTHORIZED,
100
+ detail="Invalid or missing API key"
101
+ )
102
  return credentials.credentials
103
 
104
+ # --- Model Loading ---
105
+
106
+ @app.on_event("startup")
107
  def load_model():
108
  global llm
109
+ if not os.path.exists(MODEL_PATH):
110
+ raise FileNotFoundError(f"Model file not found at {MODEL_PATH}")
111
+
112
+ print("🚀 Loading GGUF model...")
 
113
  llm = Llama(
114
+ model_path=MODEL_PATH,
115
+ n_ctx=4096, # Set to the model's 4K context limit
116
  n_threads=2,
117
  n_batch=512,
118
  verbose=False,
119
  use_mlock=True,
120
  n_gpu_layers=0,
121
  )
122
+ print("✅ Model loaded successfully!")
123
+
124
+ # --- Helper Functions ---
125
 
126
  def format_messages(messages: List[Message]) -> str:
127
+ """Formats messages for the Zephyr chat template."""
128
+ prompt = ""
129
+ # Zephyr template requires a system prompt, even if empty.
130
+ system_message_found = False
131
  for message in messages:
132
+ if message.role == "system":
133
+ prompt += f"<|system|>\n{message.content}</s>\n"
134
+ system_message_found = True
135
+ break
136
+ if not system_message_found:
137
+ prompt += "<|system|>\n</s>\n"
138
+
139
+ for message in messages:
140
+ if message.role == "user":
141
+ prompt += f"<|user|>\n{message.content}</s>\n"
142
+ elif message.role == "assistant":
143
+ prompt += f"<|assistant|>\n{message.content}</s>\n"
144
+
145
+ # Add the final prompt for the assistant to begin generating
146
+ prompt += "<|assistant|>\n"
147
+ return prompt
148
 
149
  def count_tokens_rough(text: str) -> int:
150
+ """A rough approximation of token counting."""
151
  return len(text.split())
152
 
153
+ # --- API Endpoints ---
154
+
155
+ @app.get("/v1/health")
156
+ async def health_check():
157
+ """Health check endpoint."""
158
+ return {"status": "healthy", "model_loaded": llm is not None}
159
 
160
  @app.get("/v1/models", response_model=ModelsResponse)
161
  async def list_models(api_key: str = Depends(verify_api_key)):
162
+ """Lists the available models."""
163
+ return ModelsResponse(data=[ModelData(id=MODEL_NAME)])
164
+
165
+ @app.post("/v1/chat/completions")
166
+ async def create_chat_completion(
167
+ request: ChatCompletionRequest,
168
+ api_key: str = Depends(verify_api_key)
169
+ ):
170
+ """Creates a model response for the given chat conversation."""
 
171
  if llm is None:
172
+ raise HTTPException(status_code=503, detail="Model is not loaded yet")
173
 
174
  prompt = format_messages(request.messages)
175
+ stop_tokens = ["</s>"] # The stop token for Zephyr is </s>
176
+ if isinstance(request.stop, str):
177
+ stop_tokens.append(request.stop)
178
+ elif isinstance(request.stop, list):
179
+ stop_tokens.extend(request.stop)
180
+
181
+ # Streaming response
182
+ if request.stream:
183
+ async def stream_generator():
184
+ completion_id = f"chatcmpl-{uuid.uuid4().hex}"
185
+ created_time = int(time.time())
186
+ stream = llm(
187
+ prompt,
188
+ max_tokens=request.max_tokens,
189
+ temperature=request.temperature,
190
+ top_p=request.top_p,
191
+ stop=stop_tokens,
192
+ stream=True,
193
+ echo=False
194
+ )
195
+ for output in stream:
196
+ if 'choices' in output and len(output['choices']) > 0:
197
+ delta_content = output['choices'][0].get('text', '')
198
+ chunk = {
199
+ "id": completion_id,
200
+ "object": "chat.completion.chunk",
201
+ "created": created_time,
202
+ "model": MODEL_NAME,
203
+ "choices": [{"index": 0, "delta": {"content": delta_content}, "finish_reason": None}]
204
+ }
205
+ yield f"data: {json.dumps(chunk)}\n\n"
206
+ final_chunk = {
207
+ "id": completion_id, "object": "chat.completion.chunk", "created": created_time,
208
+ "model": MODEL_NAME, "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
209
+ }
210
+ yield f"data: {json.dumps(final_chunk)}\n\n"
211
+ yield "data: [DONE]\n\n"
212
+ return StreamingResponse(stream_generator(), media_type="text/event-stream")
213
+
214
+ # Non-streaming response
215
+ else:
216
  response = llm(
217
  prompt,
218
  max_tokens=request.max_tokens,
219
  temperature=request.temperature,
220
  top_p=request.top_p,
221
+ stop=stop_tokens,
222
  echo=False
223
  )
224
+ response_text = response['choices'][0]['text'].strip()
225
+ prompt_tokens = count_tokens_rough(prompt)
226
+ completion_tokens = count_tokens_rough(response_text)
227
+ return ChatCompletionResponse(
228
+ model=MODEL_NAME,
229
+ choices=[
230
+ ChatCompletionChoice(
231
+ index=0,
232
+ message=Message(role="assistant", content=response_text),
233
+ finish_reason="stop"
234
+ )
235
+ ],
236
+ usage=Usage(
237
+ prompt_tokens=prompt_tokens,
238
+ completion_tokens=completion_tokens,
239
+ total_tokens=prompt_tokens + completion_tokens
 
240
  )
 
 
 
 
 
241
  )