File size: 16,649 Bytes
9604400
ee3c612
 
 
9604400
91bc5ae
9604400
ee3c612
9604400
ee3c612
 
 
9604400
 
 
 
ee3c612
9604400
 
 
 
ee3c612
 
 
 
 
 
 
 
d36a46f
 
ee3c612
 
 
d36a46f
 
ee3c612
 
 
 
d36a46f
ee3c612
d36a46f
ee3c612
d36a46f
ee3c612
d36a46f
ee3c612
d36a46f
 
ee3c612
d36a46f
3daef91
ee3c612
3daef91
 
d36a46f
 
 
 
9604400
ee3c612
 
 
 
9604400
 
d36a46f
9604400
d36a46f
9604400
 
 
ee3c612
 
 
9604400
 
ee3c612
 
 
2aeb5c7
 
 
 
ee3c612
 
2aeb5c7
ee3c612
 
 
 
2aeb5c7
 
 
d36a46f
ee3c612
 
 
 
 
9604400
 
 
ee3c612
 
 
 
 
 
 
 
9604400
 
 
ee3c612
9604400
c126015
 
ee3c612
9604400
 
 
d36a46f
9604400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c126015
9604400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c126015
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9604400
 
d36a46f
2aeb5c7
9604400
 
 
 
 
 
 
 
 
 
d36a46f
9604400
 
 
 
d36a46f
9604400
 
 
 
 
 
 
 
 
 
ee3c612
 
 
 
 
 
9604400
 
ee3c612
9604400
ee3c612
9604400
 
 
 
 
 
 
 
 
 
 
 
 
 
2aeb5c7
ee3c612
9604400
 
 
 
 
 
 
2aeb5c7
 
 
d36a46f
 
 
 
 
 
 
 
 
 
 
 
 
9604400
d36a46f
9604400
 
 
 
 
 
ee3c612
9604400
ee3c612
 
 
 
9604400
d36a46f
 
 
2aeb5c7
ee3c612
9604400
 
 
 
 
ee3c612
 
 
 
 
 
 
 
 
 
 
 
 
9604400
 
2aeb5c7
9604400
ee3c612
c126015
 
 
 
 
 
 
 
 
 
 
 
d36a46f
 
ee3c612
 
 
 
d36a46f
ee3c612
 
c126015
 
 
 
 
 
 
 
 
d36a46f
c126015
 
 
ee3c612
c126015
 
 
 
 
 
 
d36a46f
2aeb5c7
 
 
 
 
 
 
 
 
 
c126015
ee3c612
c126015
 
 
 
 
 
 
 
 
 
ee3c612
c126015
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2aeb5c7
 
d36a46f
c126015
 
2aeb5c7
 
d36a46f
ee3c612
 
 
 
d36a46f
ee3c612
 
c126015
2aeb5c7
 
 
 
 
d36a46f
 
c126015
2aeb5c7
ee3c612
2aeb5c7
 
 
 
 
 
c126015
d36a46f
c126015
2aeb5c7
 
 
 
 
 
 
 
 
c126015
ee3c612
2aeb5c7
 
 
 
 
 
 
 
 
 
ee3c612
c126015
2aeb5c7
d36a46f
2aeb5c7
 
 
 
d36a46f
2aeb5c7
c126015
2aeb5c7
 
 
d36a46f
c126015
 
9604400
 
 
 
 
ee3c612
 
 
 
9604400
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
"""
OpenELM OpenAI & Anthropic API Compatible Wrapper - v5
Minimal lazy-loading architecture for instant startup.
Heavy imports (torch, transformers) are deferred to a background thread.
"""

import uuid
import os
import sys
import time
import asyncio
import threading
from contextlib import asynccontextmanager
from typing import AsyncIterator, List, Optional, Dict, Any

from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field


# Global state for lazy loading
# This allows the server to respond immediately while model loads in background
global_state = {
    "status": "INITIALIZING",  # INITIALIZING -> LOADING -> READY -> ERROR
    "model": None,
    "tokenizer": None,
    "error": None
}


