gnai-creator commited on
Commit
e9b5ee1
·
verified ·
1 Parent(s): 6a09de4

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +170 -93
handler.py CHANGED
@@ -21,9 +21,8 @@ import importlib.util
21
  from dataclasses import dataclass
22
  from pathlib import Path
23
  import hashlib
24
- import random
25
  import re
26
- from typing import Any, Mapping, MutableMapping, Optional, Sequence
27
 
28
  import numpy as np
29
 
@@ -380,61 +379,133 @@ class EndpointHandler:
380
 
381
  return self._coerce_array(value, node=node)
382
 
383
- def _encode_tokens(self, text: str) -> tuple[np.ndarray, list[int]]:
384
- token_ids = self._tokenizer.encode(text)
 
 
 
 
 
385
  if self._token_sequence_length <= 0:
386
- array = np.asarray([token_ids], dtype=self._token_dtype)
387
- return array, token_ids
388
 
389
- length = min(len(token_ids), self._token_sequence_length)
390
  padded = np.full(
391
  (1, self._token_sequence_length),
392
  fill_value=self._tokenizer.pad_token_id,
393
  dtype=self._token_dtype,
394
  )
395
- padded[0, :length] = np.asarray(token_ids[:length], dtype=self._token_dtype)
396
- return padded, token_ids[:length]
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
  @staticmethod
399
- def _candidate_seed(psi: np.ndarray) -> int:
400
- digest = hashlib.sha1(psi.tobytes()).digest()
401
- return int.from_bytes(digest[:4], "little", signed=False)
 
 
 
 
402
 
403
- def _build_candidates(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  self,
405
- psi_vector: np.ndarray,
406
  *,
407
- user_prompt: str | None,
408
- system_prompt: str | None,
409
- constraints: Mapping[str, Any] | None,
410
- ) -> tuple[list[str], str, list[str]]:
411
- descriptors = _summarise_intent(psi_vector)
412
- summary = ", ".join(descriptors) if descriptors else "balanced intent"
413
- observations = [
414
- f"Interpretation: the symbolic intent emphasises {summary}.",
415
- f"Symbolic synopsis → {summary}.",
416
- ]
417
- if user_prompt:
418
- observations.append(f"{user_prompt.strip()}\nInsight: {summary}.")
419
- if system_prompt:
420
- observations.append(f"{system_prompt.strip()}\nDirective: honour {summary}.")
421
- if constraints:
422
- formatted = ", ".join(f"{key}={value}" for key, value in constraints.items())
423
- observations.append(f"Constraints observed: {formatted}.")
424
-
425
- seed = self._candidate_seed(psi_vector.astype(np.float32, copy=False))
426
- rng = random.Random(seed)
427
- rng.shuffle(observations)
428
- if not observations:
429
- observations = [f"Symbolic synopsis → {summary}."]
430
- return observations, summary, descriptors
431
-
432
- def _run_candidate(self, base_feed: Mapping[str, np.ndarray], tokens: np.ndarray) -> list[tuple[Any, np.ndarray]]:
433
- feed = {name: value for name, value in base_feed.items()}
434
- if self.tokens_input is not None:
435
- feed[self.tokens_input] = tokens
436
- outputs = self.session.run(None, feed)
437
- return list(zip(self.io.outputs, outputs))
 
 
 
 
 
 
 
438
 
439
  @staticmethod
440
  def _extract_q_hat(outputs: Sequence[tuple[Any, np.ndarray]]) -> float:
@@ -470,66 +541,72 @@ class EndpointHandler:
470
  system_prompt = payload.get("system_prompt")
471
  user_prompt = payload.get("user_prompt")
472
 
473
- candidates, summary, descriptors = self._build_candidates(
474
- psi_vector,
475
- user_prompt=user_prompt if isinstance(user_prompt, str) else None,
476
- system_prompt=system_prompt if isinstance(system_prompt, str) else None,
477
- constraints=state_constraints,
478
- )
479
-
480
- best_text: str | None = None
481
- best_tokens: list[int] = []
482
- best_outputs: list[tuple[Any, np.ndarray]] | None = None
483
- best_quality = float("-inf")
484
 
485
- limit = min(len(candidates), max(decoding.beam_size, 1))
486
- for candidate in candidates[:limit]:
487
- if self.tokens_input is None:
488
- break
489
- token_array, token_ids = self._encode_tokens(candidate)
490
- outputs = self._run_candidate(feed, token_array)
491
- quality = self._extract_q_hat(outputs)
492
- if quality > best_quality:
493
- best_quality = quality
494
- best_text = candidate
495
- best_tokens = token_ids
496
- best_outputs = outputs
497
- if quality >= decoding.stop_quality:
498
- break
499
 
500
- if best_outputs is None:
501
- # Fall back to a single pass using the prepared feed.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
  outputs = self.session.run(None, feed)
503
- best_outputs = list(zip(self.io.outputs, outputs))
504
- if best_text is None:
505
- best_text = f"Symbolic synopsis → {summary}."
506
- if best_quality == float("-inf"):
507
- best_quality = self._extract_q_hat(best_outputs)
 
 
508
 
