import traceback import json import sys from typing import Dict, Any, List import torch from transformers import AutoModelForCausalLM, AutoTokenizer def log(*args): """Send logs to HuggingFace endpoint logs.""" print("[DEBUG]", *args) sys.stdout.flush() class EndpointHandler: def __init__(self, path=""): log("📌 Initializing handler...") log("Model path:", path) try: self.model_id = path # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) log("Tokenizer loaded.") # Load model self.model = AutoModelForCausalLM.from_pretrained( path, trust_remote_code=True, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto", ) log("Model loaded on device:", self.model.device) except Exception as e: log("❌ Error during initialization:", str(e)) log(traceback.format_exc()) raise e log("✅ Initialization complete.") def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: log("----------------------------------------------------") log("📥 Incoming Request:", json.dumps(data, indent=2)) try: prompt = data.get("prompt") or data.get("inputs") or "" max_tokens = data.get("max_tokens", 200) temperature = data.get("temperature", 0.1) stop_tokens = data.get("stop", None) log("Prompt length:", len(prompt)) log("Max tokens:", max_tokens) log("Temperature:", temperature) log("Stop tokens:", stop_tokens) # Tokenize inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) log("Tokenized input shape:", {k: v.shape for k, v in inputs.items()}) # Generate outputs = self.model.generate( **inputs, max_new_tokens=max_tokens, do_sample=temperature > 0, temperature=temperature, top_p=0.95, pad_token_id=self.tokenizer.eos_token_id, ) generated_full = self.tokenizer.decode(outputs[0], skip_special_tokens=True) output_text = generated_full[len(prompt):] log("Raw model output:", repr(output_text[:300])) # Apply stop tokens if stop_tokens: for s in stop_tokens: if s in output_text: output_text = output_text.split(s)[0] log(f"Applied stop token: {s}") output_text = output_text.strip() log("Final output:", repr(output_text)) # Return OpenAI-compatible JSON (required by Continue) response = { "id": "cmpl-local", "object": "text_completion", "model": self.model_id, "choices": [ { "text": output_text, "index": 0, "finish_reason": "stop", } ], } log("📤 Response:", json.dumps(response, indent=2)) log("----------------------------------------------------") return response except Exception as e: log("❌ Exception during inference:", str(e)) log(traceback.format_exc()) return { "id": "cmpl-error", "object": "text_completion", "model": self.model_id, "choices": [ { "text": f"ERROR: {str(e)}", "index": 0, "finish_reason": "error", } ], }