RanjithaRuttala's picture
Rename handler (1).py to handler.py
ec74187 verified
import traceback
import json
import sys
from typing import Dict, Any, List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def log(*args):
"""Send logs to HuggingFace endpoint logs."""
print("[DEBUG]", *args)
sys.stdout.flush()
class EndpointHandler:
def __init__(self, path=""):
log("πŸ“Œ Initializing handler...")
log("Model path:", path)
try:
self.model_id = path
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
log("Tokenizer loaded.")
# Load model
self.model = AutoModelForCausalLM.from_pretrained(
path,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
)
log("Model loaded on device:", self.model.device)
except Exception as e:
log("❌ Error during initialization:", str(e))
log(traceback.format_exc())
raise e
log("βœ… Initialization complete.")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
log("----------------------------------------------------")
log("πŸ“₯ Incoming Request:", json.dumps(data, indent=2))
try:
prompt = data.get("prompt") or data.get("inputs") or ""
max_tokens = data.get("max_tokens", 200)
temperature = data.get("temperature", 0.1)
stop_tokens = data.get("stop", None)
log("Prompt length:", len(prompt))
log("Max tokens:", max_tokens)
log("Temperature:", temperature)
log("Stop tokens:", stop_tokens)
# Tokenize
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
log("Tokenized input shape:", {k: v.shape for k, v in inputs.items()})
# Generate
outputs = self.model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=temperature > 0,
temperature=temperature,
top_p=0.95,
pad_token_id=self.tokenizer.eos_token_id,
)
generated_full = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
output_text = generated_full[len(prompt):]
log("Raw model output:", repr(output_text[:300]))
# Apply stop tokens
if stop_tokens:
for s in stop_tokens:
if s in output_text:
output_text = output_text.split(s)[0]
log(f"Applied stop token: {s}")
output_text = output_text.strip()
log("Final output:", repr(output_text))
# Return OpenAI-compatible JSON (required by Continue)
response = {
"id": "cmpl-local",
"object": "text_completion",
"model": self.model_id,
"choices": [
{
"text": output_text,
"index": 0,
"finish_reason": "stop",
}
],
}
log("πŸ“€ Response:", json.dumps(response, indent=2))
log("----------------------------------------------------")
return response
except Exception as e:
log("❌ Exception during inference:", str(e))
log(traceback.format_exc())
return {
"id": "cmpl-error",
"object": "text_completion",
"model": self.model_id,
"choices": [
{
"text": f"ERROR: {str(e)}",
"index": 0,
"finish_reason": "error",
}
],
}