Kyryll Kochkin commited on
Commit
698373a
·
1 Parent(s): 2ad2929

added GPT4-dev-177M-1511-Instruct

Browse files
README.md CHANGED
@@ -10,7 +10,7 @@ pinned: false
10
  # GPT3dev OpenAI-Compatible API
11
  **more detailed documentation is hoeeted on [DeepWiki](https://deepwiki.com/krll-corp/gpt3dev-api)**
12
 
13
- A production-ready FastAPI server that mirrors the OpenAI REST API surface while proxying requests to Hugging Face causal language models. The service implements the `/v1/completions`, `/v1/models`, and `/v1/embeddings` endpoints with full support for streaming Server-Sent Events (SSE) and OpenAI-style usage accounting. A `/v1/chat/completions` stub is included but currently returns a structured 501 error because the available models are completion-only.
14
 
15
  ## The API is hosted on HuggingFace Spaces:
16
  ```bash
@@ -112,7 +112,23 @@ curl http://localhost:7860/v1/completions \
112
 
113
  ### Chat Completions
114
 
115
- The `/v1/chat/completions` endpoint is currently disabled and returns a 501 Not Implemented error instructing clients to use `/v1/completions` instead. I don't have any chat-tuned models now, but I plan to enable this endpoint later with openai harmony - tuned models.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  ### Embeddings
118
 
 
10
  # GPT3dev OpenAI-Compatible API
11
  **more detailed documentation is hoeeted on [DeepWiki](https://deepwiki.com/krll-corp/gpt3dev-api)**
12
 
13
+ A production-ready FastAPI server that mirrors the OpenAI REST API surface while proxying requests to Hugging Face causal language models. The service implements the `/v1/completions`, `/v1/chat/completions`, `/v1/models`, and `/v1/embeddings` endpoints with full support for streaming Server-Sent Events (SSE) and OpenAI-style usage accounting. Chat completions are available for instruct-tuned models like `GPT4-dev-177M-1511-Instruct`.
14
 
15
  ## The API is hosted on HuggingFace Spaces:
16
  ```bash
 
112
 
113
  ### Chat Completions
114
 
115
+ The `/v1/chat/completions` endpoint is available for instruct-tuned models. Currently supported instruct models:
116
+
117
+ - `GPT4-dev-177M-1511-Instruct` - Instruction-tuned GPT-4-style model fine-tuned on HuggingFaceH4/no_robots
118
+
119
+ ```bash
120
+ curl http://localhost:7860/v1/chat/completions \
121
+ -H "Content-Type: application/json" \
122
+ -d '{
123
+ "model": "GPT4-dev-177M-1511-Instruct",
124
+ "messages": [
125
+ {"role": "user", "content": "Write a short welcome message for new contributors."}
126
+ ],
127
+ "max_tokens": 128
128
+ }'
129
+ ```
130
+
131
+ Non-instruct models will return an error directing users to use `/v1/completions` instead.
132
 
133
  ### Embeddings
134
 
app/core/engine.py CHANGED
@@ -140,14 +140,14 @@ class _ModelHandle:
140
  logger.info("Loading tokenizer for %s", spec.hf_repo)
141
  tokenizer = AutoTokenizer.from_pretrained(
142
  spec.hf_repo,
143
- use_auth_token=token,
144
  trust_remote_code=True,
145
  )
146
  if tokenizer.pad_token_id is None:
147
  tokenizer.pad_token = tokenizer.eos_token
148
  logger.info("Tokenizer ready in %.2fs", time.perf_counter() - t0)
149
  model_kwargs = {
150
- "use_auth_token": token,
151
  "trust_remote_code": True,
152
  }
153
  # Resolve preferred device early so we can adjust dtype if needed
@@ -183,11 +183,30 @@ class _ModelHandle:
183
  device_pref,
184
  " (device_map=auto)" if device_map else "",
185
  )
