dotslashderek commited on
Commit
4c908af
·
verified ·
1 Parent(s): 30a29a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +379 -61
app.py CHANGED
@@ -1,69 +1,387 @@
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
2
  from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
 
 
 
 
 
4
 
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
-
19
- messages = [{"role": "system", "content": system_message}]
20
-
21
- messages.extend(history)
22
-
23
- messages.append({"role": "user", "content": message})
24
-
25
- response = ""
26
-
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
62
-
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  gr.LoginButton()
66
- chatbot.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
 
69
  if __name__ == "__main__":
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import os
5
+ import re
6
+ import time
7
+ from dataclasses import dataclass
8
+ from typing import Any, List, Optional, Sequence, Tuple
9
+
10
  import gradio as gr
11
+ import numpy as np
12
+ import torch
13
  from huggingface_hub import InferenceClient
14
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
15
+
16
+ COMPRESSION_MODEL_ID = "gravitee-io/very-small-prompt-compression"
17
+ DOWNSTREAM_MODEL = "openai/gpt-oss-20b"
18
+ EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
19
+ MAX_NEW_TOKENS = 96
20
+
21
+ compression_tokenizer = AutoTokenizer.from_pretrained(COMPRESSION_MODEL_ID, use_fast=True)
22
+ _MODEL_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ compression_model = AutoModelForSeq2SeqLM.from_pretrained(COMPRESSION_MODEL_ID).to(_MODEL_DEVICE)
24
+ compression_model.eval()
25
+
26
+
27
+ @dataclass
28
+ class Segment:
29
+ text: str
30
+ punctuation: str
31
+
32
+
33
+ def _split_prompt(prompt: str) -> List[Segment]:
34
+ """Split a prompt into sentence segments while retaining trailing punctuation."""
35
+ parts = re.findall(r"[^.!?]+[.!?]*", prompt)
36
+ segments: List[Segment] = []
37
+ for part in parts:
38
+ stripped = part.strip()
39
+ if not stripped:
40
+ continue
41
+ punct_len = len(stripped) - len(stripped.rstrip(".?!"))
42
+ punctuation = stripped[-punct_len:] if punct_len else ""
43
+ content = stripped[:-punct_len].strip() if punct_len else stripped
44
+ if content:
45
+ segments.append(Segment(text=content, punctuation=punctuation))
46
+ if not segments and prompt.strip():
47
+ segments.append(Segment(text=prompt.strip(), punctuation=""))
48
+ return segments
49
+
50
+
51
+ def _combine_segments(segments: Sequence[Segment]) -> str:
52
+ pieces = []
53
+ for segment in segments:
54
+ piece = segment.text.strip()
55
+ if segment.punctuation:
56
+ piece = f"{piece}{segment.punctuation}"
57
+ pieces.append(piece)
58
+ return " ".join(piece for piece in pieces if piece).strip()
59
+
60
+
61
+ def _count_tokens(text: str) -> int:
62
+ if not text:
63
+ return 0
64
+ return len(compression_tokenizer.encode(text, add_special_tokens=False))
65
+
66
+
67
+ def _call_compression_model(text: str, *, max_new_tokens: int = MAX_NEW_TOKENS) -> str:
68
+ try:
69
+ encoded = compression_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
70
+ encoded = {k: v.to(_MODEL_DEVICE) for k, v in encoded.items()}
71
+ with torch.no_grad():
72
+ output_ids = compression_model.generate(
73
+ **encoded,
74
+ max_new_tokens=max_new_tokens,
75
+ num_beams=4,
76
+ no_repeat_ngram_size=3,
77
+ )
78
+ compressed = compression_tokenizer.decode(output_ids[0], skip_special_tokens=True)
79
+ except Exception:
80
+ return "broken"
81
+ cleaned = compressed.strip().rstrip("?.!,;:")
82
+ return compressed or text
83
+
84
+
85
+ def _embed(client: InferenceClient, text: str) -> Optional[np.ndarray]:
86
+ if not text.strip():
87
+ return None
88
+ try:
89
+ features = client.feature_extraction(text)
90
+ except Exception:
91
+ return None
92
+ if isinstance(features, list):
93
+ array = np.array(features[0] if features and isinstance(features[0], list) else features, dtype=np.float32)
94
+ else:
95
+ array = np.array(features, dtype=np.float32)
96
+ if array.ndim == 0:
97
+ return None
98
+ if array.ndim > 1:
99
+ array = array.squeeze()
100
+ norm = np.linalg.norm(array)
101
+ if not math.isfinite(norm) or norm == 0.0:
102
+ return None
103
+ return array
104
+
105
+
106
+ def _cosine_similarity(vec_a: np.ndarray | None, vec_b: np.ndarray | None) -> Optional[float]:
107
+ if vec_a is None or vec_b is None:
108
+ return None
109
+ denom = float(np.linalg.norm(vec_a) * np.linalg.norm(vec_b))
110
+ if denom == 0.0:
111
+ return None
112
+ return float(np.dot(vec_a, vec_b) / denom)
113
+
114
+
115
+ def _extract_text(payload: Any) -> str:
116
+ if payload is None:
117
+ return ""
118
+ if isinstance(payload, str):
119
+ return payload
120
+ if isinstance(payload, dict):
121
+ if "text" in payload and isinstance(payload["text"], str):
122
+ return payload["text"]
123
+ content = payload.get("content")
124
+ if isinstance(content, str):
125
+ return content
126
+ if isinstance(content, list):
127
+ return " ".join(_extract_text(item) for item in content)
128
+ if content is None:
129
+ return ""
130
+ if isinstance(payload, list):
131
+ return " ".join(_extract_text(item) for item in payload)
132
+ if hasattr(payload, "content"):
133
+ return _extract_text(getattr(payload, "content"))
134
+ return ""
135
+
136
+
137
+ def _chat_completion(client: InferenceClient, prompt: str) -> Tuple[str, Optional[str]]:
138
+ last_error: Optional[str] = None
139
+ for attempt in range(2):
140
+ try:
141
+ completion = client.chat_completion(
142
+ messages=[
143
+ {"role": "system", "content": "You are a helpful assistant. Answer concisely."},
144
+ {"role": "user", "content": prompt},
145
+ ],
146
+ max_tokens=1024,
147
+ temperature=0.2,
148
+ top_p=0.95,
149
+ )
150
+ except Exception as exc:
151
+ last_error = f"{type(exc).__name__}: {exc}"
152
+ continue
153
+
154
+ try:
155
+ choice = completion.choices[0] if completion.choices else None
156
+ if choice is None:
157
+ last_error = "No choices returned by downstream model."
158
+ continue
159
+ finish_reason = getattr(choice, "finish_reason", None)
160
+ message = getattr(choice, "message", None)
161
+ content = _extract_text(message)
162
+ if not content:
163
+ delta = getattr(choice, "delta", None)
164
+ content = _extract_text(delta)
165
+ if not content:
166
+ raw_choice = getattr(choice, "content", None)
167
+ content = _extract_text(raw_choice)
168
+ content = content.strip()
169
+ if content:
170
+ return content, None
171
+ last_error = f"Model returned an empty response (finish_reason={finish_reason})."
172
+ except Exception as exc:
173
+ last_error = f"{type(exc).__name__}: {exc}"
174
+ return "", last_error or "No response generated."
175
+
176
+
177
+ def _get_client(model_id: str, token: Optional[str]) -> InferenceClient:
178
+ return InferenceClient(model=model_id, token=token)
179
+
180
+
181
+ def _resolve_token(hf_token: Optional[str]) -> Optional[str]:
182
+ return (hf_token or "").strip() or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
183
+
184
+
185
+ def compress_prompt_action(prompt: str, hf_token: Optional[str]) -> Tuple[str, str, str, str, str, str]:
186
+ token = _resolve_token(hf_token)
187
+ prompt = prompt.strip()
188
+ if not prompt:
189
+ message = "Please enter a prompt to compress."
190
+ placeholder = "_Run **Compare Responses** after compression to see downstream outputs._"
191
+ return ("", "", message, placeholder, placeholder, "")
192
+
193
+ embedding_client = _get_client(EMBEDDING_MODEL, token)
194
+
195
+ segments = _split_prompt(prompt)
196
+ compressed_segments: List[Segment] = []
197
+ segment_timings: List[float] = []
198
 
