File size: 5,296 Bytes
d2bfe97 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | """
ai_client.py
Thin wrapper around the Anthropic API with chunked processing and streaming.
Author: algorembrant
"""
from __future__ import annotations
import sys
from typing import Iterator, Optional
import anthropic
from config import DEFAULT_MODEL, MAX_TOKENS, CHUNK_SIZE
# ---------------------------------------------------------------------------
# Module-level client (lazy init, reused across calls)
# ---------------------------------------------------------------------------
_client: Optional[anthropic.Anthropic] = None
def _get_client() -> anthropic.Anthropic:
global _client
if _client is None:
_client = anthropic.Anthropic()
return _client
# ---------------------------------------------------------------------------
# Core helpers
# ---------------------------------------------------------------------------
def complete(
system: str,
user: str,
model: str = DEFAULT_MODEL,
max_tokens: int = MAX_TOKENS,
stream: bool = True,
) -> str:
"""
Run a single completion and return the full response text.
Streams tokens to stderr if `stream=True` so the user sees progress.
"""
client = _get_client()
if stream:
result_parts: list[str] = []
with client.messages.stream(
model=model,
max_tokens=max_tokens,
system=system,
messages=[{"role": "user", "content": user}],
) as stream_ctx:
for text in stream_ctx.text_stream:
print(text, end="", flush=True, file=sys.stderr)
result_parts.append(text)
print(file=sys.stderr) # newline after stream
return "".join(result_parts)
else:
response = client.messages.create(
model=model,
max_tokens=max_tokens,
system=system,
messages=[{"role": "user", "content": user}],
)
return response.content[0].text
def _split_into_chunks(text: str, chunk_size: int = CHUNK_SIZE) -> list[str]:
"""
Split text into chunks of at most `chunk_size` characters,
breaking on paragraph or sentence boundaries where possible.
"""
if len(text) <= chunk_size:
return [text]
chunks: list[str] = []
start = 0
while start < len(text):
end = start + chunk_size
if end >= len(text):
chunks.append(text[start:])
break
# Try to break at a paragraph boundary (\n\n)
split_at = text.rfind("\n\n", start, end)
if split_at == -1:
# Fall back to sentence boundary
split_at = text.rfind(". ", start, end)
if split_at == -1:
# Fall back to whitespace
split_at = text.rfind(" ", start, end)
if split_at == -1:
split_at = end # hard split
chunks.append(text[start : split_at + 1])
start = split_at + 1
return chunks
def complete_long(
system: str,
user_prefix: str,
text: str,
user_suffix: str = "",
model: str = DEFAULT_MODEL,
max_tokens: int = MAX_TOKENS,
merge_system: Optional[str] = None,
stream: bool = True,
) -> str:
"""
Process a potentially long text by splitting it into chunks,
running a completion on each, then optionally merging the results.
Args:
system: System prompt.
user_prefix: Text prepended before each chunk in the user message.
text: The main content to process (may be chunked).
user_suffix: Text appended after each chunk in the user message.
model: Anthropic model identifier.
max_tokens: Max output tokens per call.
merge_system: If provided and there are multiple chunks, a final
merge pass is run with this system prompt.
stream: Whether to stream tokens to stderr.
Returns:
Final processed text (merged if multi-chunk).
"""
chunks = _split_into_chunks(text)
n = len(chunks)
if n == 1:
user_msg = f"{user_prefix}\n\n{chunks[0]}"
if user_suffix:
user_msg += f"\n\n{user_suffix}"
return complete(system, user_msg, model=model, max_tokens=max_tokens, stream=stream)
# Multi-chunk processing
print(
f"[info] Text is large ({len(text):,} chars). Processing in {n} chunks.",
file=sys.stderr,
)
partial_results: list[str] = []
for i, chunk in enumerate(chunks, 1):
print(f"\n[chunk {i}/{n}]", file=sys.stderr)
user_msg = (
f"{user_prefix}\n\n"
f"[Part {i} of {n}]\n\n{chunk}"
)
if user_suffix:
user_msg += f"\n\n{user_suffix}"
result = complete(system, user_msg, model=model, max_tokens=max_tokens, stream=stream)
partial_results.append(result)
combined = "\n\n".join(partial_results)
# Optional merge/synthesis pass
if merge_system and n > 1:
print(f"\n[merging {n} chunks into final output]", file=sys.stderr)
combined = complete(
merge_system,
f"Merge and unify the following {n} sections into a single cohesive output:\n\n{combined}",
model=model,
max_tokens=max_tokens,
stream=stream,
)
return combined
|