import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel import json import os import logging import time import traceback import psutil from datetime import datetime # Set up comprehensive logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - [%(funcName)s:%(lineno)d] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger(__name__) # Add request tracking request_counter = 0 class EndpointHandler: def __init__(self, path=""): """ Initialize the handler for Hugging Face Inference Endpoints Args: path (str): Path to the model directory """ init_start_time = time.time() logger.info("🚀 Initializing Streamlit Copilot Handler") # Device setup self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Device: {self.device}") try: # Check if this is a merged model or LoRA adapter adapter_config_path = os.path.join(path, "adapter_config.json") if os.path.exists(adapter_config_path): logger.info("Loading LoRA adapter model") # This is a LoRA adapter - load base model and adapter base_model_name = "bigcode/starcoder2-3b" self.tokenizer = AutoTokenizer.from_pretrained( base_model_name, trust_remote_code=True, use_fast=True ) self.model = AutoModelForCausalLM.from_pretrained( base_model_name, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True ) # Load the LoRA adapter self.model = PeftModel.from_pretrained(self.model, path) logger.info("LoRA adapter loaded") else: logger.info("Loading merged model") self.tokenizer = AutoTokenizer.from_pretrained( path, trust_remote_code=True, use_fast=True ) self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True ) logger.info("Merged model loaded") # Configure tokenizer if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Set model to evaluation mode self.model.eval() if self.device == "cuda": torch.cuda.empty_cache() init_total_time = time.time() - init_start_time logger.info(f"✅ Model initialization completed in {init_total_time:.2f}s") except Exception as e: logger.error(f"❌ Model initialization failed: {str(e)}") logger.error(traceback.format_exc()) raise def handle_openai_completions(self, data): """ Handle OpenAI-style /v1/completions requests for Continue VS Code extension """ global request_counter request_counter += 1 req_id = f"openai-comp-{request_counter}" start_time = time.time() logger.info(f"[{req_id}] OpenAI Completions request") try: prompt = data.get("prompt", "") if not prompt: logger.warning(f"[{req_id}] No prompt provided") return {"error": {"message": "No prompt provided", "type": "invalid_request"}} logger.info(f"[{req_id}] Prompt: {len(prompt)} chars - {prompt[:50]}{'...' if len(prompt) > 50 else ''}") max_tokens = min(data.get("max_tokens", 100), 512) temperature = max(0.0, min(data.get("temperature", 0.2), 2.0)) top_p = max(0.0, min(data.get("top_p", 1.0), 1.0)) stop = data.get("stop", []) generated_text = self._generate_text_internal(prompt, max_tokens, temperature, top_p, stop, req_id) response = { "id": f"cmpl-{datetime.now().strftime('%Y%m%d%H%M%S')}", "object": "text_completion", "created": int(datetime.now().timestamp()), "model": "starcoder2-3b-streamlit-copilot", "choices": [{ "text": generated_text, "index": 0, "logprobs": None, "finish_reason": "stop" }] } total_time = time.time() - start_time logger.info(f"[{req_id}] ✅ Completed in {total_time:.2f}s - Generated {len(generated_text)} chars") return response except Exception as e: total_time = time.time() - start_time logger.error(f"[{req_id}] ❌ Failed after {total_time:.2f}s: {str(e)}") return {"error": {"message": str(e), "type": "server_error"}} def handle_openai_chat_completions(self, data): """ Handle OpenAI-style /v1/chat/completions requests """ global request_counter request_counter += 1 req_id = f"openai-chat-{request_counter}" start_time = time.time() logger.info(f"[{req_id}] Chat Completions request") try: messages = data.get("messages", []) if not messages: logger.warning(f"[{req_id}] No messages provided") return {"error": {"message": "No messages provided", "type": "invalid_request"}} logger.info(f"[{req_id}] {len(messages)} messages") # Convert messages to prompt prompt = self._messages_to_prompt(messages) max_tokens = min(data.get("max_tokens", 100), 512) temperature = max(0.0, min(data.get("temperature", 0.2), 2.0)) top_p = max(0.0, min(data.get("top_p", 1.0), 1.0)) stop = data.get("stop", []) generated_text = self._generate_text_internal(prompt, max_tokens, temperature, top_p, stop, req_id) response = { "id": f"chatcmpl-{datetime.now().strftime('%Y%m%d%H%M%S')}", "object": "chat.completion", "created": int(datetime.now().timestamp()), "model": "starcoder2-3b-streamlit-copilot", "choices": [{ "index": 0, "message": { "role": "assistant", "content": generated_text }, "finish_reason": "stop" }] } total_time = time.time() - start_time logger.info(f"[{req_id}] ✅ Completed in {total_time:.2f}s - Generated {len(generated_text)} chars") return response except Exception as e: total_time = time.time() - start_time logger.error(f"[{req_id}] ❌ Failed after {total_time:.2f}s: {str(e)}") return {"error": {"message": str(e), "type": "server_error"}} def _messages_to_prompt(self, messages): """Convert OpenAI chat messages to a single prompt for code completion""" prompt_parts = [] for message in messages: role = message.get("role", "") content = message.get("content", "") if role == "system": prompt_parts.append(f"# {content}") elif role == "user": prompt_parts.append(content) elif role == "assistant": prompt_parts.append(content) return "\n".join(prompt_parts) def _generate_text_internal(self, prompt, max_tokens, temperature, top_p, stop_sequences, req_id="unknown"): """Internal method for text generation""" gen_start_time = time.time() logger.info(f"[{req_id}] Generating text...") try: # Tokenize input input_ids = self.tokenizer.encode( prompt, return_tensors="pt", truncation=True, max_length=2048 ).to(self.device) input_length = input_ids.shape[1] logger.info(f"[{req_id}] Input tokens: {input_length}") # Generate response with torch.no_grad(): outputs = self.model.generate( input_ids, max_new_tokens=max_tokens, temperature=temperature if temperature > 0 else 0.1, do_sample=temperature > 0, top_p=top_p, top_k=50, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, repetition_penalty=1.1, early_stopping=True, use_cache=True ) generation_time = time.time() - gen_start_time new_tokens = outputs.shape[1] - input_length # Decode the response generated_text = self.tokenizer.decode( outputs[0][input_ids.shape[1]:], skip_special_tokens=True ).strip() logger.info(f"[{req_id}] Generated {new_tokens} tokens ({new_tokens/generation_time:.1f} t/s)") # Apply stop sequences if stop_sequences: for stop_seq in stop_sequences: if stop_seq in generated_text: generated_text = generated_text.split(stop_seq)[0] break logger.info(f"[{req_id}] Generated text: {generated_text[:100]}{'...' if len(generated_text) > 100 else ''}") return generated_text except Exception as e: total_gen_time = time.time() - gen_start_time logger.error(f"[{req_id}] Generation failed after {total_gen_time:.2f}s: {str(e)}") raise def __call__(self, data): """ Main inference method - supports both HuggingFace and OpenAI API formats Args: data (dict): The object received by the inference server Continue/OpenAI format: { "inputs": "Your code prompt here", "stream": true, "parameters": { "max_new_tokens": 100, "temperature": 0.7, "top_p": 0.9 } } OpenAI Completions format: { "prompt": "Your code prompt here", "max_tokens": 100, "temperature": 0.2 } OpenAI Chat format: { "messages": [...], "max_tokens": 100 } Returns: dict: Generated response with metadata """ global request_counter request_counter += 1 req_start_time = time.time() logger.info(f"Request #{request_counter} - Keys: {list(data.keys())}") # Detect request format and route accordingly if "messages" in data: # OpenAI Chat Completions format result = self.handle_openai_chat_completions(data) elif "prompt" in data: # OpenAI Completions format result = self.handle_openai_completions(data) elif "inputs" in data and ("stream" in data or any(key in data for key in ["parameters", "temperature", "max_tokens"])): # Continue VS Code extension format - return OpenAI format for llama.cpp/openai providers req_id = f"continue-{request_counter}" logger.info(f"[{req_id}] Continue HuggingFace-TGI compatible request") try: inputs = data.get("inputs", "") if not inputs: logger.warning(f"[{req_id}] No inputs provided") return {"error": {"message": "No input text provided", "type": "invalid_request"}} logger.info(f"[{req_id}] Input: {len(inputs)} chars - {inputs[:50]}{'...' if len(inputs) > 50 else ''}") # Extract parameters (Continue uses HF-style parameters) parameters = data.get("parameters", {}) max_new_tokens = min(parameters.get("max_new_tokens", data.get("max_tokens", 150)), 512) temperature = max(0.0, min(parameters.get("temperature", data.get("temperature", 0.2)), 2.0)) top_p = max(0.0, min(parameters.get("top_p", data.get("top_p", 1.0)), 1.0)) stop = data.get("stop", parameters.get("stop", [])) # Generate text generated_text = self._generate_text_internal( inputs, max_new_tokens, temperature, top_p, stop, req_id ) # Return HuggingFace format for Continue huggingface-tgi provider result = [{ "generated_text": generated_text }] except Exception as e: total_time = time.time() - req_start_time logger.error(f"[{req_id}] Failed after {total_time:.2f}s: {str(e)}") result = {"error": {"message": str(e), "type": "server_error"}} else: # Legacy HuggingFace format (pure HF testing) req_id = f"hf-{request_counter}" logger.info(f"[{req_id}] Legacy HF format request") try: inputs = data.get("inputs", "") if not inputs: logger.warning(f"[{req_id}] No inputs provided") return {"error": "No input text provided"} logger.info(f"[{req_id}] Input: {len(inputs)} chars - {inputs[:50]}{'...' if len(inputs) > 50 else ''}") parameters = data.get("parameters", {}) # Validate and set generation parameters max_new_tokens = min(parameters.get("max_new_tokens", 150), 512) temperature = max(0.1, min(parameters.get("temperature", 0.7), 1.0)) top_p = max(0.1, min(parameters.get("top_p", 0.9), 1.0)) # Use internal generation method generated_text = self._generate_text_internal( inputs, max_new_tokens, temperature, top_p, [], req_id ) # Return response in HF Inference Endpoint format result = [{ "generated_text": generated_text }] except Exception as e: total_time = time.time() - req_start_time logger.error(f"[{req_id}] Failed after {total_time:.2f}s: {str(e)}") result = {"error": f"Generation failed: {str(e)}"} # Final request logging total_request_time = time.time() - req_start_time if isinstance(result, dict) and "error" in result: logger.error(f"Request #{request_counter} ❌ Failed in {total_request_time:.2f}s") else: logger.info(f"Request #{request_counter} ✅ Completed in {total_request_time:.2f}s") return result