def model_loader_thread():
    """Load model in background thread to avoid blocking startup."""
    global global_state
    
    try:
        # Import heavy libraries INSIDE the thread
        import torch
        import sys
        from transformers import AutoTokenizer, AutoModelForCausalLM
        
        from huggingface_hub import hf_hub_download
        
        global_state["status"] = "LOADING"
        
        model_id = "apple/OpenELM-450M-Instruct"
        
        print("Loading tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(
            model_id,
            trust_remote_code=True
        )
        
        # Set special tokens
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        if tokenizer.bos_token is None:
            tokenizer.bos_token = "<s>"
        if tokenizer.eos_token is None:
            tokenizer.eos_token = "</s>"
        
        global_state["tokenizer"] = tokenizer
        print("Tokenizer loaded")
        
        print("Loading model...")
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.float32,
            use_safetensors=True,
            trust_remote_code=True
        )
        
        model.eval()
        global_state["model"] = model
        global_state["status"] = "READY"
        print(f"Model loaded successfully! Device: {next(model.parameters()).device}")
        
    except Exception as e:
        global_state["error"] = str(e)
        global_state["status"] = "ERROR"
        print(f"Error loading model: {e}")


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator:
    """Application lifespan: Start background loader, then yield."""
    global global_state
    
    print("=" * 60)
    print("OpenELM API v5 - Starting with background model loader")
    print("=" * 60)
    print("Server will respond immediately. Model loads in background.")
    print("Endpoints:")
    print("  POST /v1/chat/completions - OpenAI format")
    print("  POST /v1/messages - Anthropic format")
    print("  GET /health - Check model status")
    print("=" * 60)
    
    # Start background thread to load model
    loader_thread = threading.Thread(target=model_loader_thread, daemon=True)
    loader_thread.start()
    
    yield
    
    # Cleanup on shutdown
    if global_state["model"] is not None:
        del global_state["model"]
    if global_state["tokenizer"] is not None:
        del global_state["tokenizer"]
    if "torch" in sys.modules:
        import torch
        torch.cuda.empty_cache() if torch.cuda.is_available() else None


# Create FastAPI app
# Note: No heavy imports at module level - only fastapi and pydantic
app = FastAPI(
    title="OpenELM OpenAI API",
    description="OpenAI and Anthropic API compatible wrapper for OpenELM models",
    version="5.0.0",
    lifespan=lifespan
)

# Add CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


# ==================== Pydantic Models ====================

class MessageContent(BaseModel):
    type: str = "text"
    text: str


class Message(BaseModel):
    role: str
    content: str | List[MessageContent]
    name: Optional[str] = None


class Usage(BaseModel):
    input_tokens: int = 0
    output_tokens: int = 0
    total_tokens: int = 0


class ContentBlock(BaseModel):
    type: str = "text"
    text: str


class MessageResponse(BaseModel):
    id: str
    type: str = "message"
    role: str = "assistant"
    content: List[ContentBlock]
    model: str
    stop_reason: Optional[str] = None
    stop_sequence: Optional[str] = None
    usage: Usage


class MessageCreateParams(BaseModel):
    model: str = "openelm-450m-instruct"
    messages: List[Message]
    system: Optional[str] = None
    max_tokens: int = Field(default=1024, ge=1, le=4096)
    temperature: Optional[float] = Field(default=None, ge=0.0, le=1.0)
    top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
    stream: Optional[bool] = False


class ChatMessage(BaseModel):
    role: str
    content: str
    name: Optional[str] = None


class ChatCompletionRequest(BaseModel):
    model: str = "openelm-450m-instruct"
    messages: List[ChatMessage]
    temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0)
    top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
    max_tokens: Optional[int] = Field(default=None, ge=1, le=4096)
    stream: Optional[bool] = False


class ChatCompletionChoice(BaseModel):
    index: int
    message: ChatMessage
    finish_reason: Optional[str] = None


class ChatCompletionUsage(BaseModel):
    prompt_tokens: int
    completion_tokens: int
    total_tokens: int


class ChatCompletionResponse(BaseModel):
    id: str
    object: str = "chat.completion"
    created: int
    model: str
    choices: List[ChatCompletionChoice]
    usage: ChatCompletionUsage


class OpenAIModelInfo(BaseModel):
    id: str
    object: str = "model"
    created: int = 0
    owned_by: str = "openelm"
    permission: List[Any] = []


class OpenAIModelListResponse(BaseModel):
    object: str = "list"
    data: List[OpenAIModelInfo]


