hello-world / example_with_dataset.py
Chiedo John
Add dataset integration to Hello World model
d0c3c53
"""
Example script showing how to use the Hello World model with its dataset.
"""
from transformers import PreTrainedTokenizerFast
from model import HelloWorldModel, HelloWorldConfig
from datasets import load_dataset
import torch
def main():
print("Loading Hello World Model and Dataset Example\n")
print("=" * 50)
# Load model and tokenizer
print("Loading model and tokenizer...")
config = HelloWorldConfig.from_pretrained("chiedo/hello-world")
model = HelloWorldModel.from_pretrained("chiedo/hello-world")
tokenizer = PreTrainedTokenizerFast.from_pretrained("chiedo/hello-world")
# Method 1: Load dataset using the model's built-in method
print("\n1. Loading dataset using model's load_dataset method:")
dataset = HelloWorldModel.load_dataset("chiedo/hello-world")
if dataset:
print(f"Dataset loaded successfully!")
print(f"Splits available: {list(dataset.keys())}")
print(f"Train examples: {len(dataset['train'])}")
print(f"Validation examples: {len(dataset['validation'])}")
print(f"Test examples: {len(dataset['test'])}")
# Show first few examples
print("\nFirst 3 training examples:")
for i in range(min(3, len(dataset['train']))):
example = dataset['train'][i]
print(f" {i+1}. Text: '{example['text']}', Label: {example['label']}")
# Method 2: Load dataset directly
print("\n2. Loading dataset directly with datasets library:")
dataset_direct = load_dataset("chiedo/hello-world")
# Get label names
label_names = dataset_direct['train'].features['label'].names
print(f"Label categories: {label_names}")
# Process a batch from the dataset
print("\n3. Processing a batch from the dataset:")
batch_texts = dataset_direct['train']['text'][:3]
print(f"Batch texts: {batch_texts}")
# Prepare batch for model
inputs = model.prepare_dataset_batch(batch_texts, tokenizer)
print(f"Tokenized input shape: {inputs['input_ids'].shape}")
# Run model inference
print("\n4. Running model inference on dataset batch:")
with torch.no_grad():
outputs = model(**inputs)
print(f"Model output shape: {outputs.logits.shape}")
# Demonstrate the generate_hello_world function
print("\n5. Testing generate_hello_world function:")
result = model.generate_hello_world()
print(f"Generated output: {result}")
# Show how to iterate through dataset
print("\n6. Iterating through test set:")
for i, example in enumerate(dataset_direct['test']):
if i >= 3: # Only show first 3
break
text = example['text']
label_id = example['label']
label_name = label_names[label_id]
# Tokenize and process
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predicted_token = outputs.logits[0, -1].argmax().item()
print(f" Text: '{text}' | Label: {label_name} | Predicted next token ID: {predicted_token}")
print("\n" + "=" * 50)
print("Example completed successfully!")
if __name__ == "__main__":
main()