from __future__ import annotations import math import os import re import time from dataclasses import dataclass from typing import Any, List, Optional, Sequence, Tuple import gradio as gr import numpy as np import torch from huggingface_hub import InferenceClient from transformers import AutoModelForSeq2SeqLM, AutoTokenizer COMPRESSION_MODEL_ID = "gravitee-io/very-small-prompt-compression" DOWNSTREAM_MODEL = "openai/gpt-oss-20b" EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" MAX_NEW_TOKENS = 96 compression_tokenizer = AutoTokenizer.from_pretrained(COMPRESSION_MODEL_ID, use_fast=True) _MODEL_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") compression_model = AutoModelForSeq2SeqLM.from_pretrained(COMPRESSION_MODEL_ID).to(_MODEL_DEVICE) compression_model.eval() @dataclass class Segment: text: str punctuation: str def _split_prompt(prompt: str) -> List[Segment]: """Split a prompt into sentence segments while retaining trailing punctuation.""" parts = re.findall(r"[^.!?]+[.!?]*", prompt) segments: List[Segment] = [] for part in parts: stripped = part.strip() if not stripped: continue punct_len = len(stripped) - len(stripped.rstrip(".?!")) punctuation = stripped[-punct_len:] if punct_len else "" content = stripped[:-punct_len].strip() if punct_len else stripped if content: segments.append(Segment(text=content, punctuation=punctuation)) if not segments and prompt.strip(): segments.append(Segment(text=prompt.strip(), punctuation="")) return segments def _combine_segments(segments: Sequence[Segment]) -> str: pieces = [] for segment in segments: piece = segment.text.strip() if segment.punctuation: piece = f"{piece}{segment.punctuation}" pieces.append(piece) return " ".join(piece for piece in pieces if piece).strip() def _count_tokens(text: str) -> int: if not text: return 0 return len(compression_tokenizer.encode(text, add_special_tokens=False)) def _call_compression_model(text: str, *, max_new_tokens: int = MAX_NEW_TOKENS) -> str: try: encoded = compression_tokenizer(text, return_tensors="pt", truncation=True, max_length=512) encoded = {k: v.to(_MODEL_DEVICE) for k, v in encoded.items()} with torch.no_grad(): output_ids = compression_model.generate( **encoded, max_new_tokens=max_new_tokens, num_beams=4, no_repeat_ngram_size=3, ) compressed = compression_tokenizer.decode(output_ids[0], skip_special_tokens=True) except Exception: return "broken" cleaned = compressed.strip().rstrip("?.!,;:") return compressed or text def _embed(client: InferenceClient, text: str) -> Optional[np.ndarray]: if not text.strip(): return None try: features = client.feature_extraction(text) except Exception: return None if isinstance(features, list): array = np.array(features[0] if features and isinstance(features[0], list) else features, dtype=np.float32) else: array = np.array(features, dtype=np.float32) if array.ndim == 0: return None if array.ndim > 1: array = array.squeeze() norm = np.linalg.norm(array) if not math.isfinite(norm) or norm == 0.0: return None return array def _cosine_similarity(vec_a: np.ndarray | None, vec_b: np.ndarray | None) -> Optional[float]: if vec_a is None or vec_b is None: return None denom = float(np.linalg.norm(vec_a) * np.linalg.norm(vec_b)) if denom == 0.0: return None return float(np.dot(vec_a, vec_b) / denom) def _extract_text(payload: Any) -> str: if payload is None: return "" if isinstance(payload, str): return payload if isinstance(payload, dict): if "text" in payload and isinstance(payload["text"], str): return payload["text"] content = payload.get("content") if isinstance(content, str): return content if isinstance(content, list): return " ".join(_extract_text(item) for item in content) if content is None: return "" if isinstance(payload, list): return " ".join(_extract_text(item) for item in payload) if hasattr(payload, "content"): return _extract_text(getattr(payload, "content")) return "" def _chat_completion(client: InferenceClient, prompt: str) -> Tuple[str, Optional[str]]: last_error: Optional[str] = None for attempt in range(2): try: completion = client.chat_completion( messages=[ {"role": "system", "content": "You are a helpful assistant. Answer concisely."}, {"role": "user", "content": prompt}, ], max_tokens=1024, temperature=0.0, top_p=1.00, ) except Exception as exc: last_error = f"{type(exc).__name__}: {exc}" continue try: choice = completion.choices[0] if completion.choices else None if choice is None: last_error = "No choices returned by downstream model." continue finish_reason = getattr(choice, "finish_reason", None) message = getattr(choice, "message", None) content = _extract_text(message) if not content: delta = getattr(choice, "delta", None) content = _extract_text(delta) if not content: raw_choice = getattr(choice, "content", None) content = _extract_text(raw_choice) content = content.strip() if content: return content, None last_error = f"Model returned an empty response (finish_reason={finish_reason})." except Exception as exc: last_error = f"{type(exc).__name__}: {exc}" return "", last_error or "No response generated." def _get_client(model_id: str, token: Optional[str]) -> InferenceClient: return InferenceClient(model=model_id, token=token) def _resolve_token(hf_token: Optional[str]) -> Optional[str]: token = (hf_token or "").strip() if token: return token return os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") def compress_prompt_action(prompt: str, hf_token: Optional[str]) -> Tuple[str, str, str, str, str, str]: token = _resolve_token(hf_token) prompt = prompt.strip() if not prompt: message = "Please enter a prompt to compress." placeholder = "_Run **Compare Responses** after compression to see downstream outputs._" return ("", "", message, placeholder, placeholder, "") embedding_client = _get_client(EMBEDDING_MODEL, token) segments = _split_prompt(prompt) compressed_segments: List[Segment] = [] segment_timings: List[float] = [] for segment in segments: start = time.perf_counter() compressed_text = _call_compression_model(segment.text) segment_timings.append(time.perf_counter() - start) compressed_segments.append(Segment(text=compressed_text, punctuation=segment.punctuation)) compressed_prompt = _combine_segments(compressed_segments).rstrip("?.!,;:") original_tokens = _count_tokens(prompt) compressed_tokens = _count_tokens(compressed_prompt) token_delta = original_tokens - compressed_tokens savings_pct = (token_delta / original_tokens * 100) if original_tokens else 0.0 prompt_embedding_original = _embed(embedding_client, prompt) prompt_embedding_compressed = _embed(embedding_client, compressed_prompt) prompt_similarity = _cosine_similarity(prompt_embedding_original, prompt_embedding_compressed) prompt_metrics_lines = [ f"**Original tokens:** {original_tokens}", f"**Compressed tokens:** {compressed_tokens}", f"**Token savings:** {token_delta} ({savings_pct:.1f}%)", ] if prompt_similarity is not None: prompt_metrics_lines.append(f"**Prompt cosine similarity:** {prompt_similarity:.3f}") if segment_timings: min_ms = min(segment_timings) * 1000.0 max_ms = max(segment_timings) * 1000.0 mean_ms = (sum(segment_timings) / len(segment_timings)) * 1000.0 prompt_metrics_lines.append( f"**Segments:** {len(segment_timings)} • **Latency (ms):** min {min_ms:.1f} / mean {mean_ms:.1f} / max {max_ms:.1f}" ) prompt_metrics = "
".join(prompt_metrics_lines) placeholder_response = "_Run **Compare Responses** to query the downstream model._" response_metrics = "Press **Compare Responses** to evaluate downstream behavior." return ( prompt, compressed_prompt, prompt_metrics, placeholder_response, placeholder_response, response_metrics, ) def compare_responses_action( original_prompt: str, compressed_prompt: str, hf_token: Optional[str] ) -> Tuple[str, str, str, str, str, str]: token = _resolve_token(hf_token) original_prompt = (original_prompt or "").strip() compressed_prompt = (compressed_prompt or "").strip() if not original_prompt or not compressed_prompt: message = "Please compress a prompt before comparing responses." placeholder = "_No response generated._" return ( original_prompt, compressed_prompt, message, placeholder, placeholder, "Responses unavailable.", ) embedding_client = _get_client(EMBEDDING_MODEL, token) llm_client = _get_client(DOWNSTREAM_MODEL, token) original_tokens = _count_tokens(original_prompt) compressed_tokens = _count_tokens(compressed_prompt) token_delta = original_tokens - compressed_tokens savings_pct = (token_delta / original_tokens * 100) if original_tokens else 0.0 prompt_embedding_original = _embed(embedding_client, original_prompt) prompt_embedding_compressed = _embed(embedding_client, compressed_prompt) prompt_similarity = _cosine_similarity(prompt_embedding_original, prompt_embedding_compressed) prompt_metrics_lines = [ f"**Original tokens:** {original_tokens}", f"**Compressed tokens:** {compressed_tokens}", f"**Token savings:** {token_delta} ({savings_pct:.1f}%)", ] if prompt_similarity is not None: prompt_metrics_lines.append(f"**Prompt cosine similarity:** {prompt_similarity:.3f}") prompt_metrics = "
".join(prompt_metrics_lines) original_response, original_response_error = _chat_completion(llm_client, original_prompt) compressed_response, compressed_response_error = _chat_completion(llm_client, compressed_prompt) response_embedding_original = _embed(embedding_client, original_response) response_embedding_compressed = _embed(embedding_client, compressed_response) response_similarity = _cosine_similarity(response_embedding_original, response_embedding_compressed) response_metrics_lines = [] if response_similarity is not None: response_metrics_lines.append(f"**Response cosine similarity:** {response_similarity:.3f}") original_response_display = original_response or "_No response generated for the original prompt._" compressed_response_display = compressed_response or "_No response generated for the compressed prompt._" if original_response_error: original_response_display += f"\n\n> {original_response_error}" response_metrics_lines.append("⚠️ Downstream model issue on original prompt.") if compressed_response_error: compressed_response_display += f"\n\n> {compressed_response_error}" response_metrics_lines.append("⚠️ Downstream model issue on compressed prompt.") if not response_metrics_lines: response_metrics_lines.append("Responses unavailable.") response_metrics = "
".join(response_metrics_lines) return ( original_prompt, compressed_prompt, prompt_metrics, original_response_display, compressed_response_display, response_metrics, ) with gr.Blocks(fill_height=True, css=".gradio-container {max-width: 900px;}") as demo: gr.Markdown( """ # Very Small Prompt Compression Enter a user prompt to see how the [gravitee-io/very-small-prompt-compression](https://huggingface.co/gravitee-io/very-small-prompt-compression) checkpoint trims it down, compares token savings, and checks semantic drift before forwarding to `openai/gpt-oss-20b`. Trained using the [gravitee-io/dolly-15k-prompt-compression](https://huggingface.co/datasets/gravitee-io/dolly-15k-prompt-compression) dataset. **Note:** Provide a Hugging Face Inference API token below (or set `HF_TOKEN`) so the demo can call the downstream model. """ ) token_input = gr.Textbox( label="Hugging Face token (optional)", type="password", placeholder="Paste an access token to use your own Inference quota", ) prompt_input = gr.Textbox( label="User prompt", placeholder="Describe how to configure a rate limit policy in Gravitee API Management...", lines=4, ) with gr.Row(): compress_btn = gr.Button("Compress Prompt", variant="primary") compare_btn = gr.Button("Compare Responses", variant="secondary") original_prompt_output = gr.Textbox(label="Original prompt", lines=4, interactive=False) compressed_output = gr.Textbox(label="Compressed prompt", lines=4, interactive=False) prompt_metrics_output = gr.Markdown() with gr.Row(): original_response_output = gr.Markdown(label="Response to original prompt") compressed_response_output = gr.Markdown(label="Response to compressed prompt") response_metrics_output = gr.Markdown() compress_btn.click( fn=compress_prompt_action, inputs=[prompt_input, token_input], outputs=[ original_prompt_output, compressed_output, prompt_metrics_output, original_response_output, compressed_response_output, response_metrics_output, ], ) compare_btn.click( fn=compare_responses_action, inputs=[original_prompt_output, compressed_output, token_input], outputs=[ original_prompt_output, compressed_output, prompt_metrics_output, original_response_output, compressed_response_output, response_metrics_output, ], ) if __name__ == "__main__": demo.launch()