|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- text-generation |
|
|
- humor |
|
|
- jokes |
|
|
- distilgpt2 |
|
|
pipeline_tag: text-generation |
|
|
--- |
|
|
|
|
|
# Jokes Model |
|
|
|
|
|
A fine-tuned DistilGPT2 model that generates short, clean, and (sometimes) funny jokes! |
|
|
|
|
|
## Model Details |
|
|
- **Model type:** Causal Language Model (DistilGPT2) |
|
|
- **Fine-tuned on:** 10,000 filtered jokes from [shortjokes.csv](https://www.kaggle.com/datasets/abhinavmoudgil95/short-jokes) |
|
|
- **Training epochs:** 5 |
|
|
- **Max joke length:** 80 tokens |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Direct Inference |
|
|
|
|
|
```python |
|
|
from transformers import pipeline, AutoTokenizer |
|
|
import torch |
|
|
|
|
|
#Please add the BLOCKLIST for clean jokes |
|
|
BLOCKLIST = [ |
|
|
"sex", "naked", "porn", "fuck", "dick", "penis", "ass", |
|
|
"blowjob", "orgasm", "rape", "kill", "die", "shit", |
|
|
"crap", "bastard", "hell", "damn", "bitch", "underage", |
|
|
"pedo", "hit", "shot", "gun", "drug", "drunk", "fag", "cunt" |
|
|
] |
|
|
|
|
|
def is_safe(text): |
|
|
text_lower = text.lower() |
|
|
return not any(bad_word in text_lower for bad_word in BLOCKLIST) |
|
|
|
|
|
def generate_joke(prompt="Tell me a clean joke:"): |
|
|
joke_gen = pipeline( |
|
|
"text-generation", |
|
|
model="FaisalGh/jokes-model", |
|
|
device=0 if torch.cuda.is_available() else -1 |
|
|
) |
|
|
|
|
|
output = joke_gen( |
|
|
prompt, |
|
|
max_length=80, |
|
|
temperature=0.7, |
|
|
top_k=50, |
|
|
top_p=0.9, |
|
|
repetition_penalty=1.5, |
|
|
no_repeat_ngram_size=2, |
|
|
do_sample=True, |
|
|
pad_token_id=50256, |
|
|
eos_token_id=50256 |
|
|
) |
|
|
|
|
|
generated_text = output[0]['generated_text'] |
|
|
first_sentence = generated_text.split(".")[0] + "." |
|
|
|
|
|
return "[Content filtered] Please try again." if not is_safe(first_sentence) else first_sentence.strip() |
|
|
|
|
|
print(generate_joke("Tell me a clean joke:")) |