|
|
""" |
|
|
Example script showing how to use the Hello World model with its dataset. |
|
|
""" |
|
|
|
|
|
from transformers import PreTrainedTokenizerFast |
|
|
from model import HelloWorldModel, HelloWorldConfig |
|
|
from datasets import load_dataset |
|
|
import torch |
|
|
|
|
|
|
|
|
def main(): |
|
|
print("Loading Hello World Model and Dataset Example\n") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
print("Loading model and tokenizer...") |
|
|
config = HelloWorldConfig.from_pretrained("chiedo/hello-world") |
|
|
model = HelloWorldModel.from_pretrained("chiedo/hello-world") |
|
|
tokenizer = PreTrainedTokenizerFast.from_pretrained("chiedo/hello-world") |
|
|
|
|
|
|
|
|
print("\n1. Loading dataset using model's load_dataset method:") |
|
|
dataset = HelloWorldModel.load_dataset("chiedo/hello-world") |
|
|
|
|
|
if dataset: |
|
|
print(f"Dataset loaded successfully!") |
|
|
print(f"Splits available: {list(dataset.keys())}") |
|
|
print(f"Train examples: {len(dataset['train'])}") |
|
|
print(f"Validation examples: {len(dataset['validation'])}") |
|
|
print(f"Test examples: {len(dataset['test'])}") |
|
|
|
|
|
|
|
|
print("\nFirst 3 training examples:") |
|
|
for i in range(min(3, len(dataset['train']))): |
|
|
example = dataset['train'][i] |
|
|
print(f" {i+1}. Text: '{example['text']}', Label: {example['label']}") |
|
|
|
|
|
|
|
|
print("\n2. Loading dataset directly with datasets library:") |
|
|
dataset_direct = load_dataset("chiedo/hello-world") |
|
|
|
|
|
|
|
|
label_names = dataset_direct['train'].features['label'].names |
|
|
print(f"Label categories: {label_names}") |
|
|
|
|
|
|
|
|
print("\n3. Processing a batch from the dataset:") |
|
|
batch_texts = dataset_direct['train']['text'][:3] |
|
|
print(f"Batch texts: {batch_texts}") |
|
|
|
|
|
|
|
|
inputs = model.prepare_dataset_batch(batch_texts, tokenizer) |
|
|
print(f"Tokenized input shape: {inputs['input_ids'].shape}") |
|
|
|
|
|
|
|
|
print("\n4. Running model inference on dataset batch:") |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
print(f"Model output shape: {outputs.logits.shape}") |
|
|
|
|
|
|
|
|
print("\n5. Testing generate_hello_world function:") |
|
|
result = model.generate_hello_world() |
|
|
print(f"Generated output: {result}") |
|
|
|
|
|
|
|
|
print("\n6. Iterating through test set:") |
|
|
for i, example in enumerate(dataset_direct['test']): |
|
|
if i >= 3: |
|
|
break |
|
|
text = example['text'] |
|
|
label_id = example['label'] |
|
|
label_name = label_names[label_id] |
|
|
|
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
predicted_token = outputs.logits[0, -1].argmax().item() |
|
|
|
|
|
print(f" Text: '{text}' | Label: {label_name} | Predicted next token ID: {predicted_token}") |
|
|
|
|
|
print("\n" + "=" * 50) |
|
|
print("Example completed successfully!") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |