Trouter-Library commited on
Commit
7d95305
·
verified ·
1 Parent(s): a2975bd

Create server.py

Browse files
Files changed (1) hide show
  1. server.py +419 -0
server.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helion-V1.5 Production API Server
3
+ FastAPI server with OpenAI-compatible endpoints, streaming, and monitoring
4
+ """
5
+
6
+ import os
7
+ import time
8
+ import logging
9
+ from typing import List, Dict, Optional, AsyncIterator
10
+ from contextlib import asynccontextmanager
11
+ import uvicorn
12
+ from fastapi import FastAPI, HTTPException, Request
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from fastapi.responses import StreamingResponse
15
+ from pydantic import BaseModel, Field
16
+ import torch
17
+
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ # Global model instance
23
+ MODEL = None
24
+ TOKENIZER = None
25
+ SAFEGUARDS = None
26
+
27
+
28
+ class Message(BaseModel):
29
+ """Chat message."""
30
+ role: str = Field(..., description="Message role (system/user/assistant)")
31
+ content: str = Field(..., description="Message content")
32
+
33
+
34
+ class ChatCompletionRequest(BaseModel):
35
+ """OpenAI-compatible chat completion request."""
36
+ model: str = Field(default="DeepXR/Helion-V1.5")
37
+ messages: List[Message]
38
+ temperature: float = Field(default=0.7, ge=0.0, le=2.0)
39
+ top_p: float = Field(default=0.9, ge=0.0, le=1.0)
40
+ max_tokens: int = Field(default=512, ge=1, le=4096)
41
+ stream: bool = Field(default=False)
42
+ n: int = Field(default=1, ge=1, le=1)
43
+ stop: Optional[List[str]] = None
44
+ presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
45
+ frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
46
+
47
+
48
+ class ChatCompletionResponse(BaseModel):
49
+ """OpenAI-compatible chat completion response."""
50
+ id: str
51
+ object: str = "chat.completion"
52
+ created: int
53
+ model: str
54
+ choices: List[Dict]
55
+ usage: Dict[str, int]
56
+
57
+
58
+ class CompletionRequest(BaseModel):
59
+ """Text completion request."""
60
+ prompt: str
61
+ max_tokens: int = Field(default=512, ge=1, le=4096)
62
+ temperature: float = Field(default=0.7, ge=0.0, le=2.0)
63
+ top_p: float = Field(default=0.9, ge=0.0, le=1.0)
64
+ stream: bool = Field(default=False)
65
+
66
+
67
+ @asynccontextmanager
68
+ async def lifespan(app: FastAPI):
69
+ """Lifespan context manager for model loading."""
70
+ global MODEL, TOKENIZER, SAFEGUARDS
71
+
72
+ logger.info("Loading Helion-V1.5...")
73
+
74
+ from transformers import AutoTokenizer, AutoModelForCausalLM
75
+ from safeguards_v15 import HelionSafeguardSystem, SafeguardConfig
76
+
77
+ model_name = os.getenv("MODEL_NAME", "DeepXR/Helion-V1.5")
78
+
79
+ TOKENIZER = AutoTokenizer.from_pretrained(model_name)
80
+ MODEL = AutoModelForCausalLM.from_pretrained(
81
+ model_name,
82
+ torch_dtype=torch.bfloat16,
83
+ device_map="auto"
84
+ )
85
+ MODEL.eval()
86
+
87
+ # Initialize safeguards
88
+ safeguard_mode = os.getenv("SAFEGUARD_MODE", "moderate")
89
+ from safeguards_v15 import create_safeguard_config
90
+ config = create_safeguard_config(mode=safeguard_mode)
91
+ SAFEGUARDS = HelionSafeguardSystem(config)
92
+
93
+ logger.info("Model loaded successfully")
94
+
95
+ yield
96
+
97
+ logger.info("Shutting down...")
98
+ del MODEL
99
+ del TOKENIZER
100
+ torch.cuda.empty_cache()
101
+
102
+
103
+ # Create FastAPI app
104
+ app = FastAPI(
105
+ title="Helion-V1.5 API",
106
+ description="OpenAI-compatible API for Helion-V1.5",
107
+ version="1.5.0",
108
+ lifespan=lifespan
109
+ )
110
+
111
+ # CORS middleware
112
+ app.add_middleware(
113
+ CORSMiddleware,
114
+ allow_origins=["*"],
115
+ allow_credentials=True,
116
+ allow_methods=["*"],
117
+ allow_headers=["*"],
118
+ )
119
+
120
+
121
+ # Request tracking middleware
122
+ @app.middleware("http")
123
+ async def log_requests(request: Request, call_next):
124
+ """Log all requests."""
125
+ start_time = time.time()
126
+ response = await call_next(request)
127
+ duration = time.time() - start_time
128
+
129
+ logger.info(
130
+ f"{request.method} {request.url.path} "
131
+ f"completed in {duration:.2f}s with status {response.status_code}"
132
+ )
133
+
134
+ return response
135
+
136
+
137
+ def generate_response(
138
+ messages: List[Dict[str, str]],
139
+ max_tokens: int = 512,
140
+ temperature: float = 0.7,
141
+ top_p: float = 0.9,
142
+ use_safeguards: bool = True
143
+ ) -> Dict:
144
+ """Generate response from messages."""
145
+
146
+ if use_safeguards:
147
+ # Check input with safeguards
148
+ user_msg = messages[-1]["content"]
149
+ context = " ".join([m["content"] for m in messages[:-1]])
150
+
151
+ allowed, response = SAFEGUARDS.filter_message(user_msg, context)
152
+
153
+ if not allowed:
154
+ return {
155
+ "text": response,
156
+ "blocked": True,
157
+ "finish_reason": "content_filter"
158
+ }
159
+
160
+ # Apply chat template
161
+ input_ids = TOKENIZER.apply_chat_template(
162
+ messages,
163
+ add_generation_prompt=True,
164
+ return_tensors="pt"
165
+ ).to(MODEL.device)
166
+
167
+ # Generate
168
+ with torch.no_grad():
169
+ output = MODEL.generate(
170
+ input_ids,
171
+ max_new_tokens=max_tokens,
172
+ temperature=temperature,
173
+ top_p=top_p,
174
+ do_sample=True,
175
+ pad_token_id=TOKENIZER.pad_token_id,
176
+ eos_token_id=TOKENIZER.eos_token_id
177
+ )
178
+
179
+ # Decode
180
+ response_text = TOKENIZER.decode(
181
+ output[0][input_ids.shape[1]:],
182
+ skip_special_tokens=True
183
+ )
184
+
185
+ # Check output with safeguards
186
+ if use_safeguards:
187
+ output_safe, reason = SAFEGUARDS.check_output(response_text, user_msg)
188
+ if not output_safe:
189
+ return {
190
+ "text": SAFEGUARDS.get_refusal_message("default"),
191
+ "blocked": True,
192
+ "finish_reason": "content_filter"
193
+ }
194
+
195
+ return {
196
+ "text": response_text.strip(),
197
+ "blocked": False,
198
+ "finish_reason": "stop",
199
+ "prompt_tokens": input_ids.shape[1],
200
+ "completion_tokens": output.shape[1] - input_ids.shape[1],
201
+ "total_tokens": output.shape[1]
202
+ }
203
+
204
+
205
+ async def stream_response(
206
+ messages: List[Dict[str, str]],
207
+ max_tokens: int = 512,
208
+ temperature: float = 0.7,
209
+ top_p: float = 0.9
210
+ ) -> AsyncIterator[str]:
211
+ """Stream response tokens."""
212
+ import json
213
+
214
+ # Apply chat template
215
+ input_ids = TOKENIZER.apply_chat_template(
216
+ messages,
217
+ add_generation_prompt=True,
218
+ return_tensors="pt"
219
+ ).to(MODEL.device)
220
+
221
+ # Stream generation
222
+ from transformers import TextIteratorStreamer
223
+ from threading import Thread
224
+
225
+ streamer = TextIteratorStreamer(
226
+ TOKENIZER,
227
+ skip_prompt=True,
228
+ skip_special_tokens=True
229
+ )
230
+
231
+ generation_kwargs = dict(
232
+ input_ids=input_ids,
233
+ max_new_tokens=max_tokens,
234
+ temperature=temperature,
235
+ top_p=top_p,
236
+ do_sample=True,
237
+ streamer=streamer,
238
+ pad_token_id=TOKENIZER.pad_token_id,
239
+ eos_token_id=TOKENIZER.eos_token_id
240
+ )
241
+
242
+ thread = Thread(target=MODEL.generate, kwargs=generation_kwargs)
243
+ thread.start()
244
+
245
+ # Stream tokens
246
+ for text in streamer:
247
+ chunk = {
248
+ "id": f"chatcmpl-{int(time.time())}",
249
+ "object": "chat.completion.chunk",
250
+ "created": int(time.time()),
251
+ "model": "DeepXR/Helion-V1.5",
252
+ "choices": [{
253
+ "index": 0,
254
+ "delta": {"content": text},
255
+ "finish_reason": None
256
+ }]
257
+ }
258
+ yield f"data: {json.dumps(chunk)}\n\n"
259
+
260
+ # Final chunk
261
+ final_chunk = {
262
+ "id": f"chatcmpl-{int(time.time())}",
263
+ "object": "chat.completion.chunk",
264
+ "created": int(time.time()),
265
+ "model": "DeepXR/Helion-V1.5",
266
+ "choices": [{
267
+ "index": 0,
268
+ "delta": {},
269
+ "finish_reason": "stop"
270
+ }]
271
+ }
272
+ yield f"data: {json.dumps(final_chunk)}\n\n"
273
+ yield "data: [DONE]\n\n"
274
+
275
+
276
+ @app.get("/")
277
+ async def root():
278
+ """Root endpoint."""
279
+ return {
280
+ "name": "Helion-V1.5 API",
281
+ "version": "1.5.0",
282
+ "status": "online",
283
+ "model": "DeepXR/Helion-V1.5"
284
+ }
285
+
286
+
287
+ @app.get("/health")
288
+ async def health_check():
289
+ """Health check endpoint."""
290
+ return {
291
+ "status": "healthy",
292
+ "model_loaded": MODEL is not None,
293
+ "device": str(MODEL.device) if MODEL else None,
294
+ "safeguards_enabled": SAFEGUARDS is not None
295
+ }
296
+
297
+
298
+ @app.get("/v1/models")
299
+ async def list_models():
300
+ """List available models."""
301
+ return {
302
+ "object": "list",
303
+ "data": [{
304
+ "id": "DeepXR/Helion-V1.5",
305
+ "object": "model",
306
+ "created": int(time.time()),
307
+ "owned_by": "deepxr"
308
+ }]
309
+ }
310
+
311
+
312
+ @app.post("/v1/chat/completions")
313
+ async def chat_completions(request: ChatCompletionRequest):
314
+ """OpenAI-compatible chat completions endpoint."""
315
+
316
+ if not MODEL or not TOKENIZER:
317
+ raise HTTPException(status_code=503, detail="Model not loaded")
318
+
319
+ # Convert messages
320
+ messages = [{"role": m.role, "content": m.content} for m in request.messages]
321
+
322
+ # Streaming response
323
+ if request.stream:
324
+ return StreamingResponse(
325
+ stream_response(
326
+ messages,
327
+ max_tokens=request.max_tokens,
328
+ temperature=request.temperature,
329
+ top_p=request.top_p
330
+ ),
331
+ media_type="text/event-stream"
332
+ )
333
+
334
+ # Non-streaming response
335
+ result = generate_response(
336
+ messages,
337
+ max_tokens=request.max_tokens,
338
+ temperature=request.temperature,
339
+ top_p=request.top_p
340
+ )
341
+
342
+ response = ChatCompletionResponse(
343
+ id=f"chatcmpl-{int(time.time())}",
344
+ created=int(time.time()),
345
+ model=request.model,
346
+ choices=[{
347
+ "index": 0,
348
+ "message": {
349
+ "role": "assistant",
350
+ "content": result["text"]
351
+ },
352
+ "finish_reason": result["finish_reason"]
353
+ }],
354
+ usage={
355
+ "prompt_tokens": result.get("prompt_tokens", 0),
356
+ "completion_tokens": result.get("completion_tokens", 0),
357
+ "total_tokens": result.get("total_tokens", 0)
358
+ }
359
+ )
360
+
361
+ return response
362
+
363
+
364
+ @app.post("/v1/completions")
365
+ async def completions(request: CompletionRequest):
366
+ """Text completion endpoint."""
367
+
368
+ if not MODEL or not TOKENIZER:
369
+ raise HTTPException(status_code=503, detail="Model not loaded")
370
+
371
+ messages = [{"role": "user", "content": request.prompt}]
372
+
373
+ result = generate_response(
374
+ messages,
375
+ max_tokens=request.max_tokens,
376
+ temperature=request.temperature,
377
+ top_p=request.top_p
378
+ )
379
+
380
+ return {
381
+ "id": f"cmpl-{int(time.time())}",
382
+ "object": "text_completion",
383
+ "created": int(time.time()),
384
+ "model": "DeepXR/Helion-V1.5",
385
+ "choices": [{
386
+ "text": result["text"],
387
+ "index": 0,
388
+ "finish_reason": result["finish_reason"]
389
+ }],
390
+ "usage": {
391
+ "prompt_tokens": result.get("prompt_tokens", 0),
392
+ "completion_tokens": result.get("completion_tokens", 0),
393
+ "total_tokens": result.get("total_tokens", 0)
394
+ }
395
+ }
396
+
397
+
398
+ def main():
399
+ """Run the server."""
400
+ import argparse
401
+
402
+ parser = argparse.ArgumentParser(description="Helion-V1.5 API Server")
403
+ parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
404
+ parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
405
+ parser.add_argument("--reload", action="store_true", help="Enable auto-reload")
406
+
407
+ args = parser.parse_args()
408
+
409
+ uvicorn.run(
410
+ "server:app",
411
+ host=args.host,
412
+ port=args.port,
413
+ reload=args.reload,
414
+ log_level="info"
415
+ )
416
+
417
+
418
+ if __name__ == "__main__":
419
+ main()