Prithvik-1 commited on
Commit
99416ae
·
verified ·
1 Parent(s): 2d75f24

Upload test_new_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test_new_model.py +204 -0
test_new_model.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test the newly fine-tuned CodeLlama model on training samples
4
+ """
5
+
6
+ import json
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ sys.path.insert(0, str(Path(__file__).parent / "scripts" / "inference"))
11
+
12
+ from inference_codellama import load_local_model
13
+ import torch
14
+ from transformers import AutoTokenizer
15
+ import re
16
+
17
+ def extract_code_from_response(text):
18
+ """Extract Verilog code from markdown code blocks"""
19
+ if not text:
20
+ return text
21
+
22
+ # Check for verilog code block
23
+ if '```verilog' in text:
24
+ start = text.find('```verilog') + len('```verilog')
25
+ end = text.find('```', start)
26
+ if end != -1:
27
+ extracted = text[start:end].strip()
28
+ return extracted
29
+
30
+ # Check for generic code block
31
+ if '```' in text:
32
+ start = text.find('```')
33
+ if start != -1:
34
+ start_marker = text.find('\n', start)
35
+ if start_marker == -1:
36
+ start_marker = start + 3
37
+ else:
38
+ start_marker += 1
39
+
40
+ end = text.find('```', start_marker)
41
+ if end != -1:
42
+ extracted = text[start_marker:end].strip()
43
+ return extracted
44
+
45
+ return text.strip()
46
+
47
+ def generate_with_chat_format(model, tokenizer, instruction, max_new_tokens=1000, temperature=0.1):
48
+ """Generate using chat template format (instruction already has chat format)"""
49
+
50
+ # Instruction already contains: <s>[INST]...[/INST]
51
+ # We just append response + EOS during training
52
+ # During inference: instruction (ends with [/INST]) → model generates response
53
+
54
+ inputs = tokenizer(instruction, return_tensors="pt", truncation=True, max_length=1536).to(model.device)
55
+
56
+ with torch.no_grad():
57
+ outputs = model.generate(
58
+ **inputs,
59
+ max_new_tokens=max_new_tokens,
60
+ temperature=temperature,
61
+ do_sample=temperature > 0,
62
+ top_p=0.95 if temperature > 0 else None,
63
+ repetition_penalty=1.2,
64
+ pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id,
65
+ eos_token_id=tokenizer.eos_token_id,
66
+ )
67
+
68
+ # Decode only new tokens
69
+ input_length = inputs['input_ids'].shape[1]
70
+ generated_ids = outputs[0][input_length:]
71
+ generated_text = tokenizer.decode(generated_ids, skip_special_tokens=False)
72
+
73
+ # Remove trailing EOS if present
74
+ if generated_text.endswith(tokenizer.eos_token):
75
+ generated_text = generated_text[:-len(tokenizer.eos_token)].rstrip()
76
+
77
+ return generated_text
78
+
79
+ def analyze_code_quality(generated_text):
80
+ """Analyze if generated text is proper Verilog code"""
81
+ has_module = "module" in generated_text.lower()
82
+ has_endmodule = "endmodule" in generated_text.lower()
83
+ has_verilog_keywords = any(kw in generated_text.lower() for kw in ["input", "output", "reg", "wire", "assign", "always"])
84
+ has_code_blocks = "```" in generated_text
85
+
86
+ return {
87
+ "has_module": has_module,
88
+ "has_endmodule": has_endmodule,
89
+ "has_verilog_keywords": has_verilog_keywords,
90
+ "has_code_blocks": has_code_blocks,
91
+ "is_verilog": has_module and has_endmodule and has_verilog_keywords
92
+ }
93
+
94
+ def main():
95
+ script_dir = Path(__file__).parent
96
+ model_path = script_dir / "training-outputs" / "codellama-fifo-v2-chat"
97
+ base_model_path = script_dir / "models" / "base-models" / "CodeLlama-7B-Instruct"
98
+ train_dataset = script_dir / "datasets" / "processed" / "split_chat_format" / "train.jsonl"
99
+
100
+ print("=" * 80)
101
+ print("🧪 TESTING NEW FINE-TUNED MODEL ON TRAINING SAMPLES")
102
+ print("=" * 80)
103
+ print(f"Model: {model_path}")
104
+ print(f"Dataset: {train_dataset}")
105
+ print("=" * 80)
106
+
107
+ # Load two samples
108
+ samples = []
109
+ with open(train_dataset, 'r') as f:
110
+ for i, line in enumerate(f):
111
+ if i >= 2: # Get first 2 samples
112
+ break
113
+ if line.strip():
114
+ samples.append(json.loads(line))
115
+
116
+ if len(samples) < 2:
117
+ print(f"❌ Error: Only found {len(samples)} samples in dataset")
118
+ return
119
+
120
+ # Load model
121
+ print("\n📦 Loading model...")
122
+ model, tokenizer = load_local_model(
123
+ str(model_path),
124
+ str(base_model_path) if base_model_path.exists() else None,
125
+ use_quantization=None,
126
+ merge_weights=False
127
+ )
128
+ print("✅ Model loaded!\n")
129
+
130
+ # Test each sample
131
+ for sample_idx, sample in enumerate(samples, 1):
132
+ print("\n" + "=" * 80)
133
+ print(f"📝 SAMPLE {sample_idx}")
134
+ print("=" * 80)
135
+
136
+ instruction = sample.get("instruction", "")
137
+ expected_response = sample.get("response", "")
138
+ expected_code = extract_code_from_response(expected_response)
139
+
140
+ # Extract user message from instruction for display
141
+ if "[/INST]" in instruction:
142
+ user_part = instruction.split("[/INST]")[0]
143
+ user_part = user_part.split("Generate")[1] if "Generate" in user_part else user_part[-100:]
144
+ else:
145
+ user_part = instruction[-200:]
146
+
147
+ print(f"\n📋 Task:")
148
+ print("-" * 80)
149
+ if "Generate" in user_part:
150
+ print(user_part.split("Generate")[1].strip())
151
+ else:
152
+ print(user_part[-150:])
153
+ print("-" * 80)
154
+
155
+ print(f"\n🎯 Expected Response ({len(expected_response)} chars):")
156
+ print("-" * 80)
157
+ print(expected_code[:400] + "..." if len(expected_code) > 400 else expected_code)
158
+ print("-" * 80)
159
+
160
+ # Generate
161
+ print(f"\n🤖 Generating with model...")
162
+ generated_response = generate_with_chat_format(
163
+ model,
164
+ tokenizer,
165
+ instruction,
166
+ max_new_tokens=1000,
167
+ temperature=0.1
168
+ )
169
+
170
+ generated_code = extract_code_from_response(generated_response)
171
+
172
+ print("\n" + "=" * 80)
173
+ print(f"✅ GENERATED OUTPUT:")
174
+ print("=" * 80)
175
+ print(generated_response[:1000] + "..." if len(generated_response) > 1000 else generated_response)
176
+ print("=" * 80)
177
+
178
+ # Analysis
179
+ quality = analyze_code_quality(generated_response)
180
+
181
+ print(f"\n📊 Analysis:")
182
+ print(f" Response length: {len(generated_response)} chars")
183
+ print(f" Extracted code length: {len(generated_code)} chars")
184
+ print(f" Contains 'module': {quality['has_module']}")
185
+ print(f" Contains 'endmodule': {quality['has_endmodule']}")
186
+ print(f" Contains Verilog keywords: {quality['has_verilog_keywords']}")
187
+ print(f" Contains code blocks: {quality['has_code_blocks']}")
188
+
189
+ if quality['is_verilog']:
190
+ print(f" ✅ STATUS: Generated valid Verilog code!")
191
+ elif quality['has_module']:
192
+ print(f" ⚠️ STATUS: Partial Verilog code (missing endmodule or keywords)")
193
+ else:
194
+ print(f" ❌ STATUS: Not generating Verilog code")
195
+
196
+ print("\n" + "-" * 80)
197
+
198
+ print("\n" + "=" * 80)
199
+ print("✅ TESTING COMPLETE")
200
+ print("=" * 80)
201
+
202
+ if __name__ == "__main__":
203
+ main()
204
+