509
  formatted = {
510
  node.name: self._format_output(node.name, value)
511
- for node, value in best_outputs
512
  }
513
 
514
- if not np.isfinite(best_quality):
515
- best_quality = 0.0
516
- best_quality = float(best_quality)
517
- if best_text is None:
518
- best_text = f"Symbolic synopsis → {summary}."
 
 
 
 
 
 
 
 
 
519
 
520
  response = {
521
- "text": best_text,
522
- "tokens": best_tokens,
523
- "quality": best_quality,
524
- "q_hat": best_quality,
525
  "provider": _DEFAULT_PROVIDER,
526
  "model": _DEFAULT_MODEL,
527
- "metadata": {
528
- "summary": summary,
529
- "descriptors": descriptors,
530
- "constraints": state_constraints or {},
531
- "decoding": decoding.to_dict(),
532
- },
533
  }
534
  response.update(formatted)
535
  return response
 
21
  from dataclasses import dataclass
22
  from pathlib import Path
23
  import hashlib
 
24
  import re
25
+ from typing import Any, Mapping, MutableMapping, Optional, Sequence, Tuple
26
 
27
  import numpy as np
28
 
 
379
 
380
  return self._coerce_array(value, node=node)
381
 
382
+ @staticmethod
383
+ def _candidate_seed(psi: np.ndarray) -> int:
384
+ digest = hashlib.sha1(psi.tobytes()).digest()
385
+ return int.from_bytes(digest[:4], "little", signed=False)
386
+
387
+ def _token_array_from_ids(self, token_ids: Sequence[int]) -> np.ndarray:
388
+ ids = list(token_ids)
389
  if self._token_sequence_length <= 0:
390
+ return np.asarray([ids], dtype=self._token_dtype)
 
391
 
 
392
  padded = np.full(
393
  (1, self._token_sequence_length),
394
  fill_value=self._tokenizer.pad_token_id,
395
  dtype=self._token_dtype,
396
  )
397
+ length = min(len(ids), self._token_sequence_length)
398
+ if length > 0:
399
+ padded[0, :length] = np.asarray(ids[:length], dtype=self._token_dtype)
400
+ return padded
401
+
402
+ def _run_candidate(self, base_feed: Mapping[str, np.ndarray], tokens: Sequence[int]) -> list[tuple[Any, np.ndarray]]:
403
+ feed = {
404
+ name: (value.copy() if isinstance(value, np.ndarray) else value)
405
+ for name, value in base_feed.items()
406
+ }
407
+ if self.tokens_input is not None:
408
+ feed[self.tokens_input] = self._token_array_from_ids(tokens)
409
+ outputs = self.session.run(None, feed)
410
+ return list(zip(self.io.outputs, outputs))
411
 
412
  @staticmethod
413
+ def _extract_logits(outputs: Sequence[tuple[Any, np.ndarray]]) -> Optional[np.ndarray]:
414
+ for node, value in outputs:
415
+ if getattr(node, "name", "").lower() == "logits":
416
+ return np.asarray(value, dtype=np.float32)
417
+ if outputs:
418
+ return np.asarray(outputs[0][1], dtype=np.float32)
419
+ return None
420
 
