CatoG
commited on
Add logprob_answer function and improve diagnostics
Browse filesAdded logprob_answer function to compute log-probability of answers based on prompts. Enhanced DPO diagnostics for evaluating model preferences.
app.py
CHANGED
|
@@ -4,6 +4,7 @@ from datetime import datetime
|
|
| 4 |
|
| 5 |
import torch
|
| 6 |
from torch import nn
|
|
|
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
import pandas as pd
|
|
@@ -21,7 +22,7 @@ from trl import DPOConfig, DPOTrainer
|
|
| 21 |
|
| 22 |
|
| 23 |
# =========================================================
|
| 24 |
-
# MODEL LIST
|
| 25 |
# =========================================================
|
| 26 |
|
| 27 |
MODEL_CHOICES = [
|
|
@@ -92,7 +93,7 @@ DEFAULT_DPO_CONFIG = DPOConfig(
|
|
| 92 |
logging_steps=1,
|
| 93 |
gradient_accumulation_steps=1,
|
| 94 |
learning_rate=1e-4,
|
| 95 |
-
evaluation_strategy="no",
|
| 96 |
warmup_steps=0,
|
| 97 |
fp16=False,
|
| 98 |
save_steps=0,
|
|
@@ -246,7 +247,6 @@ def build_generation_config(
|
|
| 246 |
"""
|
| 247 |
Helper to build a GenerationConfig from UI settings.
|
| 248 |
"""
|
| 249 |
-
# Clamp values a bit just to be safe
|
| 250 |
temperature = max(0.0, float(temperature))
|
| 251 |
max_new_tokens = int(max_new_tokens)
|
| 252 |
return GenerationConfig(
|
|
@@ -310,6 +310,41 @@ def list_trained_model_files() -> List[str]:
|
|
| 310 |
return files
|
| 311 |
|
| 312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
# =========================================================
|
| 314 |
# DPO CALLBACKS
|
| 315 |
# =========================================================
|
|
@@ -327,8 +362,6 @@ def generate_candidates(
|
|
| 327 |
if not prompt.strip():
|
| 328 |
return "", ""
|
| 329 |
|
| 330 |
-
# Build two configs from the same UI settings,
|
| 331 |
-
# but make B slightly more "wild" by bumping top_k / temperature a bit
|
| 332 |
balanced_config = build_generation_config(
|
| 333 |
do_sample=do_sample,
|
| 334 |
temperature=temperature,
|
|
@@ -337,8 +370,6 @@ def generate_candidates(
|
|
| 337 |
top_p=0.9,
|
| 338 |
)
|
| 339 |
|
| 340 |
-
# For creative answer, nudge temperature and top_k a bit, but still
|
| 341 |
-
# keep them tied to UI settings.
|
| 342 |
creative_temp = float(temperature) + 0.4
|
| 343 |
creative_config = build_generation_config(
|
| 344 |
do_sample=do_sample,
|
|
@@ -534,6 +565,76 @@ You can download them using the file list below.
|
|
| 534 |
return msg, last_trained_msg, files
|
| 535 |
|
| 536 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
def generate_from_aligned_model(
|
| 538 |
prompt: str,
|
| 539 |
do_sample: bool,
|
|
@@ -602,7 +703,8 @@ with gr.Blocks() as demo:
|
|
| 602 |
- Collect several preferences and **train the model with DPO**.
|
| 603 |
- Test how the aligned policy model behaves on new prompts.
|
| 604 |
- Download the tuned model (LoRA adapter + tokenizer) after training.
|
| 605 |
-
- **
|
|
|
|
| 606 |
"""
|
| 607 |
)
|
| 608 |
|
|
@@ -806,6 +908,17 @@ with gr.Blocks() as demo:
|
|
| 806 |
outputs=test_answer,
|
| 807 |
)
|
| 808 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 809 |
# model change: reload + clear prefs + reset train status + last trained + downloads
|
| 810 |
model_dropdown.change(
|
| 811 |
fn=on_model_change,
|
|
@@ -822,4 +935,3 @@ with gr.Blocks() as demo:
|
|
| 822 |
|
| 823 |
if __name__ == "__main__":
|
| 824 |
demo.queue().launch()
|
| 825 |
-
|
|
|
|
| 4 |
|
| 5 |
import torch
|
| 6 |
from torch import nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
|
| 9 |
import gradio as gr
|
| 10 |
import pandas as pd
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
# =========================================================
|
| 25 |
+
# MODEL LIST
|
| 26 |
# =========================================================
|
| 27 |
|
| 28 |
MODEL_CHOICES = [
|
|
|
|
| 93 |
logging_steps=1,
|
| 94 |
gradient_accumulation_steps=1,
|
| 95 |
learning_rate=1e-4,
|
| 96 |
+
evaluation_strategy="no",
|
| 97 |
warmup_steps=0,
|
| 98 |
fp16=False,
|
| 99 |
save_steps=0,
|
|
|
|
| 247 |
"""
|
| 248 |
Helper to build a GenerationConfig from UI settings.
|
| 249 |
"""
|
|
|
|
| 250 |
temperature = max(0.0, float(temperature))
|
| 251 |
max_new_tokens = int(max_new_tokens)
|
| 252 |
return GenerationConfig(
|
|
|
|
| 310 |
return files
|
| 311 |
|
| 312 |
|
| 313 |
+
def logprob_answer(
|
| 314 |
+
model: nn.Module,
|
| 315 |
+
tokenizer: AutoTokenizer,
|
| 316 |
+
prompt: str,
|
| 317 |
+
answer: str,
|
| 318 |
+
) -> float:
|
| 319 |
+
"""
|
| 320 |
+
Compute the log-probability of `answer` given `prompt`,
|
| 321 |
+
using a simple "User/Assistant" format:
|
| 322 |
+
|
| 323 |
+
full_text = "User: <prompt>\\nAssistant: <answer>"
|
| 324 |
+
|
| 325 |
+
We approximate p(answer | prompt) by summing log-probs of all tokens
|
| 326 |
+
in the answer region (the shared prompt part cancels in comparisons).
|
| 327 |
+
"""
|
| 328 |
+
model.eval()
|
| 329 |
+
with torch.no_grad():
|
| 330 |
+
full_text = f"User: {prompt}\nAssistant: {answer}"
|
| 331 |
+
enc = tokenizer(
|
| 332 |
+
full_text,
|
| 333 |
+
return_tensors="pt",
|
| 334 |
+
).to(device)
|
| 335 |
+
|
| 336 |
+
input_ids = enc["input_ids"]
|
| 337 |
+
out = model(input_ids=input_ids)
|
| 338 |
+
logits = out.logits[:, :-1, :] # [B, T-1, V]
|
| 339 |
+
labels = input_ids[:, 1:] # [B, T-1]
|
| 340 |
+
|
| 341 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
| 342 |
+
token_log_probs = log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
|
| 343 |
+
total_logprob = token_log_probs.sum().item()
|
| 344 |
+
|
| 345 |
+
return float(total_logprob)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
# =========================================================
|
| 349 |
# DPO CALLBACKS
|
| 350 |
# =========================================================
|
|
|
|
| 362 |
if not prompt.strip():
|
| 363 |
return "", ""
|
| 364 |
|
|
|
|
|
|
|
| 365 |
balanced_config = build_generation_config(
|
| 366 |
do_sample=do_sample,
|
| 367 |
temperature=temperature,
|
|
|
|
| 370 |
top_p=0.9,
|
| 371 |
)
|
| 372 |
|
|
|
|
|
|
|
| 373 |
creative_temp = float(temperature) + 0.4
|
| 374 |
creative_config = build_generation_config(
|
| 375 |
do_sample=do_sample,
|
|
|
|
| 565 |
return msg, last_trained_msg, files
|
| 566 |
|
| 567 |
|
| 568 |
+
def dpo_diagnostics(state_preferences: List[Dict]) -> str:
|
| 569 |
+
"""
|
| 570 |
+
Compute how often the policy_model and ref_model
|
| 571 |
+
assign higher log-probability to the CHOSEN answer
|
| 572 |
+
than to the REJECTED answer.
|
| 573 |
+
|
| 574 |
+
Returns a markdown report with:
|
| 575 |
+
- number of pairs
|
| 576 |
+
- policy win rate
|
| 577 |
+
- ref win rate
|
| 578 |
+
- average logprob margins
|
| 579 |
+
"""
|
| 580 |
+
if not state_preferences:
|
| 581 |
+
return "No preferences collected yet β nothing to evaluate."
|
| 582 |
+
|
| 583 |
+
if policy_model is None or ref_model is None or tokenizer is None:
|
| 584 |
+
return "Models not loaded β reload base model first."
|
| 585 |
+
|
| 586 |
+
n = len(state_preferences)
|
| 587 |
+
policy_wins = 0
|
| 588 |
+
ref_wins = 0
|
| 589 |
+
|
| 590 |
+
policy_margins = []
|
| 591 |
+
ref_margins = []
|
| 592 |
+
|
| 593 |
+
for ex in state_preferences:
|
| 594 |
+
prompt = ex["prompt"]
|
| 595 |
+
chosen = ex["chosen"]
|
| 596 |
+
rejected = ex["rejected"]
|
| 597 |
+
|
| 598 |
+
# Policy model logprobs
|
| 599 |
+
lp_pol_ch = logprob_answer(policy_model, tokenizer, prompt, chosen)
|
| 600 |
+
lp_pol_rj = logprob_answer(policy_model, tokenizer, prompt, rejected)
|
| 601 |
+
margin_pol = lp_pol_ch - lp_pol_rj
|
| 602 |
+
policy_margins.append(margin_pol)
|
| 603 |
+
if margin_pol > 0:
|
| 604 |
+
policy_wins += 1
|
| 605 |
+
|
| 606 |
+
# Reference model logprobs
|
| 607 |
+
lp_ref_ch = logprob_answer(ref_model, tokenizer, prompt, chosen)
|
| 608 |
+
lp_ref_rj = logprob_answer(ref_model, tokenizer, prompt, rejected)
|
| 609 |
+
margin_ref = lp_ref_ch - lp_ref_rj
|
| 610 |
+
ref_margins.append(margin_ref)
|
| 611 |
+
if margin_ref > 0:
|
| 612 |
+
ref_wins += 1
|
| 613 |
+
|
| 614 |
+
policy_winrate = policy_wins / n
|
| 615 |
+
ref_winrate = ref_wins / n
|
| 616 |
+
|
| 617 |
+
avg_pol_margin = sum(policy_margins) / n
|
| 618 |
+
avg_ref_margin = sum(ref_margins) / n
|
| 619 |
+
|
| 620 |
+
report = f"""### π DPO Diagnostics
|
| 621 |
+
|
| 622 |
+
Preference pairs evaluated: **{n}**
|
| 623 |
+
|
| 624 |
+
**Policy model (after DPO)**
|
| 625 |
+
- Win rate (chosen > rejected): **{policy_winrate:.2%}**
|
| 626 |
+
- Avg logprob(chosen β rejected): **{avg_pol_margin:.3f}**
|
| 627 |
+
|
| 628 |
+
**Reference model (base)**
|
| 629 |
+
- Win rate (chosen > rejected): **{ref_winrate:.2%}**
|
| 630 |
+
- Avg logprob(chosen β rejected): **{avg_ref_margin:.3f}**
|
| 631 |
+
|
| 632 |
+
> A higher win rate and margin for the policy model compared to the reference model
|
| 633 |
+
> indicates that DPO training is successfully shifting the model toward your preferences.
|
| 634 |
+
"""
|
| 635 |
+
return report
|
| 636 |
+
|
| 637 |
+
|
| 638 |
def generate_from_aligned_model(
|
| 639 |
prompt: str,
|
| 640 |
do_sample: bool,
|
|
|
|
| 703 |
- Collect several preferences and **train the model with DPO**.
|
| 704 |
- Test how the aligned policy model behaves on new prompts.
|
| 705 |
- Download the tuned model (LoRA adapter + tokenizer) after training.
|
| 706 |
+
- Use **DPO diagnostics** to see if the aligned model prefers your chosen answers
|
| 707 |
+
more often than the base model.
|
| 708 |
"""
|
| 709 |
)
|
| 710 |
|
|
|
|
| 908 |
outputs=test_answer,
|
| 909 |
)
|
| 910 |
|
| 911 |
+
gr.Markdown("## π DPO diagnostics")
|
| 912 |
+
|
| 913 |
+
diag_btn = gr.Button("Compute preference win rates (policy vs base)")
|
| 914 |
+
diag_output = gr.Markdown("")
|
| 915 |
+
|
| 916 |
+
diag_btn.click(
|
| 917 |
+
fn=dpo_diagnostics,
|
| 918 |
+
inputs=[state_preferences],
|
| 919 |
+
outputs=[diag_output],
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
# model change: reload + clear prefs + reset train status + last trained + downloads
|
| 923 |
model_dropdown.change(
|
| 924 |
fn=on_model_change,
|
|
|
|
| 935 |
|
| 936 |
if __name__ == "__main__":
|
| 937 |
demo.queue().launch()
|
|
|