from typing import Any, Dict, List from transformers import AutoModelForCausalLM, AutoTokenizer import torch import traceback import time class EndpointHandler: def __init__(self, model_dir: str): self.tokenizer = AutoTokenizer.from_pretrained(model_dir) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.model = AutoModelForCausalLM.from_pretrained( model_dir, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) self.model.eval() self.device = next(self.model.parameters()).device self.default_max_new_tokens = 200 self.context_window = 4096 self.safety_buffer = 50 self.default_generation_config = { "temperature": 0.68, "top_p": 0.87, "do_sample": True, "pad_token_id": self.tokenizer.pad_token_id, "eos_token_id": self.tokenizer.eos_token_id, "repetition_penalty": 1.15 } if self.tokenizer.bos_token_id is not None: self.default_generation_config["bos_token_id"] = self.tokenizer.bos_token_id def _build_prompt(self, messages: List[Dict[str, str]], user_system: str = None) -> str: """ Baut den Prompt so, dass der bosnische Sprachbefehl IMMER an erster Stelle steht und niemals überschrieben werden kann. """ # Unverrückbarer Base-Prompt mit Sprachvorgabe base_prompt = """Ti si profesionalan i brz Telesales agent za Bosnu. Govori ISKLJUČIVO bosanski/srpski/hrvatski jezik. Nikad ne mijenjaj jezik. Budi direktan, ljubazan i prodajno jak. Ne pričaj duge priče. Odmah idi na stvar i pokušaj zatvoriti prodaju.""" # User-System-Message wird nur ERGÄNZT, nie ersetzt if user_system: base_prompt += f"\n\nDodatne instrukcije: {user_system}" # Chat-Verlauf (nur user / assistant, system wird oben behandelt) prompt = base_prompt + "\n\n" for msg in messages[-10:]: role = msg.get("role", "user") content = msg.get("content", "") if role == "user": prompt += f"Kupac: {content}\n" else: prompt += f"Agent: {content}\n" prompt += "Agent:" return prompt def _clean_response(self, text: str) -> str: """ Schneidet halluzinierte Fortsetzungen ab, sobald das Modell versucht, eine neue Rolle (Kupac/Agent) zu starten. """ text = text.strip() # Falls Antwort mit "Agent:" beginnt, entferne es if text.lower().startswith("agent:"): text = text[6:].strip() # Stop-Strings: Alles ab erstem Auftreten abschneiden stop_markers = ["\nKupac:", "\nAgent:", "Kupac:", "Agent:"] for marker in stop_markers: if marker in text: text = text.split(marker)[0].strip() break return text def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: try: # ==================== INPUT PARSING ==================== messages = [] user_system = None if "messages" in data and isinstance(data["messages"], list): for msg in data["messages"]: if not isinstance(msg, dict): continue role = msg.get("role", "user") content = msg.get("content", "") if role == "system": user_system = content else: messages.append({"role": role, "content": content}) elif "inputs" in data: raw = data["inputs"] if isinstance(raw, str): messages = [{"role": "user", "content": raw}] elif isinstance(raw, list): for item in raw: if isinstance(item, dict): messages.append(item) else: messages.append({"role": "user", "content": str(item)}) else: messages = [{"role": "user", "content": str(raw)}] if not messages or all(m.get("content", "").strip() == "" for m in messages): return self._build_error_response("Žao mi je, nisam razumio vašu poruku.") # ==================== PARAMETER ==================== max_new_tokens = data.get("max_tokens", data.get("max_new_tokens", self.default_max_new_tokens)) if not isinstance(max_new_tokens, int) or max_new_tokens <= 0: max_new_tokens = self.default_max_new_tokens max_input_length = self.context_window - max_new_tokens - self.safety_buffer if max_input_length <= 0: max_input_length = self.context_window // 2 gen_config = { "max_new_tokens": max_new_tokens, "temperature": float(data.get("temperature", self.default_generation_config["temperature"])), "top_p": float(data.get("top_p", self.default_generation_config["top_p"])), "do_sample": True, "pad_token_id": self.default_generation_config["pad_token_id"], "eos_token_id": self.default_generation_config["eos_token_id"], "repetition_penalty": float(data.get("repetition_penalty", self.default_generation_config["repetition_penalty"])), } if "bos_token_id" in self.default_generation_config: gen_config["bos_token_id"] = self.default_generation_config["bos_token_id"] # ==================== PROMPT & LÄNGEN-PRÜFUNG ==================== prompt = self._build_prompt(messages, user_system) # Prüfen, ob Prompt zu lang ist (vor Truncation) raw_token_count = len(self.tokenizer.encode(prompt, add_special_tokens=False)) was_truncated = raw_token_count > max_input_length inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=max_input_length, return_attention_mask=True ).to(self.device) input_length = inputs["input_ids"].shape[1] # ==================== GENERIERUNG ==================== with torch.no_grad(): outputs = self.model.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, **gen_config ) new_tokens = outputs[0][input_length:] completion_length = len(new_tokens) raw_response = self.tokenizer.decode(new_tokens, skip_special_tokens=True) # Stop-Strings anwenden response = self._clean_response(raw_response) # ==================== OPENAI RESPONSE + WARNUNG ==================== result = { "id": f"chatcmpl-{int(time.time())}", "object": "chat.completion", "created": int(time.time()), "model": data.get("model", "yugogpt-eng"), "choices": [ { "index": 0, "message": { "role": "assistant", "content": response }, "finish_reason": "stop" } ], "usage": { "prompt_tokens": input_length, "completion_tokens": completion_length, "total_tokens": input_length + completion_length } } # Längen-Warnung mitschicken, falls abgeschnitten if was_truncated: result["warning"] = ( f"Kontekst je skracen sa {raw_token_count} na {max_input_tokens} tokena. " "Rani dio razgovora mozda nedostaje." ) return result except Exception: print(traceback.format_exc()) return self._build_error_response("Žao mi je, došlo je do greške. Možete li ponoviti?") def _build_error_response(self, message: str) -> Dict[str, Any]: """Hilfsfunktion für einheitliche Fehler-Antworten im OpenAI-Format.""" return { "id": "chatcmpl-error", "object": "chat.completion", "created": int(time.time()), "model": "yugogpt-eng", "choices": [ { "index": 0, "message": {"role": "assistant", "content": message}, "finish_reason": "stop" } ], "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} }