File size: 4,531 Bytes
68ee47a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from typing import Dict, List, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from fastapi.responses import StreamingResponse
import uuid
import time
import json
from threading import Thread
class EndpointHandler:
def __init__(self, path: str = "openai/gpt-oss-20b"):
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(path)
self.model.eval()
# Determine the computation device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def openai_id(prefix: str) -> str:
return f"{prefix}-{uuid.uuid4().hex[:24]}"
def format_non_stream(self, model: str, text: str, prompt_length: int, completion_length: int, total_tokens: int):
# Create OpenAI-compatible payload
return {
"id": self.openai_id("chatcmpl"),
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": text},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_length,
"completion_tokens": completion_length,
"total_tokens": total_tokens
}
}
def format_stream(self, model: str, token: str, usage) -> bytes:
payload = {
"id": self.openai_id("chatcmpl"),
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"delta": {
"content": token,
"function_call": None,
"refusal": None,
"role": None,
"tool_calls": None
},
"finish_reason": None,
"logprobs": None
}],
"usage": usage
}
return f"data: {json.dumps(payload)}\n\n".encode('utf-8')
def generate(self, messages, model: str):
model_inputs = self.tokenizer(messages, return_tensors="pt").to(self.device)
full_output = self.model.generate(**model_inputs, max_new_tokens=2048)
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(model_inputs.input_ids, full_output)
]
text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0]
input_length = model_inputs.input_ids.shape[1] # Prompt tokens
output_length = full_output.shape[1] # Total tokens (prompt + completion)
completion_tokens = output_length - input_length
return self.format_non_stream(model, text, input_length, completion_tokens, output_length)
def stream(self, messages, model):
model_inputs = self.tokenizer(messages, return_tensors="pt").to(self.device)
input_len = model_inputs.input_ids.shape[1]
streamer = TextIteratorStreamer(
self.tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
generation_kwargs = dict(
**model_inputs,
streamer=streamer,
max_new_tokens=2048
)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
completion_tokens = 0
for token in streamer:
# Count tokens in each chunk
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
token_count = len(token_ids)
completion_tokens += token_count
yield self.format_stream(model, token, None)
# Final chunk with stop reason and token counts
yield self.format_stream(model, "", {
"prompt_tokens": input_len,
"completion_tokens": completion_tokens,
"total_tokens": input_len + completion_tokens
})
def __call__(self, data: Dict[str, Any]):
messages = data.get("messages")
model = data.get("model")
stream = data.get("stream", False)
if stream is False:
return self.generate(messages, model)
else:
return StreamingResponse(
self.stream(messages, model),
media_type="text/event-stream"
) |