likhonsheikh commited on
Commit
bc8a950
ยท
verified ยท
1 Parent(s): 30c47bb

Add model testing and validation script

Browse files
Files changed (1) hide show
  1. test_model.py +73 -0
test_model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Quick test of the trained Prothom Alo model
4
+ """
5
+
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+
9
+ def test_model():
10
+ """Test the fine-tuned model"""
11
+
12
+ print("๐Ÿš€ Testing Prothom Alo Fine-tuned Model")
13
+ print("=" * 50)
14
+
15
+ # Load the fine-tuned model
16
+ model_path = "./prothomalo_model/final_model"
17
+ print(f"Loading model from: {model_path}")
18
+
19
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
20
+ model = AutoModelForCausalLM.from_pretrained(model_path)
21
+
22
+ # Test text generation
23
+ prompts = [
24
+ "The latest news from Bangladesh",
25
+ "In today's opinion piece",
26
+ "Government announces new policy"
27
+ ]
28
+
29
+ for i, prompt in enumerate(prompts, 1):
30
+ print(f"\n๐Ÿงช Test {i}: {prompt}")
31
+ print("-" * 40)
32
+
33
+ # Tokenize input
34
+ inputs = tokenizer(prompt, return_tensors="pt")
35
+
36
+ # Generate text
37
+ with torch.no_grad():
38
+ outputs = model.generate(
39
+ **inputs,
40
+ max_length=150,
41
+ num_return_sequences=1,
42
+ do_sample=True,
43
+ temperature=0.8,
44
+ pad_token_id=tokenizer.eos_token_id
45
+ )
46
+
47
+ # Decode and display
48
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
+ print(f"Generated: {generated_text}")
50
+
51
+ # Test Safetensors loading
52
+ print(f"\n๐Ÿ”’ Testing Safetensors Format")
53
+ print("-" * 40)
54
+
55
+ try:
56
+ from safetensors import safe_open
57
+
58
+ with safe_open("./prothomalo_model.safetensors", framework="pt", device=0) as f:
59
+ keys = list(f.keys())
60
+ print(f"โœ… Safetensors loaded successfully!")
61
+ print(f"๐Ÿ“Š Contains {len(keys)} tensors")
62
+ print(f"๐Ÿ” First 3 tensor names:")
63
+ for key in keys[:3]:
64
+ print(f" - {key}")
65
+
66
+ except Exception as e:
67
+ print(f"โŒ Safetensors test failed: {e}")
68
+
69
+ print(f"\n๐ŸŽ‰ Model testing completed!")
70
+ return True
71
+
72
+ if __name__ == "__main__":
73
+ test_model()