Upload handler.py
Browse files- handler.py +170 -93
handler.py
CHANGED
|
@@ -21,9 +21,8 @@ import importlib.util
|
|
| 21 |
from dataclasses import dataclass
|
| 22 |
from pathlib import Path
|
| 23 |
import hashlib
|
| 24 |
-
import random
|
| 25 |
import re
|
| 26 |
-
from typing import Any, Mapping, MutableMapping, Optional, Sequence
|
| 27 |
|
| 28 |
import numpy as np
|
| 29 |
|
|
@@ -380,61 +379,133 @@ class EndpointHandler:
|
|
| 380 |
|
| 381 |
return self._coerce_array(value, node=node)
|
| 382 |
|
| 383 |
-
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
if self._token_sequence_length <= 0:
|
| 386 |
-
|
| 387 |
-
return array, token_ids
|
| 388 |
|
| 389 |
-
length = min(len(token_ids), self._token_sequence_length)
|
| 390 |
padded = np.full(
|
| 391 |
(1, self._token_sequence_length),
|
| 392 |
fill_value=self._tokenizer.pad_token_id,
|
| 393 |
dtype=self._token_dtype,
|
| 394 |
)
|
| 395 |
-
|
| 396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
|
| 398 |
@staticmethod
|
| 399 |
-
def
|
| 400 |
-
|
| 401 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
|
| 403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
self,
|
| 405 |
-
|
| 406 |
*,
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
]
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
|
| 439 |
@staticmethod
|
| 440 |
def _extract_q_hat(outputs: Sequence[tuple[Any, np.ndarray]]) -> float:
|
|
@@ -470,66 +541,72 @@ class EndpointHandler:
|
|
| 470 |
system_prompt = payload.get("system_prompt")
|
| 471 |
user_prompt = payload.get("user_prompt")
|
| 472 |
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
user_prompt=user_prompt if isinstance(user_prompt, str) else None,
|
| 476 |
-
system_prompt=system_prompt if isinstance(system_prompt, str) else None,
|
| 477 |
-
constraints=state_constraints,
|
| 478 |
-
)
|
| 479 |
-
|
| 480 |
-
best_text: str | None = None
|
| 481 |
-
best_tokens: list[int] = []
|
| 482 |
-
best_outputs: list[tuple[Any, np.ndarray]] | None = None
|
| 483 |
-
best_quality = float("-inf")
|
| 484 |
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
if self.tokens_input is None:
|
| 488 |
-
break
|
| 489 |
-
token_array, token_ids = self._encode_tokens(candidate)
|
| 490 |
-
outputs = self._run_candidate(feed, token_array)
|
| 491 |
-
quality = self._extract_q_hat(outputs)
|
| 492 |
-
if quality > best_quality:
|
| 493 |
-
best_quality = quality
|
| 494 |
-
best_text = candidate
|
| 495 |
-
best_tokens = token_ids
|
| 496 |
-
best_outputs = outputs
|
| 497 |
-
if quality >= decoding.stop_quality:
|
| 498 |
-
break
|
| 499 |
|
| 500 |
-
if
|
| 501 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
outputs = self.session.run(None, feed)
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
|
|
|
|
|
|
| 508 |
|
| 509 |
formatted = {
|
| 510 |
node.name: self._format_output(node.name, value)
|
| 511 |
-
for node, value in
|
| 512 |
}
|
| 513 |
|
| 514 |
-
if not np.isfinite(
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
|
| 520 |
response = {
|
| 521 |
-
"text":
|
| 522 |
-
"tokens":
|
| 523 |
-
"quality":
|
| 524 |
-
"q_hat":
|
| 525 |
"provider": _DEFAULT_PROVIDER,
|
| 526 |
"model": _DEFAULT_MODEL,
|
| 527 |
-
"metadata":
|
| 528 |
-
"summary": summary,
|
| 529 |
-
"descriptors": descriptors,
|
| 530 |
-
"constraints": state_constraints or {},
|
| 531 |
-
"decoding": decoding.to_dict(),
|
| 532 |
-
},
|
| 533 |
}
|
| 534 |
response.update(formatted)
|
| 535 |
return response
|
|
|
|
| 21 |
from dataclasses import dataclass
|
| 22 |
from pathlib import Path
|
| 23 |
import hashlib
|
|
|
|
| 24 |
import re
|
| 25 |
+
from typing import Any, Mapping, MutableMapping, Optional, Sequence, Tuple
|
| 26 |
|
| 27 |
import numpy as np
|
| 28 |
|
|
|
|
| 379 |
|
| 380 |
return self._coerce_array(value, node=node)
|
| 381 |
|
| 382 |
+
@staticmethod
|
| 383 |
+
def _candidate_seed(psi: np.ndarray) -> int:
|
| 384 |
+
digest = hashlib.sha1(psi.tobytes()).digest()
|
| 385 |
+
return int.from_bytes(digest[:4], "little", signed=False)
|
| 386 |
+
|
| 387 |
+
def _token_array_from_ids(self, token_ids: Sequence[int]) -> np.ndarray:
|
| 388 |
+
ids = list(token_ids)
|
| 389 |
if self._token_sequence_length <= 0:
|
| 390 |
+
return np.asarray([ids], dtype=self._token_dtype)
|
|
|
|
| 391 |
|
|
|
|
| 392 |
padded = np.full(
|
| 393 |
(1, self._token_sequence_length),
|
| 394 |
fill_value=self._tokenizer.pad_token_id,
|
| 395 |
dtype=self._token_dtype,
|
| 396 |
)
|
| 397 |
+
length = min(len(ids), self._token_sequence_length)
|
| 398 |
+
if length > 0:
|
| 399 |
+
padded[0, :length] = np.asarray(ids[:length], dtype=self._token_dtype)
|
| 400 |
+
return padded
|
| 401 |
+
|
| 402 |
+
def _run_candidate(self, base_feed: Mapping[str, np.ndarray], tokens: Sequence[int]) -> list[tuple[Any, np.ndarray]]:
|
| 403 |
+
feed = {
|
| 404 |
+
name: (value.copy() if isinstance(value, np.ndarray) else value)
|
| 405 |
+
for name, value in base_feed.items()
|
| 406 |
+
}
|
| 407 |
+
if self.tokens_input is not None:
|
| 408 |
+
feed[self.tokens_input] = self._token_array_from_ids(tokens)
|
| 409 |
+
outputs = self.session.run(None, feed)
|
| 410 |
+
return list(zip(self.io.outputs, outputs))
|
| 411 |
|
| 412 |
@staticmethod
|
| 413 |
+
def _extract_logits(outputs: Sequence[tuple[Any, np.ndarray]]) -> Optional[np.ndarray]:
|
| 414 |
+
for node, value in outputs:
|
| 415 |
+
if getattr(node, "name", "").lower() == "logits":
|
| 416 |
+
return np.asarray(value, dtype=np.float32)
|
| 417 |
+
if outputs:
|
| 418 |
+
return np.asarray(outputs[0][1], dtype=np.float32)
|
| 419 |
+
return None
|
| 420 |
|
| 421 |
+
@staticmethod
|
| 422 |
+
def _sample_next_token(
|
| 423 |
+
logits: np.ndarray,
|
| 424 |
+
decoding: _DecodingParams,
|
| 425 |
+
rng: np.random.Generator,
|
| 426 |
+
) -> int:
|
| 427 |
+
vector = np.asarray(logits, dtype=np.float64).reshape(-1)
|
| 428 |
+
temperature = max(float(decoding.temperature), 1e-5)
|
| 429 |
+
top_p = float(decoding.top_p)
|
| 430 |
+
|
| 431 |
+
if temperature <= 1e-5 or not np.isfinite(vector).any():
|
| 432 |
+
return int(int(np.argmax(vector)))
|
| 433 |
+
|
| 434 |
+
stabilized = vector / temperature
|
| 435 |
+
stabilized -= np.max(stabilized)
|
| 436 |
+
probs = np.exp(stabilized)
|
| 437 |
+
probs = np.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
|
| 438 |
+
total = probs.sum()
|
| 439 |
+
if total <= 0.0:
|
| 440 |
+
return int(np.argmax(vector))
|
| 441 |
+
probs /= total
|
| 442 |
+
|
| 443 |
+
if top_p <= 0.0:
|
| 444 |
+
return int(np.argmax(probs))
|
| 445 |
+
|
| 446 |
+
if 0.0 < top_p < 1.0:
|
| 447 |
+
sorted_indices = np.argsort(-probs)
|
| 448 |
+
sorted_probs = probs[sorted_indices]
|
| 449 |
+
cumulative = np.cumsum(sorted_probs)
|
| 450 |
+
mask = cumulative <= top_p
|
| 451 |
+
if mask.size > 0:
|
| 452 |
+
mask[0] = True
|
| 453 |
+
filtered_indices = sorted_indices[mask]
|
| 454 |
+
filtered_probs = sorted_probs[mask]
|
| 455 |
+
filtered_total = filtered_probs.sum()
|
| 456 |
+
if filtered_total <= 0.0:
|
| 457 |
+
filtered_indices = sorted_indices
|
| 458 |
+
filtered_probs = sorted_probs
|
| 459 |
+
filtered_total = filtered_probs.sum()
|
| 460 |
+
filtered_probs = filtered_probs / filtered_total
|
| 461 |
+
choice = rng.choice(len(filtered_indices), p=filtered_probs)
|
| 462 |
+
return int(filtered_indices[int(choice)])
|
| 463 |
+
|
| 464 |
+
choice = rng.choice(len(probs), p=probs)
|
| 465 |
+
return int(choice)
|
| 466 |
+
|
| 467 |
+
def _generate_sequence(
|
| 468 |
self,
|
| 469 |
+
base_feed: Mapping[str, np.ndarray],
|
| 470 |
*,
|
| 471 |
+
decoding: _DecodingParams,
|
| 472 |
+
seed: int,
|
| 473 |
+
) -> Optional[Tuple[str, list[int], float, list[tuple[Any, np.ndarray]], int]]:
|
| 474 |
+
if self.tokens_input is None:
|
| 475 |
+
return None
|
| 476 |
+
|
| 477 |
+
rng = np.random.default_rng(seed)
|
| 478 |
+
token_ids: list[int] = [self._tokenizer.bos_token_id]
|
| 479 |
+
quality = float("-inf")
|
| 480 |
+
formatted_outputs: list[tuple[Any, np.ndarray]] | None = None
|
| 481 |
+
steps = 0
|
| 482 |
+
|
| 483 |
+
max_steps = max(decoding.max_new_tokens, 1)
|
| 484 |
+
for _ in range(max_steps):
|
| 485 |
+
outputs = self._run_candidate(base_feed, token_ids)
|
| 486 |
+
logits = self._extract_logits(outputs)
|
| 487 |
+
if logits is None:
|
| 488 |
+
break
|
| 489 |
+
last_index = min(len(token_ids) - 1, logits.shape[1] - 1)
|
| 490 |
+
next_logits = logits[0, last_index]
|
| 491 |
+
next_token = self._sample_next_token(next_logits, decoding, rng)
|
| 492 |
+
token_ids.append(int(next_token))
|
| 493 |
+
steps += 1
|
| 494 |
+
|
| 495 |
+
outputs = self._run_candidate(base_feed, token_ids)
|
| 496 |
+
formatted_outputs = outputs
|
| 497 |
+
quality = self._extract_q_hat(outputs)
|
| 498 |
+
|
| 499 |
+
if token_ids[-1] == self._tokenizer.eos_token_id:
|
| 500 |
+
break
|
| 501 |
+
if self._token_sequence_length > 0 and len(token_ids) >= self._token_sequence_length:
|
| 502 |
+
break
|
| 503 |
+
|
| 504 |
+
if formatted_outputs is None:
|
| 505 |
+
return None
|
| 506 |
+
|
| 507 |
+
text = self._tokenizer.decode(token_ids)
|
| 508 |
+
return text, token_ids, float(quality), formatted_outputs, steps
|
| 509 |
|
| 510 |
@staticmethod
|
| 511 |
def _extract_q_hat(outputs: Sequence[tuple[Any, np.ndarray]]) -> float:
|
|
|
|
| 541 |
system_prompt = payload.get("system_prompt")
|
| 542 |
user_prompt = payload.get("user_prompt")
|
| 543 |
|
| 544 |
+
descriptors = _summarise_intent(psi_vector)
|
| 545 |
+
summary = ", ".join(descriptors) if descriptors else "balanced intent"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 546 |
|
| 547 |
+
best_candidate: Optional[Tuple[str, list[int], float, list[tuple[Any, np.ndarray]], int]] = None
|
| 548 |
+
seeds: list[int] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 549 |
|
| 550 |
+
if self.tokens_input is not None:
|
| 551 |
+
beams = max(decoding.beam_size, 1)
|
| 552 |
+
base_seed = self._candidate_seed(psi_vector)
|
| 553 |
+
for beam_idx in range(beams):
|
| 554 |
+
seed = base_seed + beam_idx
|
| 555 |
+
seeds.append(seed)
|
| 556 |
+
candidate = self._generate_sequence(
|
| 557 |
+
feed,
|
| 558 |
+
decoding=decoding,
|
| 559 |
+
seed=seed,
|
| 560 |
+
)
|
| 561 |
+
if candidate is None:
|
| 562 |
+
continue
|
| 563 |
+
text, token_ids, quality, outputs, steps = candidate
|
| 564 |
+
if (
|
| 565 |
+
best_candidate is None
|
| 566 |
+
or quality > best_candidate[2]
|
| 567 |
+
):
|
| 568 |
+
best_candidate = candidate
|
| 569 |
+
if quality >= decoding.stop_quality:
|
| 570 |
+
break
|
| 571 |
+
|
| 572 |
+
if best_candidate is None:
|
| 573 |
outputs = self.session.run(None, feed)
|
| 574 |
+
formatted_outputs = list(zip(self.io.outputs, outputs))
|
| 575 |
+
quality = self._extract_q_hat(formatted_outputs)
|
| 576 |
+
text = f"Symbolic synopsis → {summary}."
|
| 577 |
+
token_ids: list[int] = []
|
| 578 |
+
steps = 0
|
| 579 |
+
else:
|
| 580 |
+
text, token_ids, quality, formatted_outputs, steps = best_candidate
|
| 581 |
|
| 582 |
formatted = {
|
| 583 |
node.name: self._format_output(node.name, value)
|
| 584 |
+
for node, value in formatted_outputs
|
| 585 |
}
|
| 586 |
|
| 587 |
+
if not np.isfinite(quality):
|
| 588 |
+
quality = 0.0
|
| 589 |
+
quality = float(quality)
|
| 590 |
+
|
| 591 |
+
metadata = {
|
| 592 |
+
"summary": summary,
|
| 593 |
+
"descriptors": descriptors,
|
| 594 |
+
"constraints": state_constraints or {},
|
| 595 |
+
"decoding": decoding.to_dict(),
|
| 596 |
+
"seeds": seeds,
|
| 597 |
+
"steps": steps,
|
| 598 |
+
"system_prompt": system_prompt if isinstance(system_prompt, str) else None,
|
| 599 |
+
"user_prompt": user_prompt if isinstance(user_prompt, str) else None,
|
| 600 |
+
}
|
| 601 |
|
| 602 |
response = {
|
| 603 |
+
"text": text,
|
| 604 |
+
"tokens": token_ids,
|
| 605 |
+
"quality": quality,
|
| 606 |
+
"q_hat": quality,
|
| 607 |
"provider": _DEFAULT_PROVIDER,
|
| 608 |
"model": _DEFAULT_MODEL,
|
| 609 |
+
"metadata": metadata,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 610 |
}
|
| 611 |
response.update(formatted)
|
| 612 |
return response
|