File size: 4,531 Bytes
68ee47a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from typing import Dict, List, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from fastapi.responses import StreamingResponse
import uuid
import time
import json
from threading import Thread

class EndpointHandler:
    def __init__(self, path: str = "openai/gpt-oss-20b"):
        # Load tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModelForCausalLM.from_pretrained(path)
        self.model.eval()

        # Determine the computation device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def openai_id(prefix: str) -> str:
        return f"{prefix}-{uuid.uuid4().hex[:24]}"

    def format_non_stream(self, model: str, text: str, prompt_length: int, completion_length: int, total_tokens: int):
        # Create OpenAI-compatible payload
        return {
            "id": self.openai_id("chatcmpl"),
            "object": "chat.completion",
            "created": int(time.time()),
            "model": model,
            "choices": [{
                "index": 0,
                "message": {"role": "assistant", "content": text},
                "finish_reason": "stop"
            }],
            "usage": {
                "prompt_tokens": prompt_length,
                "completion_tokens": completion_length,
                "total_tokens": total_tokens
            }
        }

    def format_stream(self, model: str, token: str, usage) -> bytes:
        payload = {
            "id": self.openai_id("chatcmpl"),
            "object": "chat.completion.chunk",
            "created": int(time.time()),
            "model": model,
            "choices": [{
                "index": 0,
                "delta": {
                    "content": token,
                    "function_call": None,
                    "refusal": None,
                    "role": None,
                    "tool_calls": None
                },
                "finish_reason": None,
                "logprobs": None
            }],
            "usage": usage
        }

        return f"data: {json.dumps(payload)}\n\n".encode('utf-8')

    def generate(self, messages, model: str):
        model_inputs = self.tokenizer(messages, return_tensors="pt").to(self.device)
        full_output = self.model.generate(**model_inputs, max_new_tokens=2048)
        generated_ids = [
            output_ids[len(input_ids):]
            for input_ids, output_ids in zip(model_inputs.input_ids, full_output)
        ]
        text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0]

        input_length = model_inputs.input_ids.shape[1]  # Prompt tokens
        output_length = full_output.shape[1]  # Total tokens (prompt + completion)
        completion_tokens = output_length - input_length

        return self.format_non_stream(model, text, input_length, completion_tokens, output_length)

    def stream(self, messages, model):
        model_inputs = self.tokenizer(messages, return_tensors="pt").to(self.device)
        input_len = model_inputs.input_ids.shape[1]
        streamer = TextIteratorStreamer(
            self.tokenizer,
            skip_prompt=True,
            skip_special_tokens=True
        )

        generation_kwargs = dict(
            **model_inputs,
            streamer=streamer,
            max_new_tokens=2048
        )

        thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
        thread.start()

        completion_tokens = 0
        for token in streamer:
            # Count tokens in each chunk
            token_ids = self.tokenizer.encode(token, add_special_tokens=False)
            token_count = len(token_ids)
            completion_tokens += token_count

            yield self.format_stream(model, token, None)

        # Final chunk with stop reason and token counts
        yield self.format_stream(model, "", {
            "prompt_tokens": input_len,
            "completion_tokens": completion_tokens,
            "total_tokens": input_len + completion_tokens
        })

    def __call__(self, data: Dict[str, Any]):
        messages = data.get("messages")
        model = data.get("model")
        stream = data.get("stream", False)

        if stream is False:
            return self.generate(messages, model)
        else:
            return StreamingResponse(
                self.stream(messages, model),
                media_type="text/event-stream"
            )