421
+ @staticmethod
422
+ def _sample_next_token(
423
+ logits: np.ndarray,
424
+ decoding: _DecodingParams,
425
+ rng: np.random.Generator,
426
+ ) -> int:
427
+ vector = np.asarray(logits, dtype=np.float64).reshape(-1)
428
+ temperature = max(float(decoding.temperature), 1e-5)
429
+ top_p = float(decoding.top_p)
430
+
431
+ if temperature <= 1e-5 or not np.isfinite(vector).any():
432
+ return int(int(np.argmax(vector)))
433
+
434
+ stabilized = vector / temperature
435
+ stabilized -= np.max(stabilized)
436
+ probs = np.exp(stabilized)
437
+ probs = np.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
438
+ total = probs.sum()
439
+ if total <= 0.0:
440
+ return int(np.argmax(vector))
441
+ probs /= total
442
+
443
+ if top_p <= 0.0:
444
+ return int(np.argmax(probs))
445
+
446
+ if 0.0 < top_p < 1.0:
447
+ sorted_indices = np.argsort(-probs)
448
+ sorted_probs = probs[sorted_indices]
449
+ cumulative = np.cumsum(sorted_probs)
450
+ mask = cumulative <= top_p
451
+ if mask.size > 0:
452
+ mask[0] = True
453
+ filtered_indices = sorted_indices[mask]
454
+ filtered_probs = sorted_probs[mask]
455
+ filtered_total = filtered_probs.sum()
456
+ if filtered_total <= 0.0:
457
+ filtered_indices = sorted_indices
458
+ filtered_probs = sorted_probs
459
+ filtered_total = filtered_probs.sum()
460
+ filtered_probs = filtered_probs / filtered_total
461
+ choice = rng.choice(len(filtered_indices), p=filtered_probs)
462
+ return int(filtered_indices[int(choice)])
463
+
464
+ choice = rng.choice(len(probs), p=probs)
465
+ return int(choice)
466
+
467
+ def _generate_sequence(
468
  self,
469
+ base_feed: Mapping[str, np.ndarray],
470
  *,
471
+ decoding: _DecodingParams,
472
+ seed: int,
473
+ ) -> Optional[Tuple[str, list[int], float, list[tuple[Any, np.ndarray]], int]]:
474
+ if self.tokens_input is None:
475
+ return None
476
+
477
+ rng = np.random.default_rng(seed)
478
+ token_ids: list[int] = [self._tokenizer.bos_token_id]
479
+ quality = float("-inf")
480
+ formatted_outputs: list[tuple[Any, np.ndarray]] | None = None
481
+ steps = 0
482
+
483
+ max_steps = max(decoding.max_new_tokens, 1)
484
+ for _ in range(max_steps):
485
+ outputs = self._run_candidate(base_feed, token_ids)
486
+ logits = self._extract_logits(outputs)
487
+ if logits is None:
488
+ break
489
+ last_index = min(len(token_ids) - 1, logits.shape[1] - 1)
490
+ next_logits = logits[0, last_index]
491
+ next_token = self._sample_next_token(next_logits, decoding, rng)
492
+ token_ids.append(int(next_token))
493
+ steps += 1
494
+
495
+ outputs = self._run_candidate(base_feed, token_ids)
496
+ formatted_outputs = outputs
497
+ quality = self._extract_q_hat(outputs)
498
+
499
+ if token_ids[-1] == self._tokenizer.eos_token_id:
500
+ break
501
+ if self._token_sequence_length > 0 and len(token_ids) >= self._token_sequence_length:
502
+ break
503
+
504
+ if formatted_outputs is None:
505
+ return None
506
+
507
+ text = self._tokenizer.decode(token_ids)
508
+ return text, token_ids, float(quality), formatted_outputs, steps
509
 
510
  @staticmethod
511
  def _extract_q_hat(outputs: Sequence[tuple[Any, np.ndarray]]) -> float:
 
541
  system_prompt = payload.get("system_prompt")
542
  user_prompt = payload.get("user_prompt")
543
 
544
+ descriptors = _summarise_intent(psi_vector)
545
+ summary = ", ".join(descriptors) if descriptors else "balanced intent"
 
 
 
 
 
 
 
 
 
546
 
547
+ best_candidate: Optional[Tuple[str, list[int], float, list[tuple[Any, np.ndarray]], int]] = None
548
+ seeds: list[int] = []
 
 
 
 
 
 
 
 
 
 
 
 
549
 
550
+ if self.tokens_input is not None:
551
+ beams = max(decoding.beam_size, 1)
552
+ base_seed = self._candidate_seed(psi_vector)
553
+ for beam_idx in range(beams):
554
+ seed = base_seed + beam_idx
555
+ seeds.append(seed)
556
+ candidate = self._generate_sequence(
557
+ feed,
558
+ decoding=decoding,
559
+ seed=seed,
560
+ )
561
+ if candidate is None:
562
+ continue
563
+ text, token_ids, quality, outputs, steps = candidate
564
+ if (
565
+ best_candidate is None
566
+ or quality > best_candidate[2]
567
+ ):
568
+ best_candidate = candidate
569
+ if quality >= decoding.stop_quality:
570
+ break
571
+
572
+ if best_candidate is None:
573
  outputs = self.session.run(None, feed)
574
+ formatted_outputs = list(zip(self.io.outputs, outputs))
575
+ quality = self._extract_q_hat(formatted_outputs)
576
+ text = f"Symbolic synopsis → {summary}."
577
+ token_ids: list[int] = []
578
+ steps = 0
579
+ else:
580
+ text, token_ids, quality, formatted_outputs, steps = best_candidate
581
 
582
  formatted = {
583
  node.name: self._format_output(node.name, value)
584
+ for node, value in formatted_outputs
585
  }
586
 
587
+ if not np.isfinite(quality):
588
+ quality = 0.0
589
+ quality = float(quality)
590
+
591
+ metadata = {
592
+ "summary": summary,
593
+ "descriptors": descriptors,
594
+ "constraints": state_constraints or {},
595
+ "decoding": decoding.to_dict(),
596
+ "seeds": seeds,
597
+ "steps": steps,
598
+ "system_prompt": system_prompt if isinstance(system_prompt, str) else None,
599
+ "user_prompt": user_prompt if isinstance(user_prompt, str) else None,
600
+ }
601
 
602
  response = {
603
+ "text": text,
604
+ "tokens": token_ids,
605
+ "quality": quality,
606
+ "q_hat": quality,
607
  "provider": _DEFAULT_PROVIDER,
608
  "model": _DEFAULT_MODEL,
609
+ "metadata": metadata,
 
 
 
 
 
610
  }
611
  response.update(formatted)
612
  return response