File size: 9,319 Bytes
1da0626 1123d56 6526995 1da0626 1123d56 6526995 1123d56 1da0626 1123d56 a86aa15 1da0626 6526995 1da0626 6526995 1da0626 70bd972 1123d56 1da0626 1123d56 6526995 1123d56 1da0626 1123d56 6d3a6db 1123d56 1da0626 a86aa15 1da0626 1123d56 1da0626 70bd972 1da0626 1123d56 a86aa15 1da0626 a86aa15 1da0626 276ec1d 1da0626 6526995 1da0626 1123d56 6526995 1da0626 1123d56 1da0626 1123d56 dbcb334 6526995 1da0626 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 | from typing import Any, Dict, List
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import traceback
import time
class EndpointHandler:
def __init__(self, model_dir: str):
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.model = AutoModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
self.model.eval()
self.device = next(self.model.parameters()).device
self.default_max_new_tokens = 200
self.context_window = 4096
self.safety_buffer = 50
self.default_generation_config = {
"temperature": 0.68,
"top_p": 0.87,
"do_sample": True,
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"repetition_penalty": 1.15
}
if self.tokenizer.bos_token_id is not None:
self.default_generation_config["bos_token_id"] = self.tokenizer.bos_token_id
def _build_prompt(self, messages: List[Dict[str, str]], user_system: str = None) -> str:
"""
Baut den Prompt so, dass der bosnische Sprachbefehl IMMER an erster
Stelle steht und niemals überschrieben werden kann.
"""
# Unverrückbarer Base-Prompt mit Sprachvorgabe
base_prompt = """Ti si profesionalan i brz Telesales agent za Bosnu.
Govori ISKLJUČIVO bosanski/srpski/hrvatski jezik. Nikad ne mijenjaj jezik.
Budi direktan, ljubazan i prodajno jak.
Ne pričaj duge priče. Odmah idi na stvar i pokušaj zatvoriti prodaju."""
# User-System-Message wird nur ERGÄNZT, nie ersetzt
if user_system:
base_prompt += f"\n\nDodatne instrukcije: {user_system}"
# Chat-Verlauf (nur user / assistant, system wird oben behandelt)
prompt = base_prompt + "\n\n"
for msg in messages[-10:]:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "user":
prompt += f"Kupac: {content}\n"
else:
prompt += f"Agent: {content}\n"
prompt += "Agent:"
return prompt
def _clean_response(self, text: str) -> str:
"""
Schneidet halluzinierte Fortsetzungen ab, sobald das Modell
versucht, eine neue Rolle (Kupac/Agent) zu starten.
"""
text = text.strip()
# Falls Antwort mit "Agent:" beginnt, entferne es
if text.lower().startswith("agent:"):
text = text[6:].strip()
# Stop-Strings: Alles ab erstem Auftreten abschneiden
stop_markers = ["\nKupac:", "\nAgent:", "Kupac:", "Agent:"]
for marker in stop_markers:
if marker in text:
text = text.split(marker)[0].strip()
break
return text
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
try:
# ==================== INPUT PARSING ====================
messages = []
user_system = None
if "messages" in data and isinstance(data["messages"], list):
for msg in data["messages"]:
if not isinstance(msg, dict):
continue
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
user_system = content
else:
messages.append({"role": role, "content": content})
elif "inputs" in data:
raw = data["inputs"]
if isinstance(raw, str):
messages = [{"role": "user", "content": raw}]
elif isinstance(raw, list):
for item in raw:
if isinstance(item, dict):
messages.append(item)
else:
messages.append({"role": "user", "content": str(item)})
else:
messages = [{"role": "user", "content": str(raw)}]
if not messages or all(m.get("content", "").strip() == "" for m in messages):
return self._build_error_response("Žao mi je, nisam razumio vašu poruku.")
# ==================== PARAMETER ====================
max_new_tokens = data.get("max_tokens", data.get("max_new_tokens", self.default_max_new_tokens))
if not isinstance(max_new_tokens, int) or max_new_tokens <= 0:
max_new_tokens = self.default_max_new_tokens
max_input_length = self.context_window - max_new_tokens - self.safety_buffer
if max_input_length <= 0:
max_input_length = self.context_window // 2
gen_config = {
"max_new_tokens": max_new_tokens,
"temperature": float(data.get("temperature", self.default_generation_config["temperature"])),
"top_p": float(data.get("top_p", self.default_generation_config["top_p"])),
"do_sample": True,
"pad_token_id": self.default_generation_config["pad_token_id"],
"eos_token_id": self.default_generation_config["eos_token_id"],
"repetition_penalty": float(data.get("repetition_penalty", self.default_generation_config["repetition_penalty"])),
}
if "bos_token_id" in self.default_generation_config:
gen_config["bos_token_id"] = self.default_generation_config["bos_token_id"]
# ==================== PROMPT & LÄNGEN-PRÜFUNG ====================
prompt = self._build_prompt(messages, user_system)
# Prüfen, ob Prompt zu lang ist (vor Truncation)
raw_token_count = len(self.tokenizer.encode(prompt, add_special_tokens=False))
was_truncated = raw_token_count > max_input_length
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=max_input_length,
return_attention_mask=True
).to(self.device)
input_length = inputs["input_ids"].shape[1]
# ==================== GENERIERUNG ====================
with torch.no_grad():
outputs = self.model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
**gen_config
)
new_tokens = outputs[0][input_length:]
completion_length = len(new_tokens)
raw_response = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
# Stop-Strings anwenden
response = self._clean_response(raw_response)
# ==================== OPENAI RESPONSE + WARNUNG ====================
result = {
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": data.get("model", "yugogpt-eng"),
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": input_length,
"completion_tokens": completion_length,
"total_tokens": input_length + completion_length
}
}
# Längen-Warnung mitschicken, falls abgeschnitten
if was_truncated:
result["warning"] = (
f"Kontekst je skracen sa {raw_token_count} na {max_input_tokens} tokena. "
"Rani dio razgovora mozda nedostaje."
)
return result
except Exception:
print(traceback.format_exc())
return self._build_error_response("Žao mi je, došlo je do greške. Možete li ponoviti?")
def _build_error_response(self, message: str) -> Dict[str, Any]:
"""Hilfsfunktion für einheitliche Fehler-Antworten im OpenAI-Format."""
return {
"id": "chatcmpl-error",
"object": "chat.completion",
"created": int(time.time()),
"model": "yugogpt-eng",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": message},
"finish_reason": "stop"
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
}
|