File size: 1,834 Bytes
ebc3bf5
 
 
 
 
c4cbe0b
 
 
 
 
 
ebc3bf5
 
 
 
 
 
 
 
 
c4cbe0b
 
ebc3bf5
 
c4cbe0b
 
 
ebc3bf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import config
from core.tokenizer_utils import count_tokens
from models.model_loader import get_llm

try:
    import spaces
    _gpu = spaces.GPU
except ImportError:
    _gpu = lambda fn: fn  # no-op when running locally without the spaces package


_PROMPT_TEMPLATE = """You are a lossless compression assistant. Compress the following text to at most {target} tokens.
Preserve all key facts, decisions, and intent. Do not add commentary. Output only the compressed text.

TEXT:
{text}

COMPRESSED:"""


@_gpu
def _generate(prompt: str) -> str:
    model, tokenizer = get_llm()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=config.MAX_NEW_TOKENS,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )
    new_tokens = output_ids[0][inputs["input_ids"].shape[1]:]
    return tokenizer.decode(new_tokens, skip_special_tokens=True).strip()


def compress(text: str, target_tokens: int) -> tuple[str, int, int]:
    """Returns (compressed_text, input_token_count, output_token_count)."""
    input_tokens = count_tokens(text)

    if input_tokens <= target_tokens:
        return text, input_tokens, input_tokens

    prompt = _PROMPT_TEMPLATE.format(target=target_tokens, text=text)
    compressed = _generate(prompt)

    # Trim to hard token limit if model overshoots
    _, tokenizer = get_llm()
    ids = tokenizer.encode(compressed, add_special_tokens=False)
    if len(ids) > target_tokens:
        compressed = tokenizer.decode(ids[:target_tokens], skip_special_tokens=True)

    output_tokens = count_tokens(compressed)
    return compressed, input_tokens, output_tokens