dotslashderek's picture
chore: update app to better handle optional hugging face token
3272ef5 verified
from __future__ import annotations
import math
import os
import re
import time
from dataclasses import dataclass
from typing import Any, List, Optional, Sequence, Tuple
import gradio as gr
import numpy as np
import torch
from huggingface_hub import InferenceClient
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
COMPRESSION_MODEL_ID = "gravitee-io/very-small-prompt-compression"
DOWNSTREAM_MODEL = "openai/gpt-oss-20b"
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
MAX_NEW_TOKENS = 96
compression_tokenizer = AutoTokenizer.from_pretrained(COMPRESSION_MODEL_ID, use_fast=True)
_MODEL_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
compression_model = AutoModelForSeq2SeqLM.from_pretrained(COMPRESSION_MODEL_ID).to(_MODEL_DEVICE)
compression_model.eval()
@dataclass
class Segment:
text: str
punctuation: str
def _split_prompt(prompt: str) -> List[Segment]:
"""Split a prompt into sentence segments while retaining trailing punctuation."""
parts = re.findall(r"[^.!?]+[.!?]*", prompt)
segments: List[Segment] = []
for part in parts:
stripped = part.strip()
if not stripped:
continue
punct_len = len(stripped) - len(stripped.rstrip(".?!"))
punctuation = stripped[-punct_len:] if punct_len else ""
content = stripped[:-punct_len].strip() if punct_len else stripped
if content:
segments.append(Segment(text=content, punctuation=punctuation))
if not segments and prompt.strip():
segments.append(Segment(text=prompt.strip(), punctuation=""))
return segments
def _combine_segments(segments: Sequence[Segment]) -> str:
pieces = []
for segment in segments:
piece = segment.text.strip()
if segment.punctuation:
piece = f"{piece}{segment.punctuation}"
pieces.append(piece)
return " ".join(piece for piece in pieces if piece).strip()
def _count_tokens(text: str) -> int:
if not text:
return 0
return len(compression_tokenizer.encode(text, add_special_tokens=False))
def _call_compression_model(text: str, *, max_new_tokens: int = MAX_NEW_TOKENS) -> str:
try:
encoded = compression_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
encoded = {k: v.to(_MODEL_DEVICE) for k, v in encoded.items()}
with torch.no_grad():
output_ids = compression_model.generate(
**encoded,
max_new_tokens=max_new_tokens,
num_beams=4,
no_repeat_ngram_size=3,
)
compressed = compression_tokenizer.decode(output_ids[0], skip_special_tokens=True)
except Exception:
return "broken"
cleaned = compressed.strip().rstrip("?.!,;:")
return compressed or text
def _embed(client: InferenceClient, text: str) -> Optional[np.ndarray]:
if not text.strip():
return None
try:
features = client.feature_extraction(text)
except Exception:
return None
if isinstance(features, list):
array = np.array(features[0] if features and isinstance(features[0], list) else features, dtype=np.float32)
else:
array = np.array(features, dtype=np.float32)
if array.ndim == 0:
return None
if array.ndim > 1:
array = array.squeeze()
norm = np.linalg.norm(array)
if not math.isfinite(norm) or norm == 0.0:
return None
return array
def _cosine_similarity(vec_a: np.ndarray | None, vec_b: np.ndarray | None) -> Optional[float]:
if vec_a is None or vec_b is None:
return None
denom = float(np.linalg.norm(vec_a) * np.linalg.norm(vec_b))
if denom == 0.0:
return None
return float(np.dot(vec_a, vec_b) / denom)
def _extract_text(payload: Any) -> str:
if payload is None:
return ""
if isinstance(payload, str):
return payload
if isinstance(payload, dict):
if "text" in payload and isinstance(payload["text"], str):
return payload["text"]
content = payload.get("content")
if isinstance(content, str):
return content
if isinstance(content, list):
return " ".join(_extract_text(item) for item in content)
if content is None:
return ""
if isinstance(payload, list):
return " ".join(_extract_text(item) for item in payload)
if hasattr(payload, "content"):
return _extract_text(getattr(payload, "content"))
return ""
def _chat_completion(client: InferenceClient, prompt: str) -> Tuple[str, Optional[str]]:
last_error: Optional[str] = None
for attempt in range(2):
try:
completion = client.chat_completion(
messages=[
{"role": "system", "content": "You are a helpful assistant. Answer concisely."},
{"role": "user", "content": prompt},
],
max_tokens=1024,
temperature=0.0,
top_p=1.00,
)
except Exception as exc:
last_error = f"{type(exc).__name__}: {exc}"
continue
try:
choice = completion.choices[0] if completion.choices else None
if choice is None:
last_error = "No choices returned by downstream model."
continue
finish_reason = getattr(choice, "finish_reason", None)
message = getattr(choice, "message", None)
content = _extract_text(message)
if not content:
delta = getattr(choice, "delta", None)
content = _extract_text(delta)
if not content:
raw_choice = getattr(choice, "content", None)
content = _extract_text(raw_choice)
content = content.strip()
if content:
return content, None
last_error = f"Model returned an empty response (finish_reason={finish_reason})."
except Exception as exc:
last_error = f"{type(exc).__name__}: {exc}"
return "", last_error or "No response generated."
def _get_client(model_id: str, token: Optional[str]) -> InferenceClient:
return InferenceClient(model=model_id, token=token)
def _resolve_token(hf_token: Optional[str]) -> Optional[str]:
token = (hf_token or "").strip()
if token:
return token
return os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
def compress_prompt_action(prompt: str, hf_token: Optional[str]) -> Tuple[str, str, str, str, str, str]:
token = _resolve_token(hf_token)
prompt = prompt.strip()
if not prompt:
message = "Please enter a prompt to compress."
placeholder = "_Run **Compare Responses** after compression to see downstream outputs._"
return ("", "", message, placeholder, placeholder, "")
embedding_client = _get_client(EMBEDDING_MODEL, token)
segments = _split_prompt(prompt)
compressed_segments: List[Segment] = []
segment_timings: List[float] = []
for segment in segments:
start = time.perf_counter()
compressed_text = _call_compression_model(segment.text)
segment_timings.append(time.perf_counter() - start)
compressed_segments.append(Segment(text=compressed_text, punctuation=segment.punctuation))
compressed_prompt = _combine_segments(compressed_segments).rstrip("?.!,;:")
original_tokens = _count_tokens(prompt)
compressed_tokens = _count_tokens(compressed_prompt)
token_delta = original_tokens - compressed_tokens
savings_pct = (token_delta / original_tokens * 100) if original_tokens else 0.0
prompt_embedding_original = _embed(embedding_client, prompt)
prompt_embedding_compressed = _embed(embedding_client, compressed_prompt)
prompt_similarity = _cosine_similarity(prompt_embedding_original, prompt_embedding_compressed)
prompt_metrics_lines = [
f"**Original tokens:** {original_tokens}",
f"**Compressed tokens:** {compressed_tokens}",
f"**Token savings:** {token_delta} ({savings_pct:.1f}%)",
]
if prompt_similarity is not None:
prompt_metrics_lines.append(f"**Prompt cosine similarity:** {prompt_similarity:.3f}")
if segment_timings:
min_ms = min(segment_timings) * 1000.0
max_ms = max(segment_timings) * 1000.0
mean_ms = (sum(segment_timings) / len(segment_timings)) * 1000.0
prompt_metrics_lines.append(
f"**Segments:** {len(segment_timings)} • **Latency (ms):** min {min_ms:.1f} / mean {mean_ms:.1f} / max {max_ms:.1f}"
)
prompt_metrics = "<br>".join(prompt_metrics_lines)
placeholder_response = "_Run **Compare Responses** to query the downstream model._"
response_metrics = "Press **Compare Responses** to evaluate downstream behavior."
return (
prompt,
compressed_prompt,
prompt_metrics,
placeholder_response,
placeholder_response,
response_metrics,
)
def compare_responses_action(
original_prompt: str, compressed_prompt: str, hf_token: Optional[str]
) -> Tuple[str, str, str, str, str, str]:
token = _resolve_token(hf_token)
original_prompt = (original_prompt or "").strip()
compressed_prompt = (compressed_prompt or "").strip()
if not original_prompt or not compressed_prompt:
message = "Please compress a prompt before comparing responses."
placeholder = "_No response generated._"
return (
original_prompt,
compressed_prompt,
message,
placeholder,
placeholder,
"Responses unavailable.",
)
embedding_client = _get_client(EMBEDDING_MODEL, token)
llm_client = _get_client(DOWNSTREAM_MODEL, token)
original_tokens = _count_tokens(original_prompt)
compressed_tokens = _count_tokens(compressed_prompt)
token_delta = original_tokens - compressed_tokens
savings_pct = (token_delta / original_tokens * 100) if original_tokens else 0.0
prompt_embedding_original = _embed(embedding_client, original_prompt)
prompt_embedding_compressed = _embed(embedding_client, compressed_prompt)
prompt_similarity = _cosine_similarity(prompt_embedding_original, prompt_embedding_compressed)
prompt_metrics_lines = [
f"**Original tokens:** {original_tokens}",
f"**Compressed tokens:** {compressed_tokens}",
f"**Token savings:** {token_delta} ({savings_pct:.1f}%)",
]
if prompt_similarity is not None:
prompt_metrics_lines.append(f"**Prompt cosine similarity:** {prompt_similarity:.3f}")
prompt_metrics = "<br>".join(prompt_metrics_lines)
original_response, original_response_error = _chat_completion(llm_client, original_prompt)
compressed_response, compressed_response_error = _chat_completion(llm_client, compressed_prompt)
response_embedding_original = _embed(embedding_client, original_response)
response_embedding_compressed = _embed(embedding_client, compressed_response)
response_similarity = _cosine_similarity(response_embedding_original, response_embedding_compressed)
response_metrics_lines = []
if response_similarity is not None:
response_metrics_lines.append(f"**Response cosine similarity:** {response_similarity:.3f}")
original_response_display = original_response or "_No response generated for the original prompt._"
compressed_response_display = compressed_response or "_No response generated for the compressed prompt._"
if original_response_error:
original_response_display += f"\n\n> {original_response_error}"
response_metrics_lines.append("⚠️ Downstream model issue on original prompt.")
if compressed_response_error:
compressed_response_display += f"\n\n> {compressed_response_error}"
response_metrics_lines.append("⚠️ Downstream model issue on compressed prompt.")
if not response_metrics_lines:
response_metrics_lines.append("Responses unavailable.")
response_metrics = "<br>".join(response_metrics_lines)
return (
original_prompt,
compressed_prompt,
prompt_metrics,
original_response_display,
compressed_response_display,
response_metrics,
)
with gr.Blocks(fill_height=True, css=".gradio-container {max-width: 900px;}") as demo:
gr.Markdown(
"""
# Very Small Prompt Compression
Enter a user prompt to see how the [gravitee-io/very-small-prompt-compression](https://huggingface.co/gravitee-io/very-small-prompt-compression) checkpoint trims it down,
compares token savings, and checks semantic drift before forwarding to `openai/gpt-oss-20b`.
Trained using the [gravitee-io/dolly-15k-prompt-compression](https://huggingface.co/datasets/gravitee-io/dolly-15k-prompt-compression) dataset.
**Note:** Provide a Hugging Face Inference API token below (or set `HF_TOKEN`) so the demo can call the downstream model.
"""
)
token_input = gr.Textbox(
label="Hugging Face token (optional)",
type="password",
placeholder="Paste an access token to use your own Inference quota",
)
prompt_input = gr.Textbox(
label="User prompt",
placeholder="Describe how to configure a rate limit policy in Gravitee API Management...",
lines=4,
)
with gr.Row():
compress_btn = gr.Button("Compress Prompt", variant="primary")
compare_btn = gr.Button("Compare Responses", variant="secondary")
original_prompt_output = gr.Textbox(label="Original prompt", lines=4, interactive=False)
compressed_output = gr.Textbox(label="Compressed prompt", lines=4, interactive=False)
prompt_metrics_output = gr.Markdown()
with gr.Row():
original_response_output = gr.Markdown(label="Response to original prompt")
compressed_response_output = gr.Markdown(label="Response to compressed prompt")
response_metrics_output = gr.Markdown()
compress_btn.click(
fn=compress_prompt_action,
inputs=[prompt_input, token_input],
outputs=[
original_prompt_output,
compressed_output,
prompt_metrics_output,
original_response_output,
compressed_response_output,
response_metrics_output,
],
)
compare_btn.click(
fn=compare_responses_action,
inputs=[original_prompt_output, compressed_output, token_input],
outputs=[
original_prompt_output,
compressed_output,
prompt_metrics_output,
original_response_output,
compressed_response_output,
response_metrics_output,
],
)
if __name__ == "__main__":
demo.launch()