186
- model = AutoModelForCausalLM.from_pretrained(
187
- spec.hf_repo,
188
- device_map=device_map,
189
- **model_kwargs,
190
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  logger.info("Model ready in %.2fs", time.perf_counter() - t1)
192
  if device_map is None:
193
  model = model.to(device_pref)
@@ -212,11 +231,12 @@ class _ModelHandle:
212
  kwargs.pop("cache_position", None)
213
  kwargs.pop("encoder_attention_mask", None)
214
  kwargs.pop("attention_mask", None)
 
215
  return _orig_forward(*args, **kwargs)
216
 
217
  model.forward = MethodType(_forward_compat, model)
218
  # Also patch submodules whose forward signatures include
219
- # encoder_attention_mask to avoid duplicate passing (positional+kw)
220
  for _, module in model.named_modules():
221
  fwd = getattr(module, "forward", None)
222
  if not callable(fwd):
@@ -225,7 +245,12 @@ class _ModelHandle:
225
  sig = inspect.signature(fwd)
226
  except Exception:
227
  continue
228
- if "encoder_attention_mask" not in sig.parameters:
 
 
 
 
 
229
  continue
230
  orig_fwd = fwd
231
 
@@ -234,6 +259,7 @@ class _ModelHandle:
234
  kwargs.pop("encoder_attention_mask", None)
235
  kwargs.pop("attention_mask", None)
236
  kwargs.pop("cache_position", None)
 
237
  return orig(*args, **kwargs)
238
 
239
  return _sub_forward_compat
@@ -292,6 +318,27 @@ def _apply_stop_sequences(text: str, stop_sequences: Sequence[str]) -> tuple[str
292
  return text, "length"
293
 
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  def _prepare_inputs(
296
  handle: _ModelHandle,
297
  prompt: str,
 
140
  logger.info("Loading tokenizer for %s", spec.hf_repo)
141
  tokenizer = AutoTokenizer.from_pretrained(
142
  spec.hf_repo,
143
+ token=token,
144
  trust_remote_code=True,
145
  )
146
  if tokenizer.pad_token_id is None:
147
  tokenizer.pad_token = tokenizer.eos_token
148
  logger.info("Tokenizer ready in %.2fs", time.perf_counter() - t0)
149
  model_kwargs = {
150
+ "token": token,
151
  "trust_remote_code": True,
152
  }
153
  # Resolve preferred device early so we can adjust dtype if needed
 
183
  device_pref,
184
  " (device_map=auto)" if device_map else "",
185
  )
186
+ # Patch _load_pretrained_model to fix tie_weights incompatibility
187
+ # with newer transformers that pass missing_keys keyword argument
188
+ from transformers import modeling_utils
189
+ _orig_load_pretrained_func = modeling_utils.PreTrainedModel._load_pretrained_model.__func__
190
+
191
+ def _patched_load_pretrained_func(cls, model, *args, **kwargs):
192
+ # Patch model.tie_weights to accept and ignore unexpected kwargs
193
+ orig_tie_weights = model.tie_weights
194
+ def _compat_tie_weights(*tw_args, **tw_kwargs):
195
+ tw_kwargs.pop("missing_keys", None)
196
+ tw_kwargs.pop("recompute_mapping", None)
197
+ return orig_tie_weights(*tw_args, **tw_kwargs)
198
+ model.tie_weights = _compat_tie_weights
199
+ return _orig_load_pretrained_func(cls, model, *args, **kwargs)
200
+
201
+ modeling_utils.PreTrainedModel._load_pretrained_model = classmethod(_patched_load_pretrained_func)
202
+ try:
203
+ model = AutoModelForCausalLM.from_pretrained(
204
+ spec.hf_repo,
205
+ device_map=device_map,
206
+ **model_kwargs,
207
+ )
208
+ finally:
209
+ modeling_utils.PreTrainedModel._load_pretrained_model = classmethod(_orig_load_pretrained_func)
210
  logger.info("Model ready in %.2fs", time.perf_counter() - t1)
211
  if device_map is None:
212
  model = model.to(device_pref)
 
231
  kwargs.pop("cache_position", None)
232
  kwargs.pop("encoder_attention_mask", None)
233
  kwargs.pop("attention_mask", None)
234
+ kwargs.pop("head_mask", None)
235
  return _orig_forward(*args, **kwargs)
236
 
237
  model.forward = MethodType(_forward_compat, model)
238
  # Also patch submodules whose forward signatures include
239
+ # encoder_attention_mask or head_mask to avoid duplicate passing (positional+kw)
240
  for _, module in model.named_modules():
241
  fwd = getattr(module, "forward", None)
242
  if not callable(fwd):
 
245
  sig = inspect.signature(fwd)
246
  except Exception:
247
  continue
248
+ # Patch modules that have problematic parameters
249
+ needs_patch = any(
250
+ p in sig.parameters
251
+ for p in ("encoder_attention_mask", "head_mask")
252
+ )
253
+ if not needs_patch:
254
  continue
255
  orig_fwd = fwd
256
 
 
259
  kwargs.pop("encoder_attention_mask", None)
260
  kwargs.pop("attention_mask", None)
261
  kwargs.pop("cache_position", None)
262
+ kwargs.pop("head_mask", None)
263
  return orig(*args, **kwargs)
264
 
265
  return _sub_forward_compat
 
318
  return text, "length"
319
 
320
 
321
+ def apply_chat_template(
322
+ model_name: str,
323
+ messages: List[dict],
324
+ add_generation_prompt: bool = True,
325
+ ) -> str:
326
+ """Apply the tokenizer's chat template to format messages for instruct models."""
327
+ handle = _get_handle(model_name)
328
+ tokenizer = handle.tokenizer
329
+ if hasattr(tokenizer, "apply_chat_template"):
330
+ return tokenizer.apply_chat_template(
331
+ messages,
332
+ tokenize=False,
333
+ add_generation_prompt=add_generation_prompt,
334
+ )
335
+ # Fallback for tokenizers without chat_template
336
+ from .prompting import render_chat_prompt
337
+ from ..schemas.chat import ChatMessage
338
+ chat_messages = [ChatMessage(role=m["role"], content=m["content"]) for m in messages]
339
+ return render_chat_prompt(chat_messages)
340
+
341
+
342
  def _prepare_inputs(
343
  handle: _ModelHandle,
344
  prompt: str,
app/core/model_registry.py CHANGED
@@ -55,9 +55,29 @@ class ModelSpec:
55
  device: Optional[str] = None
56
  max_context_tokens: Optional[int] = None
57
  metadata: Optional[ModelMetadata] = None
 
58
 
59
 
60
  _DEFAULT_MODELS: List[ModelSpec] = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  ModelSpec(
62
  name="GPT4-dev-177M-1511",
63
  hf_repo="k050506koch/GPT4-dev-177M-1511",
@@ -230,6 +250,7 @@ def _load_registry_from_file(path: Path) -> Iterable[ModelSpec]:
230
  device=entry.get("device"),
231
  max_context_tokens=entry.get("max_context_tokens"),
232
  metadata=metadata,
 
233
  )
234
  )
235
  return specs
 
55
  device: Optional[str] = None
56
  max_context_tokens: Optional[int] = None
57
  metadata: Optional[ModelMetadata] = None
58
+ is_instruct: bool = False
59
 
60
 
61
  _DEFAULT_MODELS: List[ModelSpec] = [
62
+ ModelSpec(
63
+ name="GPT4-dev-177M-1511-Instruct",
64
+ hf_repo="k050506koch/GPT4-dev-177M-1511-Instruct",
65
+ dtype="float16",
66
+ device="auto",
67
+ max_context_tokens=512,
68
+ is_instruct=True,
69
+ metadata=ModelMetadata(
70
+ description="Instruction-tuned GPT-4-style model fine-tuned on HuggingFaceH4/no_robots conversational dataset.",
71
+ parameter_count="177M",
72
+ training_datasets="HuggingFaceH4/no_robots",
73
+ training_steps="1,200 SFT steps · AdamW optimizer · cosine LR schedule · assistant-only loss masking",
74
+ evaluation="25.75% MMLU, 34.20% HellaSwag (author reported)",
75
+ notes="First instruct model. Uses Harmony-style chat formatting with apply_chat_template. Requires trust_remote_code.",
76
+ sources=(
77
+ "https://huggingface.co/k050506koch/GPT4-dev-177M-1511-Instruct",
78
+ ),
79
+ ),
80
+ ),
81
  ModelSpec(
82
  name="GPT4-dev-177M-1511",
83
  hf_repo="k050506koch/GPT4-dev-177M-1511",
 
250
  device=entry.get("device"),
251
  max_context_tokens=entry.get("max_context_tokens"),
252
  metadata=metadata,
253
+ is_instruct=entry.get("is_instruct", False),
254
  )
255
  )
256
  return specs
app/routers/chat.py CHANGED
@@ -1,19 +1,165 @@
1
  """Chat completions endpoint."""
2
  from __future__ import annotations
3
 
 
 
 
 
 
 
4
  from fastapi import APIRouter
 
5
 
6
- from ..core.errors import feature_not_available
7
- from ..schemas.chat import ChatCompletionRequest, ChatCompletionResponse
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  router = APIRouter(prefix="/v1", tags=["chat"])
10
 
11
 
12
  @router.post("/chat/completions", response_model=ChatCompletionResponse)
13
  async def create_chat_completion(payload: ChatCompletionRequest) -> ChatCompletionResponse:
14
- """Return a structured error while chat completions are disabled."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- raise feature_not_available(
17
- "chat_completions",
18
- "Chat completions are currently disabled; please use /v1/completions instead.",
19
  )
 
1
  """Chat completions endpoint."""
2
  from __future__ import annotations
3
 
4
+ import asyncio
5
+ import json
6
+ import time
7
+ import uuid
8
+ from typing import Generator, List
9
+
10
  from fastapi import APIRouter
11
+ from fastapi.responses import StreamingResponse
12
 
13
+ from ..core import engine
14
+ from ..core.errors import model_not_found, openai_http_error
15
+ from ..core.model_registry import get_model_spec
16
+ from ..schemas.chat import (
17
+ ChatCompletionChoice,
18
+ ChatCompletionChunk,
19
+ ChatCompletionChunkChoice,
20
+ ChatCompletionChunkChoiceDelta,
21
+ ChatCompletionRequest,
22
+ ChatCompletionResponse,
23
+ ChatMessage,
24
+ )
25
+ from ..schemas.common import UsageInfo
26
 
27
  router = APIRouter(prefix="/v1", tags=["chat"])
28
 
29
 
30
  @router.post("/chat/completions", response_model=ChatCompletionResponse)
31
  async def create_chat_completion(payload: ChatCompletionRequest) -> ChatCompletionResponse:
32
+ """Generate chat completions using instruct-tuned models."""
33
+ try:
34
+ spec = get_model_spec(payload.model)
35
+ except KeyError:
36
+ raise model_not_found(payload.model)
37
+
38
+ if not spec.is_instruct:
39
+ raise openai_http_error(
40
+ 400,
41
+ f"Model '{payload.model}' is not an instruct model and cannot be used with chat completions. "
42
+ "Please use /v1/completions instead, or choose an instruct model like 'GPT4-dev-177M-1511-Instruct'.",
43
+ error_type="invalid_request_error",
44
+ param="model",
45
+ )
46
+
47
+ # Convert messages to dict format for apply_chat_template
48
+ messages_dict = [{"role": m.role, "content": m.content} for m in payload.messages]
49
+
50
+ # Apply chat template using tokenizer
51
+ prompt = engine.apply_chat_template(payload.model, messages_dict)
52
+
53
+ stop_sequences = payload.stop if isinstance(payload.stop, list) else (
54
+ [payload.stop] if payload.stop else []
55
+ )
56
+
57
+ if payload.stream:
58
+ return _streaming_chat_completion(payload, prompt, stop_sequences)
59
+
60
+ try:
61
+ result = await asyncio.to_thread(
62
+ engine.generate,
63
+ payload.model,
64
+ prompt,
65
+ temperature=payload.temperature,
66
+ top_p=payload.top_p,
67
+ max_tokens=payload.max_tokens,
68
+ stop=stop_sequences,
69
+ n=payload.n,
70
+ )
71
+ except Exception as exc:
72
+ raise openai_http_error(
73
+ 500,
74
+ f"Generation error: {exc}",
75
+ error_type="server_error",
76
+ code="generation_error",
77
+ )
78
+
79
+ choices: List[ChatCompletionChoice] = []
80
+ total_completion_tokens = 0
81
+ for idx, item in enumerate(result.completions):
82
+ total_completion_tokens += item.tokens
83
+ choices.append(
84
+ ChatCompletionChoice(
85
+ index=idx,
86
+ message=ChatMessage(role="assistant", content=item.text.strip()),
87
+ finish_reason=item.finish_reason,
88
+ )
89
+ )
90
+
91
+ usage = UsageInfo(
92
+ prompt_tokens=result.prompt_tokens,
93
+ completion_tokens=total_completion_tokens,
94
+ total_tokens=result.prompt_tokens + total_completion_tokens,
95
+ )
96
+ return ChatCompletionResponse(model=payload.model, choices=choices, usage=usage)
97
+
98
+
99
+ def _streaming_chat_completion(
100
+ payload: ChatCompletionRequest,
101
+ prompt: str,
102
+ stop_sequences: List[str],
103
+ ) -> StreamingResponse:
104
+ completion_id = f"chatcmpl-{uuid.uuid4().hex}"
105
+
106
+ def event_stream() -> Generator[bytes, None, None]:
107
+ stream = engine.create_stream(
108
+ payload.model,
109
+ prompt,
110
+ temperature=payload.temperature,
111
+ top_p=payload.top_p,
112
+ max_tokens=payload.max_tokens,
113
+ stop=stop_sequences,
114
+ )
115
+
116
+ # Send initial role delta
117
+ initial_chunk = ChatCompletionChunk(
118
+ id=completion_id,
119
+ created=int(time.time()),
120
+ model=payload.model,
121
+ choices=[
122
+ ChatCompletionChunkChoice(
123
+ index=0,
124
+ delta=ChatCompletionChunkChoiceDelta(role="assistant"),
125
+ finish_reason=None,
126
+ )
127
+ ],
128
+ )
129
+ yield f"data: {initial_chunk.model_dump_json()}\n\n".encode()
130
+
131
+ for token in stream.iter_tokens():
132
+ chunk = ChatCompletionChunk(
133
+ id=completion_id,
134
+ created=int(time.time()),
135
+ model=payload.model,
136
+ choices=[
137
+ ChatCompletionChunkChoice(
138
+ index=0,
139
+ delta=ChatCompletionChunkChoiceDelta(content=token),
140
+ finish_reason=None,
141
+ )
142
+ ],
143
+ )
144
+ yield f"data: {chunk.model_dump_json()}\n\n".encode()
145
+
146
+ # Send final chunk with finish_reason
147
+ final_chunk = ChatCompletionChunk(
148
+ id=completion_id,
149
+ created=int(time.time()),
150
+ model=payload.model,
151
+ choices=[
152
+ ChatCompletionChunkChoice(
153
+ index=0,
154
+ delta=ChatCompletionChunkChoiceDelta(),
155
+ finish_reason=stream.finish_reason,
156
+ )
157
+ ],
158
+ )
159
+ yield f"data: {final_chunk.model_dump_json()}\n\n".encode()
160
+ yield b"data: [DONE]\n\n"
161
 
162
+ return StreamingResponse(
163
+ event_stream(),
164
+ media_type="text/event-stream",
165
  )
tests/test_openai_compat.py CHANGED
@@ -220,7 +220,17 @@ def test_completions_handles_prompt_list(monkeypatch: pytest.MonkeyPatch) -> Non
220
  assert body["usage"]["prompt_tokens"] == len("Hello") + len("World")
221
 
222
 
223
- def test_chat_disabled() -> None:
 
 
 
 
 
 
 
 
 
 
224
  payload = ChatCompletionRequest.model_validate({
225
  "model": "GPT3-dev",
226
  "messages": [
@@ -230,8 +240,8 @@ def test_chat_disabled() -> None:
230
 
231
  with pytest.raises(HTTPException) as exc:
232
  asyncio.run(chat.create_chat_completion(payload))
233
- assert exc.value.status_code == 501
234
- assert exc.value.detail["code"] == "chat_completions_not_available"
235
 
236
 
237
  def test_embeddings_not_implemented() -> None:
 
220
  assert body["usage"]["prompt_tokens"] == len("Hello") + len("World")
221
 
222
 
223
+ def test_chat_rejects_non_instruct_model(monkeypatch: pytest.MonkeyPatch) -> None:
224
+ """Chat completions should reject non-instruct models with a 400 error."""
225
+ from app.core import model_registry
226
+
227
+ # Register a non-instruct model
228
+ monkeypatch.setattr(
229
+ model_registry,
230
+ "_registry",
231
+ {"GPT3-dev": ModelSpec(name="GPT3-dev", hf_repo="k050506koch/GPT3-dev", is_instruct=False)},
232
+ )
233
+
234
  payload = ChatCompletionRequest.model_validate({
235
  "model": "GPT3-dev",
236
  "messages": [
 
240
 
241
  with pytest.raises(HTTPException) as exc:
242
  asyncio.run(chat.create_chat_completion(payload))
243
+ assert exc.value.status_code == 400
244
+ assert "not an instruct model" in exc.value.detail["message"]
245
 
246
 
247
  def test_embeddings_not_implemented() -> None: