| import json |
| import logging |
| from typing import Dict, List, Optional |
|
|
| import torch |
| from fastapi import FastAPI, Request |
| from vllm import LLM, SamplingParams |
| from vllm.utils import random_uuid |
|
|
| from chat_template import format_chat |
|
|
| app = FastAPI() |
| logger = logging.getLogger() |
| logger.setLevel(logging.INFO) |
|
|
| |
| def model_fn(model_dir): |
| |
| model = LLM( |
| model=model_dir, |
| trust_remote_code=True, |
| dtype="fp8", |
| gpu_memory_utilization=0.9, |
| ) |
| return model |
|
|
| |
| model = None |
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| global model |
| model = model_fn("/opt/ml/model") |
|
|
| |
| @app.post("/v1/chat/completions") |
| async def chat_completions(request: Request): |
| try: |
| data = await request.json() |
| |
| |
| messages = data.get("messages", []) |
| formatted_prompt = format_chat(messages) |
| |
| |
| sampling_params = SamplingParams( |
| do_sample=data.get("do_sample", True), |
| temperature=data.get("temperature", 0.7), |
| top_p=data.get("top_p", 0.9), |
| max_new_tokens=data.get("max_new_tokens", 512), |
| top_k=data.get("top_k", -1), |
| repetition_penalty=data.get("repetition_penalty", 1.0), |
| length_penalty=data.get("length_penalty", 1.0), |
| stop_token_ids=data.get("stop_token_ids", None), |
| skip_special_tokens=data.get("skip_special_tokens", True) |
| ) |
| |
| |
| guided_params = data.get("guided_params", None) |
| if guided_params: |
| sampling_params.guided_choice = guided_params.get("guided_choice") |
| sampling_params.guided_json = guided_params.get("guided_json") |
| sampling_params.guided_regex = guided_params.get("guided_regex") |
| |
| |
| outputs = model.generate(formatted_prompt, sampling_params) |
| generated_text = outputs[0].outputs[0].text |
| |
| |
| response = { |
| "id": f"chatcmpl-{random_uuid()}", |
| "object": "chat.completion", |
| "created": int(torch.cuda.current_timestamp()), |
| "model": "qwen-72b", |
| "choices": [{ |
| "index": 0, |
| "message": { |
| "role": "assistant", |
| "content": generated_text |
| }, |
| "finish_reason": "stop" |
| }], |
| "usage": { |
| "prompt_tokens": len(formatted_prompt), |
| "completion_tokens": len(generated_text), |
| "total_tokens": len(formatted_prompt) + len(generated_text) |
| } |
| } |
| |
| return response |
| |
| except Exception as e: |
| logger.exception("Exception during prediction") |
| return {"error": str(e), "details": repr(e)} |
|
|
| |
| @app.get("/ping") |
| def ping(): |
| logger.info("Ping request received") |
| return {"status": "healthy"} |
|
|