# ==================== Helper Functions ====================

def format_prompt_for_openelm(messages: List[Message], system: Optional[str] = None) -> str:
    """Format messages into a prompt suitable for OpenELM."""
    prompt_parts = []
    
    if system:
        prompt_parts.append(f"[System: {system}]")
    
    for msg in messages:
        role = msg.role.lower()
        content = msg.content
        
        if isinstance(content, list):
            text_parts = [b.text for b in content if hasattr(b, 'text')]
            content = ''.join(text_parts)
        elif not isinstance(content, str):
            content = str(content)
        
        if role == "user":
            prompt_parts.append(f"User: {content}")
        elif role == "assistant":
            prompt_parts.append(f"Assistant: {content}")
        else:
            prompt_parts.append(f"{role}: {content}")
    
    prompt_parts.append("Assistant:")
    return "\n\n".join(prompt_parts)


def count_tokens(text: str, tokenizer) -> int:
    """Count tokens using the tokenizer."""
    try:
        return len(tokenizer.encode(text))
    except:
        return max(1, len(text) // 4)


def truncate_prompt(prompt: str, max_tokens: int, tokenizer, system: Optional[str] = None) -> str:
    """Truncate prompt to fit within context window."""
    current_tokens = count_tokens(prompt, tokenizer)
    
    if current_tokens <= max_tokens:
        return prompt
    
    lines = prompt.split("\n\n")
    
    system_line = None
    if lines and lines[0].startswith("[System:"):
        system_line = lines[0]
        lines = lines[1:]
    
    truncated_lines = []
    for line in reversed(lines):
        truncated_lines.insert(0, line)
        test_prompt = "\n\n".join([system_line] + truncated_lines) if system_line else "\n\n".join(truncated_lines)
        if count_tokens(test_prompt, tokenizer) <= max_tokens:
            break
    
    if system_line:
        return "\n\n".join([system_line] + truncated_lines)
    return "\n\n".join(truncated_lines)


def extract_assistant_response(generated_text: str) -> str:
    """Extract assistant response from generated text."""
    if "Assistant:" in generated_text:
        return generated_text.split("Assistant:")[-1].strip()
    
    lines = generated_text.split("\n")
    response_parts = []
    in_assistant = False
    for line in lines:
        if line.startswith("Assistant:"):
            in_assistant = True
            response_parts.append(line.replace("Assistant:", "").strip())
        elif in_assistant and not line.startswith("User:") and not line.startswith("System:"):
            response_parts.append(line)
        elif line.startswith("User:") or line.startswith("System:"):
            in_assistant = False
    
    return "\n".join(response_parts).strip()


# ==================== API Endpoints ====================

@app.get("/", tags=["Root"])
async def root():
    """Root endpoint with API information."""
    return {
        "name": "OpenELM OpenAI API v5",
        "version": "5.0.0",
        "status": global_state["status"],
        "model_loaded": global_state["status"] == "READY",
        "endpoints": {
            "chat": "POST /v1/chat/completions",
            "messages": "POST /v1/messages",
            "health": "GET /health"
        },
        "note": "Model loads in background for instant startup"
    }


@app.get("/health", tags=["Health"])
async def health_check():
    """Health check endpoint."""
    if global_state["status"] == "READY":
        return {"status": "healthy", "model_loaded": True}
    elif global_state["status"] == "ERROR":
        raise HTTPException(
            status_code=503,
            detail=f"Model failed to load: {global_state.get('error', 'Unknown error')}"
        )
    else:
        raise HTTPException(
            status_code=503,
            detail="Model is still loading. Please retry in a few moments."
        )


@app.get("/v1/models", response_model=OpenAIModelListResponse, tags=["Models"])
async def list_models():
    """List available models (OpenAI format)."""
    return OpenAIModelListResponse(
        data=[
            OpenAIModelInfo(
                id="openelm-450m-instruct",
                owned_by="apple",
                created=int(uuid.uuid1().time)
            )
        ]
    )


@app.post("/v1/chat/completions", tags=["OpenAI"])
async def create_chat_completion(request: ChatCompletionRequest):
    """Create chat completion (OpenAI API format)."""
    if global_state["status"] != "READY":
        if global_state["status"] == "ERROR":
            raise HTTPException(status_code=503, detail="Model failed to load")
        raise HTTPException(status_code=503, detail="Model is still loading. Please retry.")
    
    model = global_state["model"]
    tokenizer = global_state["tokenizer"]
    
    try:
        system_message = None
        formatted_messages = []
        
        for msg in request.messages:
            if msg.role == "system" and system_message is None:
                system_message = msg.content
            else:
                formatted_messages.append(Message(role=msg.role, content=msg.content))
        
        prompt = format_prompt_for_openelm(formatted_messages, system_message)
        max_tokens = request.max_tokens or 1024
        prompt = truncate_prompt(prompt, 2048 - max_tokens, tokenizer, system_message)
        
        inputs = tokenizer(prompt, return_tensors="pt")
        input_tokens = len(inputs.input_ids[0])
        
        if hasattr(model, 'device'):
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
        
        gen_params = {"max_new_tokens": max_tokens}
        
        if request.temperature is not None:
            if request.temperature == 0:
                gen_params["do_sample"] = False
            else:
                gen_params["temperature"] = request.temperature
                gen_params["do_sample"] = True
        
        if request.top_p is not None:
            gen_params["top_p"] = request.top_p
        
        import torch
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                **gen_params,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
        
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        response_text = extract_assistant_response(generated_text)
        output_tokens = count_tokens(response_text, tokenizer)
        
        response_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
        timestamp = int(uuid.uuid1().time)
        
        return ChatCompletionResponse(
            id=response_id,
            created=timestamp,
            model="openelm-450m-instruct",
            choices=[
                ChatCompletionChoice(
                    index=0,
                    message=ChatMessage(role="assistant", content=response_text),
                    finish_reason="stop"
                )
            ],
            usage=ChatCompletionUsage(
                prompt_tokens=input_tokens,
                completion_tokens=output_tokens,
                total_tokens=input_tokens + output_tokens
            )
        )
        
    except Exception as e:
        import traceback
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")


@app.post("/v1/messages", response_model=MessageResponse, tags=["Messages"])
async def create_message(params: MessageCreateParams):
    """Create message (Anthropic API format)."""
    if global_state["status"] != "READY":
        if global_state["status"] == "ERROR":
            raise HTTPException(status_code=503, detail="Model failed to load")
        raise HTTPException(status_code=503, detail="Model is still loading. Please retry.")
    
    model = global_state["model"]
    tokenizer = global_state["tokenizer"]
    
    try:
        formatted_messages = []
        for msg in params.messages:
            content = msg.content
            if isinstance(content, list):
                content = ''.join(b.text for b in content if hasattr(b, 'text'))
            formatted_messages.append(Message(role=msg.role, content=content))
        
        prompt = format_prompt_for_openelm(formatted_messages, params.system)
        prompt = truncate_prompt(prompt, 2048 - params.max_tokens, tokenizer, params.system)
        
        inputs = tokenizer(prompt, return_tensors="pt")
        input_tokens = len(inputs.input_ids[0])
        
        if hasattr(model, 'device'):
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
        
        gen_params = {"max_new_tokens": params.max_tokens}
        
        if params.temperature is not None:
            if params.temperature == 0:
                gen_params["do_sample"] = False
            else:
                gen_params["temperature"] = params.temperature
                gen_params["do_sample"] = True
        
        if params.top_p is not None:
            gen_params["top_p"] = params.top_p
        
        import torch
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                **gen_params,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
        
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        response_text = extract_assistant_response(generated_text)
        output_tokens = count_tokens(response_text, tokenizer)
        
        return MessageResponse(
            id=f"msg_{uuid.uuid4().hex[:8]}",
            role="assistant",
            content=[ContentBlock(type="text", text=response_text)],
            model="openelm-450m-instruct",
            stop_reason="end_turn",
            usage=Usage(input_tokens=input_tokens, output_tokens=output_tokens)
        )
        
    except Exception as e:
        import traceback
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")


# ==================== Main Entry Point ====================

if __name__ == "__main__":
    import uvicorn
    
    port = int(os.environ.get("PORT", 7860))
    
    print(f"\nStarting OpenELM API v5 on port {port}...")
    print("The server will respond immediately while the model loads in background.\n")
    
    uvicorn.run(
        "app:app",
        host="0.0.0.0",
        port=port,
        reload=False,
        workers=1
    )