MiniMax Agent commited on
Commit
ee3c612
·
1 Parent(s): 7ae4b71

v5: Minimal lazy-loading architecture for instant startup

Browse files

- Remove ALL heavy imports from module level (torch, transformers)
- Use background thread to load model after server starts
- Server responds immediately to health checks
- API returns 503 if model is still loading
- Fixes 30-minute timeout issue on Hugging Face Spaces

Files changed (1) hide show
  1. app.py +111 -212
app.py CHANGED
@@ -1,186 +1,57 @@
1
  """
2
- OpenELM OpenAI & Anthropic API Compatible Wrapper
3
-
4
- This version properly handles OpenELM's custom configuration and tokenizer.
5
  """
6
 
7
  import uuid
 
8
  import sys
9
- import subprocess
 
 
10
  from contextlib import asynccontextmanager
11
  from typing import AsyncIterator, List, Optional, Dict, Any
12
 
13
- import torch
14
  from fastapi import FastAPI, HTTPException, Request
15
- from fastapi.responses import JSONResponse, StreamingResponse
16
  from fastapi.middleware.cors import CORSMiddleware
17
  from pydantic import BaseModel, Field
18
- from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast
19
- from huggingface_hub import hf_hub_download
20
- import os
21
-
22
- # Import for streaming
23
- from transformers import TextIteratorStreamer
24
- from threading import Thread
25
 
26
 
27
- # Global model and tokenizer references
28
- model = None
29
- tokenizer = None
30
- model_loaded = False
31
- model_id = "apple/OpenELM-450M-Instruct"
32
-
33
-
34
- def install_sentencepiece():
35
- """Install SentencePiece if not available."""
36
- try:
37
- import sentencepiece
38
- return True
39
- except ImportError:
40
- print("Installing SentencePiece...")
41
- try:
42
- subprocess.run([sys.executable, "-m", "pip", "install", "sentencepiece", "--quiet"], check=True)
43
- print("SentencePiece installed successfully")
44
- return True
45
- except subprocess.CalledProcessError:
46
- print("Failed to install SentencePiece")
47
- return False
48
-
49
-
50
- def register_openelm_config():
51
- """Register OpenELM configuration with transformers."""
52
- try:
53
- # Try to import and register the config
54
- from transformers import AutoConfig, LlamaConfig
55
-
56
- # Download the OpenELM config
57
- config_path = hf_hub_download(
58
- repo_id=model_id,
59
- filename="configuration_openelm.py",
60
- repo_type="model"
61
- )
62
-
63
- # Add to path and import
64
- config_dir = os.path.dirname(config_path)
65
- if config_dir not in sys.path:
66
- sys.path.insert(0, config_dir)
67
-
68
- # The config file should have the OpenELMConfig class
69
- # We'll use LlamaConfig as a base since OpenELM is similar to LLaMA
70
- print("OpenELM configuration registered (using LLaMA-compatible loading)")
71
- return True
72
-
73
- except Exception as e:
74
- print(f"Could not register OpenELM config: {e}")
75
- return False
76
 
77
 
78
- def load_tokenizer():
79
- """
80
- Load tokenizer with multiple fallback strategies.
81
- OpenELM uses a custom configuration that transformers doesn't natively support.
82
- """
83
- print("Loading tokenizer...")
84
 
85
- # Install sentencepiece first
86
- install_sentencepiece()
87
-
88
- # Strategy 1: Try using the tokenizer files directly
89
  try:
90
- from transformers import LlamaTokenizerFast
 
 
 
91
 
92
- # Download tokenizer files
93
- tokenizer_file = hf_hub_download(
94
- repo_id=model_id,
95
- filename="tokenizer.json",
96
- repo_type="model"
97
- )
98
 
