File size: 3,246 Bytes
d0c3c53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
Example script showing how to use the Hello World model with its dataset.
"""

from transformers import PreTrainedTokenizerFast
from model import HelloWorldModel, HelloWorldConfig
from datasets import load_dataset
import torch


def main():
    print("Loading Hello World Model and Dataset Example\n")
    print("=" * 50)
    
    # Load model and tokenizer
    print("Loading model and tokenizer...")
    config = HelloWorldConfig.from_pretrained("chiedo/hello-world")
    model = HelloWorldModel.from_pretrained("chiedo/hello-world")
    tokenizer = PreTrainedTokenizerFast.from_pretrained("chiedo/hello-world")
    
    # Method 1: Load dataset using the model's built-in method
    print("\n1. Loading dataset using model's load_dataset method:")
    dataset = HelloWorldModel.load_dataset("chiedo/hello-world")
    
    if dataset:
        print(f"Dataset loaded successfully!")
        print(f"Splits available: {list(dataset.keys())}")
        print(f"Train examples: {len(dataset['train'])}")
        print(f"Validation examples: {len(dataset['validation'])}")
        print(f"Test examples: {len(dataset['test'])}")
        
        # Show first few examples
        print("\nFirst 3 training examples:")
        for i in range(min(3, len(dataset['train']))):
            example = dataset['train'][i]
            print(f"  {i+1}. Text: '{example['text']}', Label: {example['label']}")
    
    # Method 2: Load dataset directly
    print("\n2. Loading dataset directly with datasets library:")
    dataset_direct = load_dataset("chiedo/hello-world")
    
    # Get label names
    label_names = dataset_direct['train'].features['label'].names
    print(f"Label categories: {label_names}")
    
    # Process a batch from the dataset
    print("\n3. Processing a batch from the dataset:")
    batch_texts = dataset_direct['train']['text'][:3]
    print(f"Batch texts: {batch_texts}")
    
    # Prepare batch for model
    inputs = model.prepare_dataset_batch(batch_texts, tokenizer)
    print(f"Tokenized input shape: {inputs['input_ids'].shape}")
    
    # Run model inference
    print("\n4. Running model inference on dataset batch:")
    with torch.no_grad():
        outputs = model(**inputs)
        print(f"Model output shape: {outputs.logits.shape}")
    
    # Demonstrate the generate_hello_world function
    print("\n5. Testing generate_hello_world function:")
    result = model.generate_hello_world()
    print(f"Generated output: {result}")
    
    # Show how to iterate through dataset
    print("\n6. Iterating through test set:")
    for i, example in enumerate(dataset_direct['test']):
        if i >= 3:  # Only show first 3
            break
        text = example['text']
        label_id = example['label']
        label_name = label_names[label_id]
        
        # Tokenize and process
        inputs = tokenizer(text, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs)
            predicted_token = outputs.logits[0, -1].argmax().item()
        
        print(f"  Text: '{text}' | Label: {label_name} | Predicted next token ID: {predicted_token}")
    
    print("\n" + "=" * 50)
    print("Example completed successfully!")


if __name__ == "__main__":
    main()