199
+ for segment in segments:
200
+ start = time.perf_counter()
201
+ compressed_text = _call_compression_model(segment.text)
202
+ segment_timings.append(time.perf_counter() - start)
203
+ compressed_segments.append(Segment(text=compressed_text, punctuation=segment.punctuation))
204
 
205
+ compressed_prompt = _combine_segments(compressed_segments).rstrip("?.!,;:")
206
+
207
+ original_tokens = _count_tokens(prompt)
208
+ compressed_tokens = _count_tokens(compressed_prompt)
209
+ token_delta = original_tokens - compressed_tokens
210
+ savings_pct = (token_delta / original_tokens * 100) if original_tokens else 0.0
211
+
212
+ prompt_embedding_original = _embed(embedding_client, prompt)
213
+ prompt_embedding_compressed = _embed(embedding_client, compressed_prompt)
214
+ prompt_similarity = _cosine_similarity(prompt_embedding_original, prompt_embedding_compressed)
215
+
216
+ prompt_metrics_lines = [
217
+ f"**Original tokens:** {original_tokens}",
218
+ f"**Compressed tokens:** {compressed_tokens}",
219
+ f"**Token savings:** {token_delta} ({savings_pct:.1f}%)",
220
+ ]
221
+ if prompt_similarity is not None:
222
+ prompt_metrics_lines.append(f"**Prompt cosine similarity:** {prompt_similarity:.3f}")
223
+ if segment_timings:
224
+ min_ms = min(segment_timings) * 1000.0
225
+ max_ms = max(segment_timings) * 1000.0
226
+ mean_ms = (sum(segment_timings) / len(segment_timings)) * 1000.0
227
+ prompt_metrics_lines.append(
228
+ f"**Segments:** {len(segment_timings)} • **Latency (ms):** min {min_ms:.1f} / mean {mean_ms:.1f} / max {max_ms:.1f}"
229
+ )
230
+ prompt_metrics = "<br>".join(prompt_metrics_lines)
231
+
232
+ placeholder_response = "_Run **Compare Responses** to query the downstream model._"
233
+ response_metrics = "Press **Compare Responses** to evaluate downstream behavior."
234
+
235
+ return (
236
+ prompt,
237
+ compressed_prompt,
238
+ prompt_metrics,
239
+ placeholder_response,
240
+ placeholder_response,
241
+ response_metrics,
242
+ )
243
+
244
+
245
+ def compare_responses_action(
246
+ original_prompt: str, compressed_prompt: str, hf_token: Optional[str]
247
+ ) -> Tuple[str, str, str, str, str, str]:
248
+ token = _resolve_token(hf_token)
249
+ original_prompt = (original_prompt or "").strip()
250
+ compressed_prompt = (compressed_prompt or "").strip()
251
+
252
+ if not original_prompt or not compressed_prompt:
253
+ message = "Please compress a prompt before comparing responses."
254
+ placeholder = "_No response generated._"
255
+ return (
256
+ original_prompt,
257
+ compressed_prompt,
258
+ message,
259
+ placeholder,
260
+ placeholder,
261
+ "Responses unavailable.",
262
+ )
263
+
264
+ embedding_client = _get_client(EMBEDDING_MODEL, token)
265
+ llm_client = _get_client(DOWNSTREAM_MODEL, token)
266
+
267
+ original_tokens = _count_tokens(original_prompt)
268
+ compressed_tokens = _count_tokens(compressed_prompt)
269
+ token_delta = original_tokens - compressed_tokens
270
+ savings_pct = (token_delta / original_tokens * 100) if original_tokens else 0.0
271
+
272
+ prompt_embedding_original = _embed(embedding_client, original_prompt)
273
+ prompt_embedding_compressed = _embed(embedding_client, compressed_prompt)
274
+ prompt_similarity = _cosine_similarity(prompt_embedding_original, prompt_embedding_compressed)
275
+
276
+ prompt_metrics_lines = [
277
+ f"**Original tokens:** {original_tokens}",
278
+ f"**Compressed tokens:** {compressed_tokens}",
279
+ f"**Token savings:** {token_delta} ({savings_pct:.1f}%)",
280
+ ]
281
+ if prompt_similarity is not None:
282
+ prompt_metrics_lines.append(f"**Prompt cosine similarity:** {prompt_similarity:.3f}")
283
+ prompt_metrics = "<br>".join(prompt_metrics_lines)
284
+
285
+ original_response, original_response_error = _chat_completion(llm_client, original_prompt)
286
+ compressed_response, compressed_response_error = _chat_completion(llm_client, compressed_prompt)
287
+
288
+ response_embedding_original = _embed(embedding_client, original_response)
289
+ response_embedding_compressed = _embed(embedding_client, compressed_response)
290
+ response_similarity = _cosine_similarity(response_embedding_original, response_embedding_compressed)
291
+
292
+ response_metrics_lines = []
293
+ if response_similarity is not None:
294
+ response_metrics_lines.append(f"**Response cosine similarity:** {response_similarity:.3f}")
295
+
296
+ original_response_display = original_response or "_No response generated for the original prompt._"
297
+ compressed_response_display = compressed_response or "_No response generated for the compressed prompt._"
298
+ if original_response_error:
299
+ original_response_display += f"\n\n> {original_response_error}"
300
+ response_metrics_lines.append("⚠️ Downstream model issue on original prompt.")
301
+ if compressed_response_error:
302
+ compressed_response_display += f"\n\n> {compressed_response_error}"
303
+ response_metrics_lines.append("⚠️ Downstream model issue on compressed prompt.")
304
+
305
+ if not response_metrics_lines:
306
+ response_metrics_lines.append("Responses unavailable.")
307
+
308
+ response_metrics = "<br>".join(response_metrics_lines)
309
+
310
+ return (
311
+ original_prompt,
312
+ compressed_prompt,
313
+ prompt_metrics,
314
+ original_response_display,
315
+ compressed_response_display,
316
+ response_metrics,
317
+ )
318
+
319
+
320
+ with gr.Blocks(fill_height=True, css=".gradio-container {max-width: 900px;}") as demo:
321
+ gr.Markdown(
322
+ """
323
+ # Very Small Prompt Compression
324
+ Enter a user prompt to see how the [gravitee-io/very-small-prompt-compression](https://huggingface.co/gravitee-io/very-small-prompt-compression) checkpoint trims it down,
325
+ compares token savings, and checks semantic drift before forwarding to `openai/gpt-oss-20b`.
326
+
327
+ Trained using the [gravitee-io/dolly-15k-prompt-compression](https://huggingface.co/datasets/gravitee-io/dolly-15k-prompt-compression) dataset.
328
+ """
329
+ )
330
+
331
+ token_input = gr.Textbox(
332
+ label="Hugging Face token (optional)",
333
+ type="password",
334
+ placeholder="Paste an access token to use your own Inference quota",
335
+ )
336
+ if os.getenv("SPACE_ID"):
337
+ gr.Markdown("If running on Spaces, leave blank to use the Space token or secrets.")
338
  gr.LoginButton()
