|
|
from __future__ import annotations |
|
|
|
|
|
import math |
|
|
import os |
|
|
import re |
|
|
import time |
|
|
from dataclasses import dataclass |
|
|
from typing import Any, List, Optional, Sequence, Tuple |
|
|
|
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import torch |
|
|
from huggingface_hub import InferenceClient |
|
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
|
|
|
COMPRESSION_MODEL_ID = "gravitee-io/very-small-prompt-compression" |
|
|
DOWNSTREAM_MODEL = "openai/gpt-oss-20b" |
|
|
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" |
|
|
MAX_NEW_TOKENS = 96 |
|
|
|
|
|
compression_tokenizer = AutoTokenizer.from_pretrained(COMPRESSION_MODEL_ID, use_fast=True) |
|
|
_MODEL_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
compression_model = AutoModelForSeq2SeqLM.from_pretrained(COMPRESSION_MODEL_ID).to(_MODEL_DEVICE) |
|
|
compression_model.eval() |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Segment: |
|
|
text: str |
|
|
punctuation: str |
|
|
|
|
|
|
|
|
def _split_prompt(prompt: str) -> List[Segment]: |
|
|
"""Split a prompt into sentence segments while retaining trailing punctuation.""" |
|
|
parts = re.findall(r"[^.!?]+[.!?]*", prompt) |
|
|
segments: List[Segment] = [] |
|
|
for part in parts: |
|
|
stripped = part.strip() |
|
|
if not stripped: |
|
|
continue |
|
|
punct_len = len(stripped) - len(stripped.rstrip(".?!")) |
|
|
punctuation = stripped[-punct_len:] if punct_len else "" |
|
|
content = stripped[:-punct_len].strip() if punct_len else stripped |
|
|
if content: |
|
|
segments.append(Segment(text=content, punctuation=punctuation)) |
|
|
if not segments and prompt.strip(): |
|
|
segments.append(Segment(text=prompt.strip(), punctuation="")) |
|
|
return segments |
|
|
|
|
|
|
|
|
def _combine_segments(segments: Sequence[Segment]) -> str: |
|
|
pieces = [] |
|
|
for segment in segments: |
|
|
piece = segment.text.strip() |
|
|
if segment.punctuation: |
|
|
piece = f"{piece}{segment.punctuation}" |
|
|
pieces.append(piece) |
|
|
return " ".join(piece for piece in pieces if piece).strip() |
|
|
|
|
|
|
|
|
def _count_tokens(text: str) -> int: |
|
|
if not text: |
|
|
return 0 |
|
|
return len(compression_tokenizer.encode(text, add_special_tokens=False)) |
|
|
|
|
|
|
|
|
def _call_compression_model(text: str, *, max_new_tokens: int = MAX_NEW_TOKENS) -> str: |
|
|
try: |
|
|
encoded = compression_tokenizer(text, return_tensors="pt", truncation=True, max_length=512) |
|
|
encoded = {k: v.to(_MODEL_DEVICE) for k, v in encoded.items()} |
|
|
with torch.no_grad(): |
|
|
output_ids = compression_model.generate( |
|
|
**encoded, |
|
|
max_new_tokens=max_new_tokens, |
|
|
num_beams=4, |
|
|
no_repeat_ngram_size=3, |
|
|
) |
|
|
compressed = compression_tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
except Exception: |
|
|
return "broken" |
|
|
cleaned = compressed.strip().rstrip("?.!,;:") |
|
|
return compressed or text |
|
|
|
|
|
|
|
|
def _embed(client: InferenceClient, text: str) -> Optional[np.ndarray]: |
|
|
if not text.strip(): |
|
|
return None |
|
|
try: |
|
|
features = client.feature_extraction(text) |
|
|
except Exception: |
|
|
return None |
|
|
if isinstance(features, list): |
|
|
array = np.array(features[0] if features and isinstance(features[0], list) else features, dtype=np.float32) |
|
|
else: |
|
|
array = np.array(features, dtype=np.float32) |
|
|
if array.ndim == 0: |
|
|
return None |
|
|
if array.ndim > 1: |
|
|
array = array.squeeze() |
|
|
norm = np.linalg.norm(array) |
|
|
if not math.isfinite(norm) or norm == 0.0: |
|
|
return None |
|
|
return array |
|
|
|
|
|
|
|
|
def _cosine_similarity(vec_a: np.ndarray | None, vec_b: np.ndarray | None) -> Optional[float]: |
|
|
if vec_a is None or vec_b is None: |
|
|
return None |
|
|
denom = float(np.linalg.norm(vec_a) * np.linalg.norm(vec_b)) |
|
|
if denom == 0.0: |
|
|
return None |
|
|
return float(np.dot(vec_a, vec_b) / denom) |
|
|
|
|
|
|
|
|
def _extract_text(payload: Any) -> str: |
|
|
if payload is None: |
|
|
return "" |
|
|
if isinstance(payload, str): |
|
|
return payload |
|
|
if isinstance(payload, dict): |
|
|
if "text" in payload and isinstance(payload["text"], str): |
|
|
return payload["text"] |
|
|
content = payload.get("content") |
|
|
if isinstance(content, str): |
|
|
return content |
|
|
if isinstance(content, list): |
|
|
return " ".join(_extract_text(item) for item in content) |
|
|
if content is None: |
|
|
return "" |
|
|
if isinstance(payload, list): |
|
|
return " ".join(_extract_text(item) for item in payload) |
|
|
if hasattr(payload, "content"): |
|
|
return _extract_text(getattr(payload, "content")) |
|
|
return "" |
|
|
|
|
|
|
|
|
def _chat_completion(client: InferenceClient, prompt: str) -> Tuple[str, Optional[str]]: |
|
|
last_error: Optional[str] = None |
|
|
for attempt in range(2): |
|
|
try: |
|
|
completion = client.chat_completion( |
|
|
messages=[ |
|
|
{"role": "system", "content": "You are a helpful assistant. Answer concisely."}, |
|
|
{"role": "user", "content": prompt}, |
|
|
], |
|
|
max_tokens=1024, |
|
|
temperature=0.0, |
|
|
top_p=1.00, |
|
|
) |
|
|
except Exception as exc: |
|
|
last_error = f"{type(exc).__name__}: {exc}" |
|
|
continue |
|
|
|
|
|
try: |
|
|
choice = completion.choices[0] if completion.choices else None |
|
|
if choice is None: |
|
|
last_error = "No choices returned by downstream model." |
|
|
continue |
|
|
finish_reason = getattr(choice, "finish_reason", None) |
|
|
message = getattr(choice, "message", None) |
|
|
content = _extract_text(message) |
|
|
if not content: |
|
|
delta = getattr(choice, "delta", None) |
|
|
content = _extract_text(delta) |
|
|
if not content: |
|
|
raw_choice = getattr(choice, "content", None) |
|
|
content = _extract_text(raw_choice) |
|
|
content = content.strip() |
|
|
if content: |
|
|
return content, None |
|
|
last_error = f"Model returned an empty response (finish_reason={finish_reason})." |
|
|
except Exception as exc: |
|
|
last_error = f"{type(exc).__name__}: {exc}" |
|
|
return "", last_error or "No response generated." |
|
|
|
|
|
|
|
|
def _get_client(model_id: str, token: Optional[str]) -> InferenceClient: |
|
|
return InferenceClient(model=model_id, token=token) |
|
|
|
|
|
|
|
|
def _resolve_token(hf_token: Optional[str]) -> Optional[str]: |
|
|
token = (hf_token or "").strip() |
|
|
if token: |
|
|
return token |
|
|
return os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") |
|
|
|
|
|
|
|
|
def compress_prompt_action(prompt: str, hf_token: Optional[str]) -> Tuple[str, str, str, str, str, str]: |
|
|
token = _resolve_token(hf_token) |
|
|
prompt = prompt.strip() |
|
|
if not prompt: |
|
|
message = "Please enter a prompt to compress." |
|
|
placeholder = "_Run **Compare Responses** after compression to see downstream outputs._" |
|
|
return ("", "", message, placeholder, placeholder, "") |
|
|
|
|
|
embedding_client = _get_client(EMBEDDING_MODEL, token) |
|
|
|
|
|
segments = _split_prompt(prompt) |
|
|
compressed_segments: List[Segment] = [] |
|
|
segment_timings: List[float] = [] |
|
|
|
|
|
for segment in segments: |
|
|
start = time.perf_counter() |
|
|
compressed_text = _call_compression_model(segment.text) |
|
|
segment_timings.append(time.perf_counter() - start) |
|
|
compressed_segments.append(Segment(text=compressed_text, punctuation=segment.punctuation)) |
|
|
|
|
|
compressed_prompt = _combine_segments(compressed_segments).rstrip("?.!,;:") |
|
|
|
|
|
original_tokens = _count_tokens(prompt) |
|
|
compressed_tokens = _count_tokens(compressed_prompt) |
|
|
token_delta = original_tokens - compressed_tokens |
|
|
savings_pct = (token_delta / original_tokens * 100) if original_tokens else 0.0 |
|
|
|
|
|
prompt_embedding_original = _embed(embedding_client, prompt) |
|
|
prompt_embedding_compressed = _embed(embedding_client, compressed_prompt) |
|
|
prompt_similarity = _cosine_similarity(prompt_embedding_original, prompt_embedding_compressed) |
|
|
|
|
|
prompt_metrics_lines = [ |
|
|
f"**Original tokens:** {original_tokens}", |
|
|
f"**Compressed tokens:** {compressed_tokens}", |
|
|
f"**Token savings:** {token_delta} ({savings_pct:.1f}%)", |
|
|
] |
|
|
if prompt_similarity is not None: |
|
|
prompt_metrics_lines.append(f"**Prompt cosine similarity:** {prompt_similarity:.3f}") |
|
|
if segment_timings: |
|
|
min_ms = min(segment_timings) * 1000.0 |
|
|
max_ms = max(segment_timings) * 1000.0 |
|
|
mean_ms = (sum(segment_timings) / len(segment_timings)) * 1000.0 |
|
|
prompt_metrics_lines.append( |
|
|
f"**Segments:** {len(segment_timings)} • **Latency (ms):** min {min_ms:.1f} / mean {mean_ms:.1f} / max {max_ms:.1f}" |
|
|
) |
|
|
prompt_metrics = "<br>".join(prompt_metrics_lines) |
|
|
|
|
|
placeholder_response = "_Run **Compare Responses** to query the downstream model._" |
|
|
response_metrics = "Press **Compare Responses** to evaluate downstream behavior." |
|
|
|
|
|
return ( |
|
|
prompt, |
|
|
compressed_prompt, |
|
|
prompt_metrics, |
|
|
placeholder_response, |
|
|
placeholder_response, |
|
|
response_metrics, |
|
|
) |
|
|
|
|
|
|
|
|
def compare_responses_action( |
|
|
original_prompt: str, compressed_prompt: str, hf_token: Optional[str] |
|
|
) -> Tuple[str, str, str, str, str, str]: |
|
|
token = _resolve_token(hf_token) |
|
|
original_prompt = (original_prompt or "").strip() |
|
|
compressed_prompt = (compressed_prompt or "").strip() |
|
|
|
|
|
if not original_prompt or not compressed_prompt: |
|
|
message = "Please compress a prompt before comparing responses." |
|
|
placeholder = "_No response generated._" |
|
|
return ( |
|
|
original_prompt, |
|
|
compressed_prompt, |
|
|
message, |
|
|
placeholder, |
|
|
placeholder, |
|
|
"Responses unavailable.", |
|
|
) |
|
|
|
|
|
embedding_client = _get_client(EMBEDDING_MODEL, token) |
|
|
llm_client = _get_client(DOWNSTREAM_MODEL, token) |
|
|
|
|
|
original_tokens = _count_tokens(original_prompt) |
|
|
compressed_tokens = _count_tokens(compressed_prompt) |
|
|
token_delta = original_tokens - compressed_tokens |
|
|
savings_pct = (token_delta / original_tokens * 100) if original_tokens else 0.0 |
|
|
|
|
|
prompt_embedding_original = _embed(embedding_client, original_prompt) |
|
|
prompt_embedding_compressed = _embed(embedding_client, compressed_prompt) |
|
|
prompt_similarity = _cosine_similarity(prompt_embedding_original, prompt_embedding_compressed) |
|
|
|
|
|
prompt_metrics_lines = [ |
|
|
f"**Original tokens:** {original_tokens}", |
|
|
f"**Compressed tokens:** {compressed_tokens}", |
|
|
f"**Token savings:** {token_delta} ({savings_pct:.1f}%)", |
|
|
] |
|
|
if prompt_similarity is not None: |
|
|
prompt_metrics_lines.append(f"**Prompt cosine similarity:** {prompt_similarity:.3f}") |
|
|
prompt_metrics = "<br>".join(prompt_metrics_lines) |
|
|
|
|
|
original_response, original_response_error = _chat_completion(llm_client, original_prompt) |
|
|
compressed_response, compressed_response_error = _chat_completion(llm_client, compressed_prompt) |
|
|
|
|
|
response_embedding_original = _embed(embedding_client, original_response) |
|
|
response_embedding_compressed = _embed(embedding_client, compressed_response) |
|
|
response_similarity = _cosine_similarity(response_embedding_original, response_embedding_compressed) |
|
|
|
|
|
response_metrics_lines = [] |
|
|
if response_similarity is not None: |
|
|
response_metrics_lines.append(f"**Response cosine similarity:** {response_similarity:.3f}") |
|
|
|
|
|
original_response_display = original_response or "_No response generated for the original prompt._" |
|
|
compressed_response_display = compressed_response or "_No response generated for the compressed prompt._" |
|
|
if original_response_error: |
|
|
original_response_display += f"\n\n> {original_response_error}" |
|
|
response_metrics_lines.append("⚠️ Downstream model issue on original prompt.") |
|
|
if compressed_response_error: |
|
|
compressed_response_display += f"\n\n> {compressed_response_error}" |
|
|
response_metrics_lines.append("⚠️ Downstream model issue on compressed prompt.") |
|
|
|
|
|
if not response_metrics_lines: |
|
|
response_metrics_lines.append("Responses unavailable.") |
|
|
|
|
|
response_metrics = "<br>".join(response_metrics_lines) |
|
|
|
|
|
return ( |
|
|
original_prompt, |
|
|
compressed_prompt, |
|
|
prompt_metrics, |
|
|
original_response_display, |
|
|
compressed_response_display, |
|
|
response_metrics, |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Blocks(fill_height=True, css=".gradio-container {max-width: 900px;}") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# Very Small Prompt Compression |
|
|
Enter a user prompt to see how the [gravitee-io/very-small-prompt-compression](https://huggingface.co/gravitee-io/very-small-prompt-compression) checkpoint trims it down, |
|
|
compares token savings, and checks semantic drift before forwarding to `openai/gpt-oss-20b`. |
|
|
|
|
|
Trained using the [gravitee-io/dolly-15k-prompt-compression](https://huggingface.co/datasets/gravitee-io/dolly-15k-prompt-compression) dataset. |
|
|
|
|
|
**Note:** Provide a Hugging Face Inference API token below (or set `HF_TOKEN`) so the demo can call the downstream model. |
|
|
""" |
|
|
) |
|
|
|
|
|
token_input = gr.Textbox( |
|
|
label="Hugging Face token (optional)", |
|
|
type="password", |
|
|
placeholder="Paste an access token to use your own Inference quota", |
|
|
) |
|
|
|
|
|
prompt_input = gr.Textbox( |
|
|
label="User prompt", |
|
|
placeholder="Describe how to configure a rate limit policy in Gravitee API Management...", |
|
|
lines=4, |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
compress_btn = gr.Button("Compress Prompt", variant="primary") |
|
|
compare_btn = gr.Button("Compare Responses", variant="secondary") |
|
|
|
|
|
original_prompt_output = gr.Textbox(label="Original prompt", lines=4, interactive=False) |
|
|
compressed_output = gr.Textbox(label="Compressed prompt", lines=4, interactive=False) |
|
|
prompt_metrics_output = gr.Markdown() |
|
|
|
|
|
with gr.Row(): |
|
|
original_response_output = gr.Markdown(label="Response to original prompt") |
|
|
compressed_response_output = gr.Markdown(label="Response to compressed prompt") |
|
|
|
|
|
response_metrics_output = gr.Markdown() |
|
|
|
|
|
compress_btn.click( |
|
|
fn=compress_prompt_action, |
|
|
inputs=[prompt_input, token_input], |
|
|
outputs=[ |
|
|
original_prompt_output, |
|
|
compressed_output, |
|
|
prompt_metrics_output, |
|
|
original_response_output, |
|
|
compressed_response_output, |
|
|
response_metrics_output, |
|
|
], |
|
|
) |
|
|
|
|
|
compare_btn.click( |
|
|
fn=compare_responses_action, |
|
|
inputs=[original_prompt_output, compressed_output, token_input], |
|
|
outputs=[ |
|
|
original_prompt_output, |
|
|
compressed_output, |
|
|
prompt_metrics_output, |
|
|
original_response_output, |
|
|
compressed_response_output, |
|
|
response_metrics_output, |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|