Spaces:
Running
on
Zero
Running
on
Zero
try higher temperature
Browse files- generate.py +2 -1
generate.py
CHANGED
|
@@ -37,6 +37,7 @@ model = models.transformers(model_id, device=device)
|
|
| 37 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 38 |
sampler = PenalizedMultinomialSampler()
|
| 39 |
low_temperature_sampler = PenalizedMultinomialSampler(temperature=0.3)
|
|
|
|
| 40 |
empty_tokens = [token_id for token_id in range(tokenizer.vocab_size) if not tokenizer.decode([token_id], skip_special_tokens=True).strip()]
|
| 41 |
sampler.set_max_repeats(empty_tokens, 1)
|
| 42 |
disallowed_patterns = [regex.compile(r"\p{Han}")] # focus on english for now
|
|
@@ -60,7 +61,7 @@ class Dataset(BaseModel):
|
|
| 60 |
data: conlist(Sample, min_length=2, max_length=3) # type: ignore
|
| 61 |
|
| 62 |
|
| 63 |
-
samples_generator_template = generate.json(model, Dataset, sampler=
|
| 64 |
|
| 65 |
class Columns(BaseModel):
|
| 66 |
columns: conset(Annotated[str, StringConstraints(pattern=r'[a-z0-9_]+')], min_length=2, max_length=len(Sample.model_fields) - 1) # type: ignore
|
|
|
|
| 37 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 38 |
sampler = PenalizedMultinomialSampler()
|
| 39 |
low_temperature_sampler = PenalizedMultinomialSampler(temperature=0.3)
|
| 40 |
+
high_temperature_sampler = PenalizedMultinomialSampler(temperature=1.1)
|
| 41 |
empty_tokens = [token_id for token_id in range(tokenizer.vocab_size) if not tokenizer.decode([token_id], skip_special_tokens=True).strip()]
|
| 42 |
sampler.set_max_repeats(empty_tokens, 1)
|
| 43 |
disallowed_patterns = [regex.compile(r"\p{Han}")] # focus on english for now
|
|
|
|
| 61 |
data: conlist(Sample, min_length=2, max_length=3) # type: ignore
|
| 62 |
|
| 63 |
|
| 64 |
+
samples_generator_template = generate.json(model, Dataset, sampler=high_temperature_sampler)
|
| 65 |
|
| 66 |
class Columns(BaseModel):
|
| 67 |
columns: conset(Annotated[str, StringConstraints(pattern=r'[a-z0-9_]+')], min_length=2, max_length=len(Sample.model_fields) - 1) # type: ignore
|