99
- tokenizer = LlamaTokenizerFast(
100
- tokenizer_file=tokenizer_file,
101
- trust_remote_code=True
102
- )
103
- print(" Loaded tokenizer using tokenizer.json")
104
- return tokenizer
105
-
106
- except Exception as e:
107
- print(f" Strategy 1 failed: {e}")
108
-
109
- # Strategy 2: Try LlamaTokenizer with local files
110
- try:
111
- # Download vocab and merges
112
- vocab_file = hf_hub_download(
113
- repo_id=model_id,
114
- filename="vocab.txt",
115
- repo_type="model"
116
- )
117
-
118
- try:
119
- merges_file = hf_hub_download(
120
- repo_id=model_id,
121
- filename="merges.txt",
122
- repo_type="model"
123
- )
124
- tokenizer = LlamaTokenizer(
125
- vocab_file=vocab_file,
126
- merges_file=merges_file,
127
- trust_remote_code=True
128
- )
129
- except:
130
- tokenizer = LlamaTokenizer(
131
- vocab_file=vocab_file,
132
- trust_remote_code=True
133
- )
134
 
135
- print(" Loaded tokenizer using vocab.txt")
136
- return tokenizer
137
 
138
- except Exception as e:
139
- print(f" Strategy 2 failed: {e}")
140
-
141
- # Strategy 3: Try AutoTokenizer with use_fast=False
142
- try:
143
  tokenizer = AutoTokenizer.from_pretrained(
144
  model_id,
145
- trust_remote_code=True,
146
- use_fast=False
147
  )
148
- print(" Loaded tokenizer using AutoTokenizer (slow)")
149
- return tokenizer
150
-
151
- except Exception as e:
152
- print(f" Strategy 3 failed: {e}")
153
-
154
- # Strategy 4: Use a basic GPT-2 style tokenizer
155
- print(" Using fallback tokenizer")
156
- tokenizer = PreTrainedTokenizerFast(
157
- tokenizer_file=None,
158
- bos_token="<s>",
159
- eos_token="</s>",
160
- unk_token="<unk>",
161
- pad_token="<pad>"
162
- )
163
-
164
- return tokenizer
165
-
166
-
167
- def load_model():
168
- """
169
- Load the OpenELM model.
170
- """
171
- global model, tokenizer, model_loaded
172
-
173
- if model_loaded:
174
- return True
175
-
176
- print("Initializing OpenELM model...")
177
-
178
- try:
179
- # Load tokenizer
180
- print(" Loading tokenizer...")
181
- tokenizer = load_tokenizer()
182
 
183
- # Ensure pad token is set
184
  if tokenizer.pad_token is None:
185
  tokenizer.pad_token = tokenizer.eos_token
186
  if tokenizer.bos_token is None:
@@ -188,8 +59,10 @@ def load_model():
188
  if tokenizer.eos_token is None:
189
  tokenizer.eos_token = "</s>"
190
 
191
- print(" Loading model...")
192
- # Load model with simplified parameters
 
 
193
  model = AutoModelForCausalLM.from_pretrained(
194
  model_id,
195
  torch_dtype=torch.float32,
@@ -198,45 +71,53 @@ def load_model():
198
  )
199
 
200
  model.eval()
201
- model_loaded = True
202
-
203
- print(f" Model loaded successfully!")
204
- print(f" Model device: {next(model.parameters()).device}")
205
- return True
206
 
207
  except Exception as e:
208
- print(f" Error loading model: {e}")
209
- import traceback
210
- traceback.print_exc()
211
- return False
212
 
213
 
214
  @asynccontextmanager
215
  async def lifespan(app: FastAPI) -> AsyncIterator:
216
- """Application lifespan with lazy loading."""
217
- global model, tokenizer, model_loaded
218
 
219
- print("OpenELM API Ready (lazy loading)")
 
 
 
220
  print("Endpoints:")
221
  print(" POST /v1/chat/completions - OpenAI format")
222
  print(" POST /v1/messages - Anthropic format")
223
  print(" GET /health - Check model status")
 
 
 
 
 
224
 
225
  yield
226
 
227
- # Cleanup
228
- if model is not None:
229
- del model
230
- if tokenizer is not None:
231
- del tokenizer
232
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
 
 
233
 
