RanjithaRuttala's picture
Rename handler.py to handler_old.py
4acd125 verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import json
import os
import logging
import time
import traceback
import psutil
from datetime import datetime
# Set up comprehensive logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - [%(funcName)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
# Add request tracking
request_counter = 0
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the handler for Hugging Face Inference Endpoints
Args:
path (str): Path to the model directory
"""
init_start_time = time.time()
logger.info("🚀 Initializing Streamlit Copilot Handler")
# Device setup
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Device: {self.device}")
try:
# Check if this is a merged model or LoRA adapter
adapter_config_path = os.path.join(path, "adapter_config.json")
if os.path.exists(adapter_config_path):
logger.info("Loading LoRA adapter model")
# This is a LoRA adapter - load base model and adapter
base_model_name = "bigcode/starcoder2-3b"
self.tokenizer = AutoTokenizer.from_pretrained(
base_model_name,
trust_remote_code=True,
use_fast=True
)
self.model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True
)
# Load the LoRA adapter
self.model = PeftModel.from_pretrained(self.model, path)
logger.info("LoRA adapter loaded")
else:
logger.info("Loading merged model")
self.tokenizer = AutoTokenizer.from_pretrained(
path,
trust_remote_code=True,
use_fast=True
)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True
)
logger.info("Merged model loaded")
# Configure tokenizer
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Set model to evaluation mode
self.model.eval()
if self.device == "cuda":
torch.cuda.empty_cache()
init_total_time = time.time() - init_start_time
logger.info(f"✅ Model initialization completed in {init_total_time:.2f}s")
except Exception as e:
logger.error(f"❌ Model initialization failed: {str(e)}")
logger.error(traceback.format_exc())
raise
def handle_openai_completions(self, data):
"""
Handle OpenAI-style /v1/completions requests for Continue VS Code extension
"""
global request_counter
request_counter += 1
req_id = f"openai-comp-{request_counter}"
start_time = time.time()
logger.info(f"[{req_id}] OpenAI Completions request")
try:
prompt = data.get("prompt", "")
if not prompt:
logger.warning(f"[{req_id}] No prompt provided")
return {"error": {"message": "No prompt provided", "type": "invalid_request"}}
logger.info(f"[{req_id}] Prompt: {len(prompt)} chars - {prompt[:50]}{'...' if len(prompt) > 50 else ''}")
max_tokens = min(data.get("max_tokens", 100), 512)
temperature = max(0.0, min(data.get("temperature", 0.2), 2.0))
top_p = max(0.0, min(data.get("top_p", 1.0), 1.0))
stop = data.get("stop", [])
generated_text = self._generate_text_internal(prompt, max_tokens, temperature, top_p, stop, req_id)
response = {
"id": f"cmpl-{datetime.now().strftime('%Y%m%d%H%M%S')}",
"object": "text_completion",
"created": int(datetime.now().timestamp()),
"model": "starcoder2-3b-streamlit-copilot",
"choices": [{
"text": generated_text,
"index": 0,
"logprobs": None,
"finish_reason": "stop"
}]
}
total_time = time.time() - start_time
logger.info(f"[{req_id}] ✅ Completed in {total_time:.2f}s - Generated {len(generated_text)} chars")
return response
except Exception as e:
total_time = time.time() - start_time
logger.error(f"[{req_id}] ❌ Failed after {total_time:.2f}s: {str(e)}")
return {"error": {"message": str(e), "type": "server_error"}}
def handle_openai_chat_completions(self, data):
"""
Handle OpenAI-style /v1/chat/completions requests
"""
global request_counter
request_counter += 1
req_id = f"openai-chat-{request_counter}"
start_time = time.time()
logger.info(f"[{req_id}] Chat Completions request")
try:
messages = data.get("messages", [])
if not messages:
logger.warning(f"[{req_id}] No messages provided")
return {"error": {"message": "No messages provided", "type": "invalid_request"}}
logger.info(f"[{req_id}] {len(messages)} messages")
# Convert messages to prompt
prompt = self._messages_to_prompt(messages)
max_tokens = min(data.get("max_tokens", 100), 512)
temperature = max(0.0, min(data.get("temperature", 0.2), 2.0))
top_p = max(0.0, min(data.get("top_p", 1.0), 1.0))
stop = data.get("stop", [])
generated_text = self._generate_text_internal(prompt, max_tokens, temperature, top_p, stop, req_id)
response = {
"id": f"chatcmpl-{datetime.now().strftime('%Y%m%d%H%M%S')}",
"object": "chat.completion",
"created": int(datetime.now().timestamp()),
"model": "starcoder2-3b-streamlit-copilot",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": generated_text
},
"finish_reason": "stop"
}]
}
total_time = time.time() - start_time
logger.info(f"[{req_id}] ✅ Completed in {total_time:.2f}s - Generated {len(generated_text)} chars")
return response
except Exception as e:
total_time = time.time() - start_time
logger.error(f"[{req_id}] ❌ Failed after {total_time:.2f}s: {str(e)}")
return {"error": {"message": str(e), "type": "server_error"}}
def _messages_to_prompt(self, messages):
"""Convert OpenAI chat messages to a single prompt for code completion"""
prompt_parts = []
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
if role == "system":
prompt_parts.append(f"# {content}")
elif role == "user":
prompt_parts.append(content)
elif role == "assistant":
prompt_parts.append(content)
return "\n".join(prompt_parts)
def _generate_text_internal(self, prompt, max_tokens, temperature, top_p, stop_sequences, req_id="unknown"):
"""Internal method for text generation"""
gen_start_time = time.time()
logger.info(f"[{req_id}] Generating text...")
try:
# Tokenize input
input_ids = self.tokenizer.encode(
prompt,
return_tensors="pt",
truncation=True,
max_length=2048
).to(self.device)
input_length = input_ids.shape[1]
logger.info(f"[{req_id}] Input tokens: {input_length}")
# Generate response
with torch.no_grad():
outputs = self.model.generate(
input_ids,
max_new_tokens=max_tokens,
temperature=temperature if temperature > 0 else 0.1,
do_sample=temperature > 0,
top_p=top_p,
top_k=50,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.1,
early_stopping=True,
use_cache=True
)
generation_time = time.time() - gen_start_time
new_tokens = outputs.shape[1] - input_length
# Decode the response
generated_text = self.tokenizer.decode(
outputs[0][input_ids.shape[1]:],
skip_special_tokens=True
).strip()
logger.info(f"[{req_id}] Generated {new_tokens} tokens ({new_tokens/generation_time:.1f} t/s)")
# Apply stop sequences
if stop_sequences:
for stop_seq in stop_sequences:
if stop_seq in generated_text:
generated_text = generated_text.split(stop_seq)[0]
break
logger.info(f"[{req_id}] Generated text: {generated_text[:100]}{'...' if len(generated_text) > 100 else ''}")
return generated_text
except Exception as e:
total_gen_time = time.time() - gen_start_time
logger.error(f"[{req_id}] Generation failed after {total_gen_time:.2f}s: {str(e)}")
raise
def __call__(self, data):
"""
Main inference method - supports both HuggingFace and OpenAI API formats
Args:
data (dict): The object received by the inference server
Continue/OpenAI format:
{
"inputs": "Your code prompt here",
"stream": true,
"parameters": {
"max_new_tokens": 100,
"temperature": 0.7,
"top_p": 0.9
}
}
OpenAI Completions format:
{
"prompt": "Your code prompt here",
"max_tokens": 100,
"temperature": 0.2
}
OpenAI Chat format:
{
"messages": [...],
"max_tokens": 100
}
Returns:
dict: Generated response with metadata
"""
global request_counter
request_counter += 1
req_start_time = time.time()
logger.info(f"Request #{request_counter} - Keys: {list(data.keys())}")
# Detect request format and route accordingly
if "messages" in data:
# OpenAI Chat Completions format
result = self.handle_openai_chat_completions(data)
elif "prompt" in data:
# OpenAI Completions format
result = self.handle_openai_completions(data)
elif "inputs" in data and ("stream" in data or any(key in data for key in ["parameters", "temperature", "max_tokens"])):
# Continue VS Code extension format - return OpenAI format for llama.cpp/openai providers
req_id = f"continue-{request_counter}"
logger.info(f"[{req_id}] Continue HuggingFace-TGI compatible request")
try:
inputs = data.get("inputs", "")
if not inputs:
logger.warning(f"[{req_id}] No inputs provided")
return {"error": {"message": "No input text provided", "type": "invalid_request"}}
logger.info(f"[{req_id}] Input: {len(inputs)} chars - {inputs[:50]}{'...' if len(inputs) > 50 else ''}")
# Extract parameters (Continue uses HF-style parameters)
parameters = data.get("parameters", {})
max_new_tokens = min(parameters.get("max_new_tokens", data.get("max_tokens", 150)), 512)
temperature = max(0.0, min(parameters.get("temperature", data.get("temperature", 0.2)), 2.0))
top_p = max(0.0, min(parameters.get("top_p", data.get("top_p", 1.0)), 1.0))
stop = data.get("stop", parameters.get("stop", []))
# Generate text
generated_text = self._generate_text_internal(
inputs, max_new_tokens, temperature, top_p, stop, req_id
)
# Return HuggingFace format for Continue huggingface-tgi provider
result = [{
"generated_text": generated_text
}]
except Exception as e:
total_time = time.time() - req_start_time
logger.error(f"[{req_id}] Failed after {total_time:.2f}s: {str(e)}")
result = {"error": {"message": str(e), "type": "server_error"}}
else:
# Legacy HuggingFace format (pure HF testing)
req_id = f"hf-{request_counter}"
logger.info(f"[{req_id}] Legacy HF format request")
try:
inputs = data.get("inputs", "")
if not inputs:
logger.warning(f"[{req_id}] No inputs provided")
return {"error": "No input text provided"}
logger.info(f"[{req_id}] Input: {len(inputs)} chars - {inputs[:50]}{'...' if len(inputs) > 50 else ''}")
parameters = data.get("parameters", {})
# Validate and set generation parameters
max_new_tokens = min(parameters.get("max_new_tokens", 150), 512)
temperature = max(0.1, min(parameters.get("temperature", 0.7), 1.0))
top_p = max(0.1, min(parameters.get("top_p", 0.9), 1.0))
# Use internal generation method
generated_text = self._generate_text_internal(
inputs, max_new_tokens, temperature, top_p, [], req_id
)
# Return response in HF Inference Endpoint format
result = [{
"generated_text": generated_text
}]
except Exception as e:
total_time = time.time() - req_start_time
logger.error(f"[{req_id}] Failed after {total_time:.2f}s: {str(e)}")
result = {"error": f"Generation failed: {str(e)}"}
# Final request logging
total_request_time = time.time() - req_start_time
if isinstance(result, dict) and "error" in result:
logger.error(f"Request #{request_counter} ❌ Failed in {total_request_time:.2f}s")
else:
logger.info(f"Request #{request_counter} ✅ Completed in {total_request_time:.2f}s")
return result