File size: 3,989 Bytes
6d3b84e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import traceback
import json
import sys
from typing import Dict, Any, List

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


def log(*args):
    """Send logs to HuggingFace endpoint logs."""
    print("[DEBUG]", *args)
    sys.stdout.flush()


class EndpointHandler:
    def __init__(self, path=""):
        log("📌 Initializing handler...")
        log("Model path:", path)

        try:
            self.model_id = path

            # Load tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
            log("Tokenizer loaded.")

            # Load model
            self.model = AutoModelForCausalLM.from_pretrained(
                path,
                trust_remote_code=True,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                device_map="auto",
            )
            log("Model loaded on device:", self.model.device)

        except Exception as e:
            log("❌ Error during initialization:", str(e))
            log(traceback.format_exc())
            raise e

        log("✅ Initialization complete.")

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        log("----------------------------------------------------")
        log("📥 Incoming Request:", json.dumps(data, indent=2))

        try:
            prompt = data.get("prompt") or data.get("inputs") or ""
            max_tokens = data.get("max_tokens", 200)
            temperature = data.get("temperature", 0.1)
            stop_tokens = data.get("stop", None)

            log("Prompt length:", len(prompt))
            log("Max tokens:", max_tokens)
            log("Temperature:", temperature)
            log("Stop tokens:", stop_tokens)

            # Tokenize
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
            log("Tokenized input shape:", {k: v.shape for k, v in inputs.items()})

            # Generate
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_tokens,
                do_sample=temperature > 0,
                temperature=temperature,
                top_p=0.95,
                pad_token_id=self.tokenizer.eos_token_id,
            )

            generated_full = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            output_text = generated_full[len(prompt):]
            log("Raw model output:", repr(output_text[:300]))

            # Apply stop tokens
            if stop_tokens:
                for s in stop_tokens:
                    if s in output_text:
                        output_text = output_text.split(s)[0]
                        log(f"Applied stop token: {s}")

            output_text = output_text.strip()
            log("Final output:", repr(output_text))

            # Return OpenAI-compatible JSON (required by Continue)
            response = {
                "id": "cmpl-local",
                "object": "text_completion",
                "model": self.model_id,
                "choices": [
                    {
                        "text": output_text,
                        "index": 0,
                        "finish_reason": "stop",
                    }
                ],
            }

            log("📤 Response:", json.dumps(response, indent=2))
            log("----------------------------------------------------")
            return response

        except Exception as e:
            log("❌ Exception during inference:", str(e))
            log(traceback.format_exc())

            return {
                "id": "cmpl-error",
                "object": "text_completion",
                "model": self.model_id,
                "choices": [
                    {
                        "text": f"ERROR: {str(e)}",
                        "index": 0,
                        "finish_reason": "error",
                    }
                ],
            }