234
 
235
  # Create FastAPI app
 
236
  app = FastAPI(
237
  title="OpenELM OpenAI API",
238
  description="OpenAI and Anthropic API compatible wrapper for OpenELM models",
239
- version="2.1.0",
240
  lifespan=lifespan
241
  )
242
 
@@ -374,19 +255,17 @@ def format_prompt_for_openelm(messages: List[Message], system: Optional[str] = N
374
  return "\n\n".join(prompt_parts)
375
 
376
 
377
- def count_tokens(text: str) -> int:
378
- """Estimate token count."""
379
- if tokenizer:
380
- try:
381
- return len(tokenizer.encode(text))
382
- except:
383
- pass
384
- return max(1, len(text) // 4)
385
 
386
 
387
- def truncate_prompt(prompt: str, max_tokens: int, system: Optional[str] = None) -> str:
388
  """Truncate prompt to fit within context window."""
389
- current_tokens = count_tokens(prompt)
390
 
391
  if current_tokens <= max_tokens:
392
  return prompt
@@ -402,7 +281,7 @@ def truncate_prompt(prompt: str, max_tokens: int, system: Optional[str] = None)
402
  for line in reversed(lines):
403
  truncated_lines.insert(0, line)
404
  test_prompt = "\n\n".join([system_line] + truncated_lines) if system_line else "\n\n".join(truncated_lines)
405
- if count_tokens(test_prompt) <= max_tokens:
406
  break
407
 
408
  if system_line:
@@ -434,30 +313,41 @@ def extract_assistant_response(generated_text: str) -> str:
434
 
435
  @app.get("/", tags=["Root"])
436
  async def root():
 
437
  return {
438
- "name": "OpenELM OpenAI API",
439
- "version": "2.1.0",
440
- "status": "ready" if model_loaded else "initializing",
441
- "model_loaded": model_loaded,
442
  "endpoints": {
443
  "chat": "POST /v1/chat/completions",
444
  "messages": "POST /v1/messages",
445
  "health": "GET /health"
446
  },
447
- "note": "Model loads on first request"
448
  }
449
 
450
 
451
  @app.get("/health", tags=["Health"])
452
  async def health_check():
453
- return {
454
- "status": "healthy" if model_loaded else "initializing",
455
- "model_loaded": model_loaded
456
- }
 
 
 
 
 
 
 
 
 
457
 
458
 
459
  @app.get("/v1/models", response_model=OpenAIModelListResponse, tags=["Models"])
460
  async def list_models():
 
461
  return OpenAIModelListResponse(
462
  data=[
463
  OpenAIModelInfo(
@@ -472,11 +362,13 @@ async def list_models():
472
  @app.post("/v1/chat/completions", tags=["OpenAI"])
473
  async def create_chat_completion(request: ChatCompletionRequest):
474
  """Create chat completion (OpenAI API format)."""
475
- global model, tokenizer, model_loaded
 
 
 
476
 
477
- if not model_loaded:
478
- if not load_model():
479
- raise HTTPException(status_code=503, detail="Failed to load model")
480
 
481
  try:
482
  system_message = None
@@ -490,7 +382,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
490
 
491
  prompt = format_prompt_for_openelm(formatted_messages, system_message)
492
  max_tokens = request.max_tokens or 1024
493
- prompt = truncate_prompt(prompt, 2048 - max_tokens, system_message)
494
 
495
  inputs = tokenizer(prompt, return_tensors="pt")
496
  input_tokens = len(inputs.input_ids[0])
@@ -510,6 +402,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
510
  if request.top_p is not None:
511
  gen_params["top_p"] = request.top_p
512
 
 
513
  with torch.no_grad():
514
  outputs = model.generate(
515
  **inputs,
@@ -520,7 +413,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
520
 
521
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
522
  response_text = extract_assistant_response(generated_text)
523
- output_tokens = count_tokens(response_text)
524
 
525
  response_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
526
  timestamp = int(uuid.uuid1().time)
@@ -552,11 +445,13 @@ async def create_chat_completion(request: ChatCompletionRequest):
552
  @app.post("/v1/messages", response_model=MessageResponse, tags=["Messages"])
553
  async def create_message(params: MessageCreateParams):
554
  """Create message (Anthropic API format)."""
555
- global model, tokenizer, model_loaded
 
 
 
556
 
557
- if not model_loaded:
558
- if not load_model():
559
- raise HTTPException(status_code=503, detail="Failed to load model")
560
 
561
  try:
562
  formatted_messages = []
@@ -567,7 +462,7 @@ async def create_message(params: MessageCreateParams):
567
  formatted_messages.append(Message(role=msg.role, content=content))
568
 
569
  prompt = format_prompt_for_openelm(formatted_messages, params.system)
570
- prompt = truncate_prompt(prompt, 2048 - params.max_tokens, params.system)
571
 
572
  inputs = tokenizer(prompt, return_tensors="pt")
573
  input_tokens = len(inputs.input_ids[0])
@@ -587,6 +482,7 @@ async def create_message(params: MessageCreateParams):
587
  if params.top_p is not None:
588
  gen_params["top_p"] = params.top_p
589
 
 
590
  with torch.no_grad():
591
  outputs = model.generate(
592
  **inputs,
@@ -597,7 +493,7 @@ async def create_message(params: MessageCreateParams):
597
 
598
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
599
  response_text = extract_assistant_response(generated_text)
600
- output_tokens = count_tokens(response_text)
601
 
602
  return MessageResponse(
603
  id=f"msg_{uuid.uuid4().hex[:8]}",
@@ -619,7 +515,10 @@ async def create_message(params: MessageCreateParams):
619
  if __name__ == "__main__":
620
  import uvicorn
621
 
622
- port = int(os.environ.get("PORT", 8000))
 
 
 
623
 
624
  uvicorn.run(
625
  "app:app",
 
1
  """
2
+ OpenELM OpenAI & Anthropic API Compatible Wrapper - v5
3
+ Minimal lazy-loading architecture for instant startup.
4
+ Heavy imports (torch, transformers) are deferred to a background thread.
5
  """
6
 
7
  import uuid
8
+ import os
9
  import sys
10
+ import time
11
+ import asyncio
12
+ import threading
13
  from contextlib import asynccontextmanager
14
  from typing import AsyncIterator, List, Optional, Dict, Any
15
 
 
16
  from fastapi import FastAPI, HTTPException, Request
17
+ from fastapi.responses import JSONResponse
18
  from fastapi.middleware.cors import CORSMiddleware
19
  from pydantic import BaseModel, Field
 
 
 
 
 
 
 
20
 
21
 
22
+ # Global state for lazy loading
23
+ # This allows the server to respond immediately while model loads in background
24
+ global_state = {
25
+ "status": "INITIALIZING", # INITIALIZING -> LOADING -> READY -> ERROR
26
+ "model": None,
27
+ "tokenizer": None,
28
+ "error": None
29
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
+ def model_loader_thread():
33
+ """Load model in background thread to avoid blocking startup."""
34
+ global global_state
 
 
 
35
 
 
 
 
 
36
  try:
37
+ # Import heavy libraries INSIDE the thread
38
+ import torch
39
+ import sys
40
+ from transformers import AutoTokenizer, AutoModelForCausalLM
41
 
42
+ from huggingface_hub import hf_hub_download
 
 
 
 
 
43
 
44
+ global_state["status"] = "LOADING"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ model_id = "apple/OpenELM-450M-Instruct"
 
47
 
48
+ print("Loading tokenizer...")
 
 
 
 
49
  tokenizer = AutoTokenizer.from_pretrained(
50
  model_id,
51
+ trust_remote_code=True
 
52
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ # Set special tokens
55
  if tokenizer.pad_token is None:
56
  tokenizer.pad_token = tokenizer.eos_token
57
  if tokenizer.bos_token is None:
 
59
  if tokenizer.eos_token is None:
60
  tokenizer.eos_token = "</s>"
61
 
62
+ global_state["tokenizer"] = tokenizer
63
+ print("Tokenizer loaded")
64
+
65
+ print("Loading model...")
66
  model = AutoModelForCausalLM.from_pretrained(
67
  model_id,
68
  torch_dtype=torch.float32,
 
71
  )
72
 
73
  model.eval()
74
+ global_state["model"] = model
75
+ global_state["status"] = "READY"
76
+ print(f"Model loaded successfully! Device: {next(model.parameters()).device}")
 
 
77
 
78
  except Exception as e:
79
+ global_state["error"] = str(e)
80
+ global_state["status"] = "ERROR"
81
+ print(f"Error loading model: {e}")
 
82
 
83
 
84
  @asynccontextmanager
85
  async def lifespan(app: FastAPI) -> AsyncIterator:
86
+ """Application lifespan: Start background loader, then yield."""
87
+ global global_state
88
 
89
+ print("=" * 60)
90
+ print("OpenELM API v5 - Starting with background model loader")
91
+ print("=" * 60)
92
+ print("Server will respond immediately. Model loads in background.")
93
  print("Endpoints:")
94
  print(" POST /v1/chat/completions - OpenAI format")
95
  print(" POST /v1/messages - Anthropic format")
96
  print(" GET /health - Check model status")
97
+ print("=" * 60)
98
+
99
+ # Start background thread to load model
100
+ loader_thread = threading.Thread(target=model_loader_thread, daemon=True)
101
+ loader_thread.start()
102
 
103
  yield
104
 
105
+ # Cleanup on shutdown
106
+ if global_state["model"] is not None:
107
+ del global_state["model"]
108
+ if global_state["tokenizer"] is not None:
109
+ del global_state["tokenizer"]
110
+ if "torch" in sys.modules:
111
+ import torch
112
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
113
 
114
 
115
  # Create FastAPI app
116
+ # Note: No heavy imports at module level - only fastapi and pydantic
117
  app = FastAPI(
118
  title="OpenELM OpenAI API",
119
  description="OpenAI and Anthropic API compatible wrapper for OpenELM models",
120
+ version="5.0.0",
121
  lifespan=lifespan
122
  )
123
 
 
255
  return "\n\n".join(prompt_parts)
256
 
257
 
258
+ def count_tokens(text: str, tokenizer) -> int:
259
+ """Count tokens using the tokenizer."""
260
+ try:
261
+ return len(tokenizer.encode(text))
262
+ except:
263
+ return max(1, len(text) // 4)
 
 
264
 
265
 
266
+ def truncate_prompt(prompt: str, max_tokens: int, tokenizer, system: Optional[str] = None) -> str:
267
  """Truncate prompt to fit within context window."""
268
+ current_tokens = count_tokens(prompt, tokenizer)
269
 
270
  if current_tokens <= max_tokens:
271
  return prompt
 
281
  for line in reversed(lines):
282
  truncated_lines.insert(0, line)
283
  test_prompt = "\n\n".join([system_line] + truncated_lines) if system_line else "\n\n".join(truncated_lines)
284
+ if count_tokens(test_prompt, tokenizer) <= max_tokens:
285
  break
286
 
287
  if system_line:
 
313
 
314
  @app.get("/", tags=["Root"])
315
  async def root():
316
+ """Root endpoint with API information."""
317
  return {
318
+ "name": "OpenELM OpenAI API v5",
319
+ "version": "5.0.0",
320
+ "status": global_state["status"],
321
+ "model_loaded": global_state["status"] == "READY",
322
  "endpoints": {
323
  "chat": "POST /v1/chat/completions",
324
  "messages": "POST /v1/messages",
325
  "health": "GET /health"
326
  },
327
+ "note": "Model loads in background for instant startup"
328
  }
329
 
330
 
331
  @app.get("/health", tags=["Health"])
332
  async def health_check():
333
+ """Health check endpoint."""
334
+ if global_state["status"] == "READY":
335
+ return {"status": "healthy", "model_loaded": True}
336
+ elif global_state["status"] == "ERROR":
337
+ raise HTTPException(
338
+ status_code=503,
339
+ detail=f"Model failed to load: {global_state.get('error', 'Unknown error')}"
340
+ )
341
+ else:
342
+ raise HTTPException(
343
+ status_code=503,
344
+ detail="Model is still loading. Please retry in a few moments."
345
+ )
346
 
347
 
348
  @app.get("/v1/models", response_model=OpenAIModelListResponse, tags=["Models"])
349
  async def list_models():
350
+ """List available models (OpenAI format)."""
351
  return OpenAIModelListResponse(
352
  data=[
353
  OpenAIModelInfo(
 
362
  @app.post("/v1/chat/completions", tags=["OpenAI"])
363
  async def create_chat_completion(request: ChatCompletionRequest):
364
  """Create chat completion (OpenAI API format)."""
365
+ if global_state["status"] != "READY":
366
+ if global_state["status"] == "ERROR":
367
+ raise HTTPException(status_code=503, detail="Model failed to load")
368
+ raise HTTPException(status_code=503, detail="Model is still loading. Please retry.")
369
 
370
+ model = global_state["model"]
371
+ tokenizer = global_state["tokenizer"]
 
372
 
373
  try:
374
  system_message = None
 
382
 
383
  prompt = format_prompt_for_openelm(formatted_messages, system_message)
384
  max_tokens = request.max_tokens or 1024
385
+ prompt = truncate_prompt(prompt, 2048 - max_tokens, tokenizer, system_message)
386
 
387
  inputs = tokenizer(prompt, return_tensors="pt")
388
  input_tokens = len(inputs.input_ids[0])
 
402
  if request.top_p is not None:
403
  gen_params["top_p"] = request.top_p
404
 
405
+ import torch
406
  with torch.no_grad():
407
  outputs = model.generate(
408
  **inputs,
 
413
 
414
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
415
  response_text = extract_assistant_response(generated_text)
416
+ output_tokens = count_tokens(response_text, tokenizer)
417
 
418
  response_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
419
  timestamp = int(uuid.uuid1().time)
 
445
  @app.post("/v1/messages", response_model=MessageResponse, tags=["Messages"])
446
  async def create_message(params: MessageCreateParams):
447
  """Create message (Anthropic API format)."""
448
+ if global_state["status"] != "READY":
449
+ if global_state["status"] == "ERROR":
450
+ raise HTTPException(status_code=503, detail="Model failed to load")
451
+ raise HTTPException(status_code=503, detail="Model is still loading. Please retry.")
452
 
453
+ model = global_state["model"]
454
+ tokenizer = global_state["tokenizer"]
 
455
 
456
  try:
457
  formatted_messages = []
 
462
  formatted_messages.append(Message(role=msg.role, content=content))
463
 
464
  prompt = format_prompt_for_openelm(formatted_messages, params.system)
465
+ prompt = truncate_prompt(prompt, 2048 - params.max_tokens, tokenizer, params.system)
466
 
467
  inputs = tokenizer(prompt, return_tensors="pt")
468
  input_tokens = len(inputs.input_ids[0])
 
482
  if params.top_p is not None:
483
  gen_params["top_p"] = params.top_p
484
 
485
+ import torch
486
  with torch.no_grad():
487
  outputs = model.generate(
488
  **inputs,
 
493
 
494
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
495
  response_text = extract_assistant_response(generated_text)
496
+ output_tokens = count_tokens(response_text, tokenizer)
497
 
498
  return MessageResponse(
499
  id=f"msg_{uuid.uuid4().hex[:8]}",
 
515
  if __name__ == "__main__":
516
  import uvicorn
517
 
518
+ port = int(os.environ.get("PORT", 7860))
519
+
520
+ print(f"\nStarting OpenELM API v5 on port {port}...")
521
+ print("The server will respond immediately while the model loads in background.\n")
522
 
523
  uvicorn.run(
524
  "app:app",