Commit
·
2bfe150
1
Parent(s):
7355cf2
Update README.md
Browse files
README.md
CHANGED
|
@@ -42,7 +42,8 @@ This model is not suitable for all use cases due to its limited training time on
|
|
| 42 |
```python
|
| 43 |
import torch
|
| 44 |
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
| 45 |
-
|
|
|
|
| 46 |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
|
| 47 |
model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
|
| 48 |
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
|
@@ -68,9 +69,6 @@ def generate_text(model, tokenizer, prompt, max_length=256):
|
|
| 68 |
eos_token_id=tokenizer.eos_token_id,
|
| 69 |
attention_mask=attention_mask)
|
| 70 |
output_ids = tokenizer.decode(output[0], skip_special_tokens=False)
|
| 71 |
-
assistant_token_index = output_ids.index('<|ASSISTANT|>') + len('<|ASSISTANT|>')
|
| 72 |
-
next_token_index = output_ids.find('<|', assistant_token_index)
|
| 73 |
-
output_ids = output_ids[assistant_token_index:next_token_index]
|
| 74 |
return output_ids
|
| 75 |
# Loop to interact with the model
|
| 76 |
while True:
|
|
@@ -78,7 +76,9 @@ while True:
|
|
| 78 |
if prompt == "q":
|
| 79 |
break
|
| 80 |
output_text = generate_text(model, tokenizer, prompt)
|
| 81 |
-
|
|
|
|
|
|
|
| 82 |
```
|
| 83 |
## Deploying and training the model
|
| 84 |
The model has been fine-tuned on a specific input format that goes like this ```"<|USER|> {user prompt} <|ASSISTANT|> {model prediction} <|End|>".``` For the best performance from the model the input text should be as follows ```<|USER|> {dataset prompt} <|ASSISTANT|> ``` and the target/label should be as follows ```<|USER|> {dataset prompt} <|ASSISTANT|> {dataset output} <|End|>```
|
|
|
|
| 42 |
```python
|
| 43 |
import torch
|
| 44 |
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
| 45 |
+
start_token = "<|ASSISTANT|>"
|
| 46 |
+
end_token = "<|"
|
| 47 |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
|
| 48 |
model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
|
| 49 |
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
|
|
|
| 69 |
eos_token_id=tokenizer.eos_token_id,
|
| 70 |
attention_mask=attention_mask)
|
| 71 |
output_ids = tokenizer.decode(output[0], skip_special_tokens=False)
|
|
|
|
|
|
|
|
|
|
| 72 |
return output_ids
|
| 73 |
# Loop to interact with the model
|
| 74 |
while True:
|
|
|
|
| 76 |
if prompt == "q":
|
| 77 |
break
|
| 78 |
output_text = generate_text(model, tokenizer, prompt)
|
| 79 |
+
text_between_tokens = output_text[output_text.find(start_token) + len(start_token):]
|
| 80 |
+
out = text_between_tokens[:text_between_tokens.find(end_token)]
|
| 81 |
+
print(out)
|
| 82 |
```
|
| 83 |
## Deploying and training the model
|
| 84 |
The model has been fine-tuned on a specific input format that goes like this ```"<|USER|> {user prompt} <|ASSISTANT|> {model prediction} <|End|>".``` For the best performance from the model the input text should be as follows ```<|USER|> {dataset prompt} <|ASSISTANT|> ``` and the target/label should be as follows ```<|USER|> {dataset prompt} <|ASSISTANT|> {dataset output} <|End|>```
|