339
+
340
+ prompt_input = gr.Textbox(
341
+ label="User prompt",
342
+ placeholder="Describe how to configure a rate limit policy in Gravitee API Management...",
343
+ lines=4,
344
+ )
345
+
346
+ with gr.Row():
347
+ compress_btn = gr.Button("Compress Prompt", variant="primary")
348
+ compare_btn = gr.Button("Compare Responses", variant="secondary")
349
+
350
+ original_prompt_output = gr.Textbox(label="Original prompt", lines=4, interactive=False)
351
+ compressed_output = gr.Textbox(label="Compressed prompt", lines=4, interactive=False)
352
+ prompt_metrics_output = gr.Markdown()
353
+
354
+ with gr.Row():
355
+ original_response_output = gr.Markdown(label="Response to original prompt")
356
+ compressed_response_output = gr.Markdown(label="Response to compressed prompt")
357
+
358
+ response_metrics_output = gr.Markdown()
359
+
360
+ compress_btn.click(
361
+ fn=compress_prompt_action,
362
+ inputs=[prompt_input, token_input],
363
+ outputs=[
364
+ original_prompt_output,
365
+ compressed_output,
366
+ prompt_metrics_output,
367
+ original_response_output,
368
+ compressed_response_output,
369
+ response_metrics_output,
370
+ ],
371
+ )
372
+
373
+ compare_btn.click(
374
+ fn=compare_responses_action,
375
+ inputs=[original_prompt_output, compressed_output, token_input],
376
+ outputs=[
377
+ original_prompt_output,
378
+ compressed_output,
379
+ prompt_metrics_output,
380
+ original_response_output,
381
+ compressed_response_output,
382
+ response_metrics_output,
383
+ ],
384
+ )
385
 
386
 
387
  if __name__ == "__main__":