Prithvik-1 commited on
Commit
e465de3
Β·
verified Β·
1 Parent(s): c1b0ab6

Upload test_single_training_sample.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test_single_training_sample.py +188 -0
test_single_training_sample.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test inference on a single training sample with exact training format
4
+ """
5
+
6
+ import json
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ # Add scripts to path
11
+ sys.path.insert(0, str(Path(__file__).parent / "scripts" / "inference"))
12
+
13
+ from inference_codellama import load_local_model
14
+ import torch
15
+
16
+ def generate_with_exact_format(model, tokenizer, instruction, max_new_tokens=800, temperature=0.1):
17
+ """Generate using EXACT training format: instruction + EOS (model continues from here)"""
18
+
19
+ # Use EXACT training format: instruction + EOS token
20
+ # During training: instruction + EOS + response + EOS
21
+ # During inference: instruction + EOS (model will generate response)
22
+ prompt = f"{instruction}{tokenizer.eos_token}"
23
+
24
+ print(f"\nπŸ“ Prompt Format (matching training):")
25
+ print(f" Length: {len(prompt)} chars")
26
+ print(f" First 200 chars: {prompt[:200]}...")
27
+ print()
28
+
29
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1536).to(model.device)
30
+
31
+ print(f"πŸ“Š Tokenized:")
32
+ print(f" Input tokens: {inputs['input_ids'].shape[1]}")
33
+ print()
34
+
35
+ print("πŸ€– Generating...")
36
+ print("=" * 80)
37
+
38
+ with torch.no_grad():
39
+ outputs = model.generate(
40
+ **inputs,
41
+ max_new_tokens=max_new_tokens,
42
+ temperature=temperature,
43
+ do_sample=temperature > 0,
44
+ top_p=0.9 if temperature > 0 else None,
45
+ repetition_penalty=1.2, # Higher to prevent repetition
46
+ pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id,
47
+ eos_token_id=tokenizer.eos_token_id,
48
+ )
49
+
50
+ # Decode only the newly generated tokens (after the prompt)
51
+ generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
52
+ generated_text = tokenizer.decode(generated_ids, skip_special_tokens=False)
53
+
54
+ # Remove EOS token if present at the end
55
+ if generated_text.endswith(tokenizer.eos_token):
56
+ generated_text = generated_text[:-len(tokenizer.eos_token)].rstrip()
57
+
58
+ return generated_text
59
+
60
+ def extract_code_from_response(text):
61
+ """Extract Verilog code from markdown code blocks"""
62
+ if not text:
63
+ return text
64
+
65
+ # Check for verilog code block
66
+ if '```verilog' in text:
67
+ start = text.find('```verilog') + len('```verilog')
68
+ end = text.find('```', start)
69
+ if end != -1:
70
+ extracted = text[start:end].strip()
71
+ return extracted
72
+
73
+ # Check for generic code block
74
+ if '```' in text:
75
+ start = text.find('```')
76
+ if start != -1:
77
+ start_marker = text.find('\n', start)
78
+ if start_marker == -1:
79
+ start_marker = start + 3
80
+ else:
81
+ start_marker += 1
82
+
83
+ end = text.find('```', start_marker)
84
+ if end != -1:
85
+ extracted = text[start_marker:end].strip()
86
+ return extracted
87
+
88
+ return text.strip()
89
+
90
+ def main():
91
+ # Paths
92
+ script_dir = Path(__file__).parent
93
+ model_path = script_dir / "training-outputs" / "codellama-fifo-v1"
94
+ base_model_path = script_dir / "models" / "base-models" / "CodeLlama-7B-Instruct"
95
+ train_dataset = script_dir / "datasets" / "processed" / "split" / "train.jsonl"
96
+
97
+ print("=" * 80)
98
+ print("πŸ§ͺ TESTING SINGLE TRAINING SAMPLE (EXACT TRAINING FORMAT)")
99
+ print("=" * 80)
100
+ print(f"Model: {model_path}")
101
+ print(f"Base: {base_model_path}")
102
+ print("=" * 80)
103
+
104
+ # Load first sample
105
+ print("\nπŸ“š Loading training sample #1...")
106
+ with open(train_dataset, 'r') as f:
107
+ first_line = f.readline()
108
+ sample = json.loads(first_line)
109
+
110
+ instruction = sample.get("instruction", "")
111
+ expected_response = sample.get("response", "")
112
+ expected_code = extract_code_from_response(expected_response)
113
+
114
+ print(f"\nπŸ“ Instruction ({len(instruction)} chars):")
115
+ print("-" * 80)
116
+ print(instruction)
117
+ print("-" * 80)
118
+
119
+ print(f"\n🎯 Expected Response ({len(expected_response)} chars):")
120
+ print("-" * 80)
121
+ print(expected_response[:500] + "..." if len(expected_response) > 500 else expected_response)
122
+ print("-" * 80)
123
+
124
+ # Load model
125
+ print("\nπŸ“¦ Loading model...")
126
+ model, tokenizer = load_local_model(
127
+ str(model_path),
128
+ str(base_model_path) if base_model_path.exists() else None,
129
+ use_quantization=None,
130
+ merge_weights=False
131
+ )
132
+ print("βœ… Model loaded!\n")
133
+
134
+ # Test with different temperatures
135
+ temperatures = [0.1, 0.2, 0.3]
136
+
137
+ for temp in temperatures:
138
+ print("\n" + "=" * 80)
139
+ print(f"πŸ”₯ TESTING WITH TEMPERATURE: {temp}")
140
+ print("=" * 80)
141
+
142
+ try:
143
+ generated_response = generate_with_exact_format(
144
+ model,
145
+ tokenizer,
146
+ instruction,
147
+ max_new_tokens=800,
148
+ temperature=temp
149
+ )
150
+
151
+ generated_code = extract_code_from_response(generated_response)
152
+
153
+ print("\n" + "=" * 80)
154
+ print(f"βœ… GENERATED OUTPUT (Temperature {temp}):")
155
+ print("=" * 80)
156
+ print(generated_response)
157
+ print("=" * 80)
158
+
159
+ print(f"\nπŸ“Š Statistics:")
160
+ print(f" Full response length: {len(generated_response)} chars")
161
+ print(f" Extracted code length: {len(generated_code)} chars")
162
+ print(f" Expected code length: {len(expected_code)} chars")
163
+
164
+ # Quick check if it contains module declaration
165
+ has_module = "module" in generated_response.lower()
166
+ has_endmodule = "endmodule" in generated_response.lower()
167
+ has_verilog_code = "```verilog" in generated_response or ("module" in generated_response and "input" in generated_response)
168
+
169
+ print(f"\nβœ… Code Quality Check:")
170
+ print(f" Contains 'module': {has_module}")
171
+ print(f" Contains 'endmodule': {has_endmodule}")
172
+ print(f" Looks like Verilog code: {has_verilog_code}")
173
+
174
+ if has_verilog_code and has_endmodule:
175
+ print(f" βœ… STATUS: Generated Verilog code!")
176
+ elif has_module:
177
+ print(f" ⚠️ STATUS: Partial code (missing endmodule or full implementation)")
178
+ else:
179
+ print(f" ❌ STATUS: Not generating code (generating text instead)")
180
+
181
+ except Exception as e:
182
+ print(f"❌ Error with temperature {temp}: {e}")
183
+ import traceback
184
+ traceback.print_exc()
185
+
186
+ if __name__ == "__main__":
187
+ main()
188
+