| import io | |
| import time | |
| import contextlib | |
| from pathlib import Path | |
| import sys | |
| import torch | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from progressive_scaleup import progressive_scale_up_text | |
| from unified_workflow import run_workflow | |
| from bit_transformer.bit_io import text_to_bits | |
| from bit_transformer.safety import hil_safe_inference | |
| def capture_run(func, *args, **kwargs): | |
| buf = io.StringIO() | |
| start = time.time() | |
| with contextlib.redirect_stdout(buf): | |
| result = func(*args, **kwargs) | |
| duration = time.time() - start | |
| return result, buf.getvalue(), duration | |
| def main() -> None: | |
| summary: list[str] = [] | |
| _, log, dur = capture_run( | |
| progressive_scale_up_text, | |
| improve_thresh=0.01, | |
| steps=10, | |
| width_mult=2.0, | |
| max_len=64, | |
| dataset_size=512, | |
| forward_kwargs={"causal": True}, | |
| ) | |
| summary.append("### Progressive Scale-Up (causal=True)\n") | |
| summary.append(log.strip()) | |
| summary.append(f"Duration: {dur:.2f}s\n") | |
| _, log, dur = capture_run( | |
| progressive_scale_up_text, | |
| improve_thresh=0.01, | |
| steps=10, | |
| width_mult=2.0, | |
| max_len=64, | |
| dataset_size=512, | |
| forward_kwargs={"causal": False}, | |
| ) | |
| summary.append("### Progressive Scale-Up (causal=False)\n") | |
| summary.append(log.strip()) | |
| summary.append(f"Duration: {dur:.2f}s\n") | |
| (model, _), log, dur = capture_run( | |
| run_workflow, | |
| steps=2, | |
| max_len=32, | |
| dataset_size=32, | |
| plateau_steps=1, | |
| epochs_per_step=1, | |
| extra_steps=1, | |
| diffusion=False, | |
| ) | |
| bits = text_to_bits("hi") | |
| tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0) | |
| out_bits, _ = hil_safe_inference(model, tensor, c_floor=0.0, s_floor=0.0) | |
| summary.append("### Unified Workflow (causal=True)\n") | |
| summary.append(log.strip()) | |
| summary.append(f"Inference on 'hi': {out_bits.squeeze(0).tolist()}\n") | |
| summary.append(f"Duration: {dur:.2f}s\n") | |
| (_, _), log, dur = capture_run( | |
| run_workflow, | |
| steps=2, | |
| max_len=32, | |
| dataset_size=32, | |
| plateau_steps=1, | |
| epochs_per_step=1, | |
| extra_steps=1, | |
| diffusion=True, | |
| ) | |
| summary.append("### Unified Workflow (causal=False / Diffusion)\n") | |
| summary.append(log.strip()) | |
| summary.append(f"Duration: {dur:.2f}s\n") | |
| report = "\n".join(summary) | |
| print(report) | |
| if __name__ == "__main__": | |
| main() | |