from typing import Dict, Any, List import torch import time import uuid from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import PeftModel DEFAULT_SYSTEM_PROMPT = "You are an expert Minecraft Forge mod developer for version 1.21.11. Write clean, efficient, and well-structured Java code." class EndpointHandler: def __init__(self, path: str = ""): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model_id = "hwding/forge-coder-v1.21.11" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) base_model_id = "deepseek-ai/deepseek-coder-6.7b-instruct" self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForCausalLM.from_pretrained( base_model_id, quantization_config=bnb_config, device_map="auto", trust_remote_code=True, ) self.model = PeftModel.from_pretrained(self.model, path) self.model.eval() def _format_messages(self, messages: List[Dict[str, str]]) -> str: prompt_parts = [] has_system = False for msg in messages: role = msg.get("role", "") content = msg.get("content", "") if role == "system": prompt_parts.append(f"### System:\n{content}") has_system = True elif role == "user": prompt_parts.append(f"### User:\n{content}") elif role == "assistant": prompt_parts.append(f"### Assistant:\n{content}") if not has_system: prompt_parts.insert(0, f"### System:\n{DEFAULT_SYSTEM_PROMPT}") prompt_parts.append("### Assistant:\n") return "\n\n".join(prompt_parts) def _generate(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> str: input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.model.generate( **input_ids, max_new_tokens=max_new_tokens, temperature=temperature if temperature > 0 else 1.0, top_p=top_p, do_sample=temperature > 0, pad_token_id=self.tokenizer.eos_token_id, ) generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) if "### Assistant:" in generated_text: generated_text = generated_text.split("### Assistant:")[-1].strip() return generated_text def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: messages = data.get("messages") if messages: return self._handle_openai_format(data) return self._handle_simple_format(data) def _handle_openai_format(self, data: Dict[str, Any]) -> Dict[str, Any]: messages = data.get("messages", []) max_tokens = data.get("max_tokens", 512) temperature = data.get("temperature", 0.7) top_p = data.get("top_p", 0.95) prompt = self._format_messages(messages) generated_text = self._generate(prompt, max_tokens, temperature, top_p) prompt_tokens = len(self.tokenizer.encode(prompt)) completion_tokens = len(self.tokenizer.encode(generated_text)) return { "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", "object": "chat.completion", "created": int(time.time()), "model": self.model_id, "choices": [{ "index": 0, "message": { "role": "assistant", "content": generated_text, }, "finish_reason": "stop", }], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, } } def _handle_simple_format(self, data: Dict[str, Any]) -> Dict[str, Any]: inputs = data.get("inputs", "") parameters = data.get("parameters", {}) max_new_tokens = parameters.get("max_new_tokens", 512) temperature = parameters.get("temperature", 0.7) top_p = parameters.get("top_p", 0.95) if not inputs.startswith("### System:"): prompt = f"### System:\n{DEFAULT_SYSTEM_PROMPT}\n\n### User:\n{inputs}\n\n### Assistant:\n" else: prompt = inputs generated_text = self._generate(prompt, max_new_tokens, temperature, top_p) return {"generated_text": generated_text}