File size: 2,151 Bytes
65ef412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#!/usr/bin/env python3
"""
Quick test of the trained Prothom Alo model
"""

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def test_model():
    """Test the fine-tuned model"""
    
    print("πŸš€ Testing Prothom Alo Fine-tuned Model")
    print("=" * 50)
    
    # Load the fine-tuned model
    model_path = "./prothomalo_model/final_model"
    print(f"Loading model from: {model_path}")
    
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path)
    
    # Test text generation
    prompts = [
        "The latest news from Bangladesh",
        "In today's opinion piece",
        "Government announces new policy"
    ]
    
    for i, prompt in enumerate(prompts, 1):
        print(f"\nπŸ§ͺ Test {i}: {prompt}")
        print("-" * 40)
        
        # Tokenize input
        inputs = tokenizer(prompt, return_tensors="pt")
        
        # Generate text
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_length=150,
                num_return_sequences=1,
                do_sample=True,
                temperature=0.8,
                pad_token_id=tokenizer.eos_token_id
            )
        
        # Decode and display
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        print(f"Generated: {generated_text}")
    
    # Test Safetensors loading
    print(f"\nπŸ”’ Testing Safetensors Format")
    print("-" * 40)
    
    try:
        from safetensors import safe_open
        
        with safe_open("./prothomalo_model.safetensors", framework="pt", device=0) as f:
            keys = list(f.keys())
            print(f"βœ… Safetensors loaded successfully!")
            print(f"πŸ“Š Contains {len(keys)} tensors")
            print(f"πŸ” First 3 tensor names:")
            for key in keys[:3]:
                print(f"  - {key}")
                
    except Exception as e:
        print(f"❌ Safetensors test failed: {e}")
    
    print(f"\nπŸŽ‰ Model testing completed!")
    return True

if __name__ == "__main__":
    test_model()