|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from peft import PeftModel |
|
|
import json |
|
|
import os |
|
|
import logging |
|
|
import time |
|
|
import traceback |
|
|
import psutil |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - [%(funcName)s:%(lineno)d] - %(message)s', |
|
|
datefmt='%Y-%m-%d %H:%M:%S' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
request_counter = 0 |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
""" |
|
|
Initialize the handler for Hugging Face Inference Endpoints |
|
|
|
|
|
Args: |
|
|
path (str): Path to the model directory |
|
|
""" |
|
|
init_start_time = time.time() |
|
|
logger.info("🚀 Initializing Streamlit Copilot Handler") |
|
|
|
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
logger.info(f"Device: {self.device}") |
|
|
|
|
|
try: |
|
|
|
|
|
adapter_config_path = os.path.join(path, "adapter_config.json") |
|
|
|
|
|
if os.path.exists(adapter_config_path): |
|
|
logger.info("Loading LoRA adapter model") |
|
|
|
|
|
|
|
|
base_model_name = "bigcode/starcoder2-3b" |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
base_model_name, |
|
|
trust_remote_code=True, |
|
|
use_fast=True |
|
|
) |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
base_model_name, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
|
|
|
|
|
|
self.model = PeftModel.from_pretrained(self.model, path) |
|
|
logger.info("LoRA adapter loaded") |
|
|
|
|
|
else: |
|
|
logger.info("Loading merged model") |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
path, |
|
|
trust_remote_code=True, |
|
|
use_fast=True |
|
|
) |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
path, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
logger.info("Merged model loaded") |
|
|
|
|
|
|
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
if self.device == "cuda": |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
init_total_time = time.time() - init_start_time |
|
|
logger.info(f"✅ Model initialization completed in {init_total_time:.2f}s") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Model initialization failed: {str(e)}") |
|
|
logger.error(traceback.format_exc()) |
|
|
raise |
|
|
|
|
|
def handle_openai_completions(self, data): |
|
|
""" |
|
|
Handle OpenAI-style /v1/completions requests for Continue VS Code extension |
|
|
""" |
|
|
global request_counter |
|
|
request_counter += 1 |
|
|
req_id = f"openai-comp-{request_counter}" |
|
|
start_time = time.time() |
|
|
|
|
|
logger.info(f"[{req_id}] OpenAI Completions request") |
|
|
|
|
|
try: |
|
|
prompt = data.get("prompt", "") |
|
|
if not prompt: |
|
|
logger.warning(f"[{req_id}] No prompt provided") |
|
|
return {"error": {"message": "No prompt provided", "type": "invalid_request"}} |
|
|
|
|
|
logger.info(f"[{req_id}] Prompt: {len(prompt)} chars - {prompt[:50]}{'...' if len(prompt) > 50 else ''}") |
|
|
|
|
|
max_tokens = min(data.get("max_tokens", 100), 512) |
|
|
temperature = max(0.0, min(data.get("temperature", 0.2), 2.0)) |
|
|
top_p = max(0.0, min(data.get("top_p", 1.0), 1.0)) |
|
|
stop = data.get("stop", []) |
|
|
|
|
|
generated_text = self._generate_text_internal(prompt, max_tokens, temperature, top_p, stop, req_id) |
|
|
|
|
|
response = { |
|
|
"id": f"cmpl-{datetime.now().strftime('%Y%m%d%H%M%S')}", |
|
|
"object": "text_completion", |
|
|
"created": int(datetime.now().timestamp()), |
|
|
"model": "starcoder2-3b-streamlit-copilot", |
|
|
"choices": [{ |
|
|
"text": generated_text, |
|
|
"index": 0, |
|
|
"logprobs": None, |
|
|
"finish_reason": "stop" |
|
|
}] |
|
|
} |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
logger.info(f"[{req_id}] ✅ Completed in {total_time:.2f}s - Generated {len(generated_text)} chars") |
|
|
|
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
total_time = time.time() - start_time |
|
|
logger.error(f"[{req_id}] ❌ Failed after {total_time:.2f}s: {str(e)}") |
|
|
return {"error": {"message": str(e), "type": "server_error"}} |
|
|
|
|
|
def handle_openai_chat_completions(self, data): |
|
|
""" |
|
|
Handle OpenAI-style /v1/chat/completions requests |
|
|
""" |
|
|
global request_counter |
|
|
request_counter += 1 |
|
|
req_id = f"openai-chat-{request_counter}" |
|
|
start_time = time.time() |
|
|
|
|
|
logger.info(f"[{req_id}] Chat Completions request") |
|
|
|
|
|
try: |
|
|
messages = data.get("messages", []) |
|
|
if not messages: |
|
|
logger.warning(f"[{req_id}] No messages provided") |
|
|
return {"error": {"message": "No messages provided", "type": "invalid_request"}} |
|
|
|
|
|
logger.info(f"[{req_id}] {len(messages)} messages") |
|
|
|
|
|
|
|
|
prompt = self._messages_to_prompt(messages) |
|
|
|
|
|
max_tokens = min(data.get("max_tokens", 100), 512) |
|
|
temperature = max(0.0, min(data.get("temperature", 0.2), 2.0)) |
|
|
top_p = max(0.0, min(data.get("top_p", 1.0), 1.0)) |
|
|
stop = data.get("stop", []) |
|
|
|
|
|
generated_text = self._generate_text_internal(prompt, max_tokens, temperature, top_p, stop, req_id) |
|
|
|
|
|
response = { |
|
|
"id": f"chatcmpl-{datetime.now().strftime('%Y%m%d%H%M%S')}", |
|
|
"object": "chat.completion", |
|
|
"created": int(datetime.now().timestamp()), |
|
|
"model": "starcoder2-3b-streamlit-copilot", |
|
|
"choices": [{ |
|
|
"index": 0, |
|
|
"message": { |
|
|
"role": "assistant", |
|
|
"content": generated_text |
|
|
}, |
|
|
"finish_reason": "stop" |
|
|
}] |
|
|
} |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
logger.info(f"[{req_id}] ✅ Completed in {total_time:.2f}s - Generated {len(generated_text)} chars") |
|
|
|
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
total_time = time.time() - start_time |
|
|
logger.error(f"[{req_id}] ❌ Failed after {total_time:.2f}s: {str(e)}") |
|
|
return {"error": {"message": str(e), "type": "server_error"}} |
|
|
|
|
|
def _messages_to_prompt(self, messages): |
|
|
"""Convert OpenAI chat messages to a single prompt for code completion""" |
|
|
prompt_parts = [] |
|
|
for message in messages: |
|
|
role = message.get("role", "") |
|
|
content = message.get("content", "") |
|
|
|
|
|
if role == "system": |
|
|
prompt_parts.append(f"# {content}") |
|
|
elif role == "user": |
|
|
prompt_parts.append(content) |
|
|
elif role == "assistant": |
|
|
prompt_parts.append(content) |
|
|
|
|
|
return "\n".join(prompt_parts) |
|
|
|
|
|
def _generate_text_internal(self, prompt, max_tokens, temperature, top_p, stop_sequences, req_id="unknown"): |
|
|
"""Internal method for text generation""" |
|
|
gen_start_time = time.time() |
|
|
logger.info(f"[{req_id}] Generating text...") |
|
|
|
|
|
try: |
|
|
|
|
|
input_ids = self.tokenizer.encode( |
|
|
prompt, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=2048 |
|
|
).to(self.device) |
|
|
|
|
|
input_length = input_ids.shape[1] |
|
|
logger.info(f"[{req_id}] Input tokens: {input_length}") |
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
input_ids, |
|
|
max_new_tokens=max_tokens, |
|
|
temperature=temperature if temperature > 0 else 0.1, |
|
|
do_sample=temperature > 0, |
|
|
top_p=top_p, |
|
|
top_k=50, |
|
|
pad_token_id=self.tokenizer.eos_token_id, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
repetition_penalty=1.1, |
|
|
early_stopping=True, |
|
|
use_cache=True |
|
|
) |
|
|
|
|
|
generation_time = time.time() - gen_start_time |
|
|
new_tokens = outputs.shape[1] - input_length |
|
|
|
|
|
|
|
|
generated_text = self.tokenizer.decode( |
|
|
outputs[0][input_ids.shape[1]:], |
|
|
skip_special_tokens=True |
|
|
).strip() |
|
|
|
|
|
logger.info(f"[{req_id}] Generated {new_tokens} tokens ({new_tokens/generation_time:.1f} t/s)") |
|
|
|
|
|
|
|
|
if stop_sequences: |
|
|
for stop_seq in stop_sequences: |
|
|
if stop_seq in generated_text: |
|
|
generated_text = generated_text.split(stop_seq)[0] |
|
|
break |
|
|
|
|
|
logger.info(f"[{req_id}] Generated text: {generated_text[:100]}{'...' if len(generated_text) > 100 else ''}") |
|
|
|
|
|
return generated_text |
|
|
|
|
|
except Exception as e: |
|
|
total_gen_time = time.time() - gen_start_time |
|
|
logger.error(f"[{req_id}] Generation failed after {total_gen_time:.2f}s: {str(e)}") |
|
|
raise |
|
|
|
|
|
def __call__(self, data): |
|
|
""" |
|
|
Main inference method - supports both HuggingFace and OpenAI API formats |
|
|
|
|
|
Args: |
|
|
data (dict): The object received by the inference server |
|
|
Continue/OpenAI format: |
|
|
{ |
|
|
"inputs": "Your code prompt here", |
|
|
"stream": true, |
|
|
"parameters": { |
|
|
"max_new_tokens": 100, |
|
|
"temperature": 0.7, |
|
|
"top_p": 0.9 |
|
|
} |
|
|
} |
|
|
OpenAI Completions format: |
|
|
{ |
|
|
"prompt": "Your code prompt here", |
|
|
"max_tokens": 100, |
|
|
"temperature": 0.2 |
|
|
} |
|
|
OpenAI Chat format: |
|
|
{ |
|
|
"messages": [...], |
|
|
"max_tokens": 100 |
|
|
} |
|
|
Returns: |
|
|
dict: Generated response with metadata |
|
|
""" |
|
|
global request_counter |
|
|
request_counter += 1 |
|
|
req_start_time = time.time() |
|
|
|
|
|
logger.info(f"Request #{request_counter} - Keys: {list(data.keys())}") |
|
|
|
|
|
|
|
|
if "messages" in data: |
|
|
|
|
|
result = self.handle_openai_chat_completions(data) |
|
|
elif "prompt" in data: |
|
|
|
|
|
result = self.handle_openai_completions(data) |
|
|
elif "inputs" in data and ("stream" in data or any(key in data for key in ["parameters", "temperature", "max_tokens"])): |
|
|
|
|
|
req_id = f"continue-{request_counter}" |
|
|
logger.info(f"[{req_id}] Continue HuggingFace-TGI compatible request") |
|
|
|
|
|
try: |
|
|
inputs = data.get("inputs", "") |
|
|
if not inputs: |
|
|
logger.warning(f"[{req_id}] No inputs provided") |
|
|
return {"error": {"message": "No input text provided", "type": "invalid_request"}} |
|
|
|
|
|
logger.info(f"[{req_id}] Input: {len(inputs)} chars - {inputs[:50]}{'...' if len(inputs) > 50 else ''}") |
|
|
|
|
|
|
|
|
parameters = data.get("parameters", {}) |
|
|
max_new_tokens = min(parameters.get("max_new_tokens", data.get("max_tokens", 150)), 512) |
|
|
temperature = max(0.0, min(parameters.get("temperature", data.get("temperature", 0.2)), 2.0)) |
|
|
top_p = max(0.0, min(parameters.get("top_p", data.get("top_p", 1.0)), 1.0)) |
|
|
stop = data.get("stop", parameters.get("stop", [])) |
|
|
|
|
|
|
|
|
generated_text = self._generate_text_internal( |
|
|
inputs, max_new_tokens, temperature, top_p, stop, req_id |
|
|
) |
|
|
|
|
|
|
|
|
result = [{ |
|
|
"generated_text": generated_text |
|
|
}] |
|
|
|
|
|
except Exception as e: |
|
|
total_time = time.time() - req_start_time |
|
|
logger.error(f"[{req_id}] Failed after {total_time:.2f}s: {str(e)}") |
|
|
result = {"error": {"message": str(e), "type": "server_error"}} |
|
|
|
|
|
else: |
|
|
|
|
|
req_id = f"hf-{request_counter}" |
|
|
logger.info(f"[{req_id}] Legacy HF format request") |
|
|
|
|
|
try: |
|
|
inputs = data.get("inputs", "") |
|
|
if not inputs: |
|
|
logger.warning(f"[{req_id}] No inputs provided") |
|
|
return {"error": "No input text provided"} |
|
|
|
|
|
logger.info(f"[{req_id}] Input: {len(inputs)} chars - {inputs[:50]}{'...' if len(inputs) > 50 else ''}") |
|
|
|
|
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
|
|
|
max_new_tokens = min(parameters.get("max_new_tokens", 150), 512) |
|
|
temperature = max(0.1, min(parameters.get("temperature", 0.7), 1.0)) |
|
|
top_p = max(0.1, min(parameters.get("top_p", 0.9), 1.0)) |
|
|
|
|
|
|
|
|
generated_text = self._generate_text_internal( |
|
|
inputs, max_new_tokens, temperature, top_p, [], req_id |
|
|
) |
|
|
|
|
|
|
|
|
result = [{ |
|
|
"generated_text": generated_text |
|
|
}] |
|
|
|
|
|
except Exception as e: |
|
|
total_time = time.time() - req_start_time |
|
|
logger.error(f"[{req_id}] Failed after {total_time:.2f}s: {str(e)}") |
|
|
result = {"error": f"Generation failed: {str(e)}"} |
|
|
|
|
|
|
|
|
total_request_time = time.time() - req_start_time |
|
|
|
|
|
if isinstance(result, dict) and "error" in result: |
|
|
logger.error(f"Request #{request_counter} ❌ Failed in {total_request_time:.2f}s") |
|
|
else: |
|
|
logger.info(f"Request #{request_counter} ✅ Completed in {total_request_time:.2f}s") |
|
|
|
|
|
return result |