maxholsman commited on
Commit
2a51a8b
·
verified ·
1 Parent(s): 4a9570a

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +83 -0
README.md CHANGED
@@ -1,3 +1,86 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ # Fuzzy Speculative Decoding
6
+
7
+ Custom generate function for fuzzy speculative decoding with support for KL divergence, Jensen-Shannon divergence, and draft token-based acceptance criteria. This implementation extends the standard speculative decoding algorithm with additional divergence metrics for more flexible candidate acceptance.
8
+
9
+ ## Features
10
+
11
+ - **Fuzzy Speculative Decoding (FSD)**: Accepts candidate tokens based on distribution divergence thresholds
12
+ - **Multiple Divergence Types**:
13
+ - `kl`: KL divergence between candidate and target distributions
14
+ - `js`: Jensen-Shannon divergence
15
+ - `draft_tokens`: Absolute difference in draft token probabilities
16
+ - **Standard Speculative Decoding**: Falls back to standard speculative decoding acceptance when FSD threshold is not met
17
+ - **Raw Logits Support**: Returns both processed and raw logits for advanced use cases
18
+
19
+ ## Installation
20
+
21
+ ```bash
22
+ pip install -r custom_generate/requirements.txt
23
+ ```
24
+
25
+ ## Usage
26
+
27
+ ### Basic Usage
28
+
29
+ ```python
30
+ from transformers import AutoModelForCausalLM, AutoTokenizer
31
+ import torch
32
+
33
+ # Load models
34
+ target_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
35
+ assistant_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
36
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
37
+
38
+ # Prepare input
39
+ prompt = "What is the capital of France?"
40
+ inputs = tokenizer(prompt, return_tensors="pt")
41
+
42
+ # Generate with custom fuzzy speculative decoding
43
+ outputs = target_model.generate(
44
+ **inputs,
45
+ assistant_model=assistant_model,
46
+ custom_generate="maxholsman/fuzzy-spec-dec",
47
+ trust_remote_code=True,
48
+ fsd_threshold=0.0, # FSD acceptance threshold
49
+ fsd_div_type="kl", # Divergence type: "kl", "js", or "draft_tokens"
50
+ do_sample=True,
51
+ temperature=0.7,
52
+ max_new_tokens=100,
53
+ output_logits=True, # Enable raw logits output
54
+ )
55
+
56
+ # Decode result
57
+ generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
58
+ print(generated_text)
59
+ ```
60
+
61
+ ### Custom Parameters
62
+
63
+ - **`fsd_threshold`** (float, default: 0.0): Threshold for fuzzy speculative decoding acceptance. Tokens with divergence below this threshold are automatically accepted.
64
+ - **`fsd_div_type`** (str, default: "kl"): Type of divergence metric to use:
65
+ - `"kl"`: KL divergence (D_KL(candidate || target))
66
+ - `"js"`: Jensen-Shannon divergence
67
+ - `"draft_tokens"`: Absolute difference in draft token probabilities
68
+
69
+ ### How It Works
70
+
71
+ 1. The assistant model generates candidate tokens
72
+ 2. The target model evaluates these candidates
73
+ 3. For each candidate position:
74
+ - If FSD divergence ≤ threshold: token is accepted
75
+ - Otherwise: standard speculative decoding acceptance is applied
76
+ 4. Accepted tokens are kept, rejected tokens trigger resampling from the target model
77
+
78
+ ## Requirements
79
+
80
+ - `torch>=2.0.0`
81
+ - `transformers>=4.40.0`
82
+ - `scikit-learn` (optional, for confidence threshold features)
83
+
84
+ ## License
85
+
86
+ Apache 2.0