|
|
import traceback |
|
|
import json |
|
|
import sys |
|
|
from typing import Dict, Any, List |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
|
|
|
def log(*args): |
|
|
"""Send logs to HuggingFace endpoint logs.""" |
|
|
print("[DEBUG]", *args) |
|
|
sys.stdout.flush() |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
log("π Initializing handler...") |
|
|
log("Model path:", path) |
|
|
|
|
|
try: |
|
|
self.model_id = path |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) |
|
|
log("Tokenizer loaded.") |
|
|
|
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
path, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto", |
|
|
) |
|
|
log("Model loaded on device:", self.model.device) |
|
|
|
|
|
except Exception as e: |
|
|
log("β Error during initialization:", str(e)) |
|
|
log(traceback.format_exc()) |
|
|
raise e |
|
|
|
|
|
log("β
Initialization complete.") |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
log("----------------------------------------------------") |
|
|
log("π₯ Incoming Request:", json.dumps(data, indent=2)) |
|
|
|
|
|
try: |
|
|
prompt = data.get("prompt") or data.get("inputs") or "" |
|
|
max_tokens = data.get("max_tokens", 200) |
|
|
temperature = data.get("temperature", 0.1) |
|
|
stop_tokens = data.get("stop", None) |
|
|
|
|
|
log("Prompt length:", len(prompt)) |
|
|
log("Max tokens:", max_tokens) |
|
|
log("Temperature:", temperature) |
|
|
log("Stop tokens:", stop_tokens) |
|
|
|
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
|
|
log("Tokenized input shape:", {k: v.shape for k, v in inputs.items()}) |
|
|
|
|
|
|
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_tokens, |
|
|
do_sample=temperature > 0, |
|
|
temperature=temperature, |
|
|
top_p=0.95, |
|
|
pad_token_id=self.tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
generated_full = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
output_text = generated_full[len(prompt):] |
|
|
log("Raw model output:", repr(output_text[:300])) |
|
|
|
|
|
|
|
|
if stop_tokens: |
|
|
for s in stop_tokens: |
|
|
if s in output_text: |
|
|
output_text = output_text.split(s)[0] |
|
|
log(f"Applied stop token: {s}") |
|
|
|
|
|
output_text = output_text.strip() |
|
|
log("Final output:", repr(output_text)) |
|
|
|
|
|
|
|
|
response = { |
|
|
"id": "cmpl-local", |
|
|
"object": "text_completion", |
|
|
"model": self.model_id, |
|
|
"choices": [ |
|
|
{ |
|
|
"text": output_text, |
|
|
"index": 0, |
|
|
"finish_reason": "stop", |
|
|
} |
|
|
], |
|
|
} |
|
|
|
|
|
log("π€ Response:", json.dumps(response, indent=2)) |
|
|
log("----------------------------------------------------") |
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
log("β Exception during inference:", str(e)) |
|
|
log(traceback.format_exc()) |
|
|
|
|
|
return { |
|
|
"id": "cmpl-error", |
|
|
"object": "text_completion", |
|
|
"model": self.model_id, |
|
|
"choices": [ |
|
|
{ |
|
|
"text": f"ERROR: {str(e)}", |
|
|
"index": 0, |
|
|
"finish_reason": "error", |
|
|
} |
|
|
], |
|
|
} |
|
|
|