tuandunghcmut commited on
Commit
6fe2ef0
·
verified ·
1 Parent(s): 280f9f2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +1278 -16
README.md CHANGED
@@ -13,30 +13,1292 @@ model-index:
13
  <!-- This model card has been generated automatically according to the information the Trainer had access to. You
14
  should probably proofread and complete it, then remove this comment. -->
15
 
16
- # Qwen25_Coder_MultipleChoice_v4
17
 
18
- This model is a fine-tuned version of [unsloth/qwen2.5-coder-1.5b-instruct-bnb-4bit](https://huggingface.co/unsloth/qwen2.5-coder-1.5b-instruct-bnb-4bit) on an unknown dataset.
19
- It achieves the following results on the evaluation set:
20
- - eval_loss: 0.5559
21
- - eval_runtime: 23.4917
22
- - eval_samples_per_second: 3.831
23
- - eval_steps_per_second: 0.511
24
- - epoch: 2.7687
25
- - step: 390
26
 
27
- ## Model description
28
 
29
- More information needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- ## Intended uses & limitations
 
 
32
 
33
- More information needed
 
 
 
 
 
 
 
34
 
35
- ## Training and evaluation data
 
 
36
 
37
- More information needed
 
 
 
 
 
38
 
39
- ## Training procedure
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  ### Training hyperparameters
42
 
 
13
  <!-- This model card has been generated automatically according to the information the Trainer had access to. You
14
  should probably proofread and complete it, then remove this comment. -->
15
 
 
16
 
 
 
 
 
 
 
 
 
17
 
 
18
 
19
+ # Qwen25_Coder_MultipleChoice
20
+
21
+ * This project focuses on distilling YAML-based structured multi-step reasoning capabilities from the GPT-4o teacher model into the smaller Qwen2.5 Coder 1.5B-Instruct LLM.
22
+
23
+ * This document provides guidance on getting started with `tuandunghcmut/Qwen25_Coder_MultipleChoice_v4`, a model fine-tuned for multiple-choice coding questions.
24
+
25
+ * The current repository of this project is at [https://github.com/tuandunghcmut/Small-Qwen-Coding-Multiple-Choice](https://github.com/tuandunghcmut/Small-Qwen-Coding-Multiple-Choice).
26
+
27
+ * The dataset has been used for training is at [https://huggingface.co/datasets/tuandunghcmut/coding-mcq-reasoning](https://huggingface.co/datasets/tuandunghcmut/coding-mcq-reasoning).
28
+
29
+ <!-- /workspace/Small-Qwen-Coding-Multiple-Choice/notebooks/inference_examples.ipynb -->
30
+ * A demonstration notebook is available on Google Colab (click the badge below). Please note that the training code has been omitted from this notebook. It is intended solely for testing and inference using the latest checkpoint.
31
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tuandunghcmut/Small-Qwen-Coding-Multiple-Choice/blob/main/notebooks/inference_examples.ipynb)
32
+
33
+ * The guide below provides an explanation of the code presented in the notebook. This notebook is a compressed version of the whole project.
34
+
35
+ ## Installation
36
+
37
+ First, install the required dependencies:
38
+
39
+ ```bash
40
+ # Install core dependencies
41
+ pip install transformers torch
42
+
43
+ # For faster inference (important)
44
+ pip install unsloth bitsandbytes
45
+
46
+ # Flash Attention (highly recommended for speed)
47
+ pip install flash-attn --no-build-isolation
48
+
49
+ # For dataset handling and YAML parsing
50
+ pip install datasets
51
+ ```
52
+
53
+ ## Environment Setup
54
+
55
+ This project requires several API keys for authentication. Create a `.env` file in the root directory with the following variables:
56
+
57
+ ```
58
+ # API Keys for authentication
59
+ OPENAI_API_KEY=your_openai_api_key_here
60
+ HF_TOKEN=your_huggingface_token_here
61
+ WANDB_API_KEY=your_wandb_api_key_here
62
+ ```
63
+
64
+ You can copy the provided `.env.example` file and fill in your credentials:
65
+
66
+ ```bash
67
+ cp .env.example .env
68
+ # Edit the .env file with your actual API keys
69
+ ```
70
+
71
+ These environment variables are required for:
72
+ - `HF_TOKEN`: Accessing Hugging Face models and datasets
73
+ - `WANDB_API_KEY`: Logging experiments to Weights & Biases
74
+ - `OPENAI_API_KEY`: Used if generating teacher completions with OpenAI models
75
+
76
+ ## Key Classes
77
+
78
+ The project provides several key classes for working with the model:
79
+
80
+ ### 1. QwenModelHandler
81
+ ```python
82
+ class QwenModelHandler:
83
+ """Handler for Qwen models with inference and saving capabilities using Unsloth"""
84
+
85
+ def __init__(self, model_name="unsloth/Qwen2.5-7B", max_seq_length=768,
86
+ quantization=None, device_map="auto", cache_dir=None):
87
+ """
88
+ Initialize model and tokenizer using Unsloth
89
+
90
+ Args:
91
+ model_name: Name or path of the model (preferably an unsloth model)
92
+ max_seq_length: Maximum sequence length for the model
93
+ quantization: Quantization type (None, '4bit', '8bit') - for compatibility
94
+ device_map: Device mapping strategy
95
+ cache_dir: Cache directory for models
96
+ """
97
+ ```
98
+
99
+ This class handles the core model operations:
100
+ - Model loading and initialization
101
+ - Text generation with streaming support
102
+ - Perplexity calculation
103
+ - Model saving and pushing to HuggingFace Hub
104
+
105
+ ### 2. PromptCreator
106
+ ```python
107
+ class PromptCreator:
108
+ """Creates and formats prompts for multiple choice questions"""
109
+
110
+ # Prompt types
111
+ BASIC = "basic" # Simple answer-only format
112
+ YAML_REASONING = "yaml" # YAML formatted reasoning
113
+ TEACHER_REASONED = "teacher" # Same YAML format but using teacher completions
114
+ ```
115
+
116
+ This class manages prompt creation with three modes:
117
+ - Basic: Simple answer-only format
118
+ - YAML Reasoning: Structured reasoning in YAML format
119
+ - Teacher Reasoned: YAML format with teacher completions for training
120
+
121
+ ### 3. ResponseParser
122
+ ```python
123
+ class ResponseParser:
124
+ """Parser for model responses with support for different formats"""
125
+
126
+ # Parser modes
127
+ BASIC = "basic" # Extract single letter answer
128
+ YAML = "yaml" # Parse YAML formatted response with reasoning
129
+ ```
130
+
131
+ This class handles response parsing:
132
+ - Extracts answers from model responses
133
+ - Parses YAML-formatted reasoning
134
+ - Supports both basic and YAML formats
135
+
136
+ ### 4. MultipleChoiceTester
137
+ ```python
138
+ class MultipleChoiceTester:
139
+ """Framework for testing Qwen models on multiple choice questions"""
140
+
141
+ def __init__(self, model_handler, prompt_creator=None):
142
+ """
143
+ Initialize with model handler and prompt configuration
144
+
145
+ Args:
146
+ model_handler: The QwenModelHandler instance
147
+ prompt_creator: Optional PromptCreator instance
148
+ """
149
+ ```
150
+
151
+ This class provides a complete testing framework:
152
+ - Single example inference
153
+ - Batch processing
154
+ - Dataset evaluation
155
+ - Performance metrics tracking
156
+ - Results saving and visualization
157
+
158
+ ## Full Class Implementations
159
+
160
+ <details>
161
+ <summary>Click to expand/collapse full class implementations</summary>
162
+
163
+ ### 1. QwenModelHandler
164
+ ```python
165
+ class QwenModelHandler:
166
+ """Handler for Qwen models with inference and saving capabilities using Unsloth"""
167
+
168
+ def __init__(self, model_name="unsloth/Qwen2.5-7B", max_seq_length=768,
169
+ quantization=None, device_map="auto", cache_dir=None):
170
+ self.model_name = model_name
171
+ self.max_seq_length = max_seq_length
172
+ self.device_map = device_map
173
+ self.quantization = quantization
174
+ self.cache_dir = cache_dir
175
+
176
+ # Convert quantization parameter to load_in_4bit parameter for Unsloth
177
+ self.load_in_4bit = quantization == "4bit"
178
+
179
+ # Load tokenizer and model
180
+ self.tokenizer, self.model = self._load_model()
181
+ self.response_parser = ResponseParser()
182
+
183
+ def _load_model(self):
184
+ """Load model and tokenizer with Unsloth for optimization"""
185
+ from unsloth import FastLanguageModel
186
+ import torch
187
+
188
+ print(f"Loading {self.model_name} with Unsloth, max_seq_length={self.max_seq_length}")
189
+
190
+ # Set dtype based on hardware
191
+ dtype = None # None for auto detection
192
+
193
+ # Load model and tokenizer with Unsloth
194
+ model, tokenizer = FastLanguageModel.from_pretrained(
195
+ model_name=self.model_name,
196
+ max_seq_length=self.max_seq_length,
197
+ dtype=dtype,
198
+ load_in_4bit=self.load_in_4bit,
199
+ cache_dir=self.cache_dir,
200
+ )
201
+
202
+ return tokenizer, model
203
+
204
+ def generate_with_streaming(self, prompt, temperature=0.7, max_tokens=1024, stream=True):
205
+ """Generate completion with optional streaming using Unsloth's optimized inference"""
206
+ # Enable faster inference
207
+ from unsloth import FastLanguageModel
208
+ FastLanguageModel.for_inference(self.model)
209
+
210
+ # Format as chat
211
+ messages = [{"role": "user", "content": prompt}]
212
+ chat_text = self.tokenizer.apply_chat_template(
213
+ messages,
214
+ tokenize=False,
215
+ add_generation_prompt=True
216
+ )
217
+
218
+ # Tokenize input
219
+ model_inputs = self.tokenizer([chat_text], return_tensors="pt").to(self.model.device)
220
+
221
+ # Generate with streaming if requested
222
+ if stream:
223
+ from transformers import TextIteratorStreamer
224
+ import threading
225
+
226
+ # Set up streamer
227
+ streamer = TextIteratorStreamer(
228
+ self.tokenizer,
229
+ skip_prompt=True,
230
+ skip_special_tokens=True
231
+ )
232
+
233
+ # Start generation in a thread
234
+ generation_kwargs = {
235
+ "input_ids": model_inputs.input_ids,
236
+ "attention_mask": model_inputs.attention_mask,
237
+ "temperature": temperature,
238
+ "max_new_tokens": max_tokens,
239
+ "streamer": streamer,
240
+ "do_sample": temperature > 0.0,
241
+ "use_cache": True,
242
+ "min_p": 0.1 if temperature > 0.0 else None,
243
+ }
244
+
245
+ thread = threading.Thread(target=self.model.generate, kwargs=generation_kwargs)
246
+ thread.start()
247
+
248
+ return streamer
249
+ else:
250
+ # Generate without streaming
251
+ generated_ids = self.model.generate(
252
+ input_ids=model_inputs.input_ids,
253
+ attention_mask=model_inputs.attention_mask,
254
+ temperature=temperature,
255
+ max_new_tokens=max_tokens,
256
+ do_sample=temperature > 0.0,
257
+ use_cache=True,
258
+ min_p=0.1 if temperature > 0.0 else None,
259
+ )
260
+
261
+ # Decode the generated text
262
+ generated_text = self.tokenizer.decode(
263
+ generated_ids[0][model_inputs.input_ids.shape[1]:],
264
+ skip_special_tokens=True
265
+ )
266
+
267
+ return generated_text
268
+
269
+ def calculate_perplexity(self, prompt, answer, temperature=0.0):
270
+ """Calculate perplexity for a prompt and answer pair"""
271
+ import torch
272
+
273
+ # Format chat for perplexity calculation
274
+ messages = [
275
+ {"role": "user", "content": prompt},
276
+ {"role": "assistant", "content": answer}
277
+ ]
278
+ chat_text = self.tokenizer.apply_chat_template(
279
+ messages,
280
+ tokenize=False
281
+ )
282
+
283
+ # Tokenize the text
284
+ encodings = self.tokenizer(chat_text, return_tensors="pt").to(self.model.device)
285
+
286
+ # Calculate loss
287
+ with torch.no_grad():
288
+ outputs = self.model(**encodings, labels=encodings.input_ids)
289
+
290
+ # Get loss and calculate perplexity
291
+ neg_log_likelihood = outputs.loss.item()
292
+ perplexity = torch.exp(torch.tensor(neg_log_likelihood)).item()
293
+
294
+ return perplexity
295
+
296
+ def save_model(self, output_dir, save_method="lora"):
297
+ """Save model to disk using Unsloth's optimized methods"""
298
+ import os
299
+
300
+ os.makedirs(output_dir, exist_ok=True)
301
+
302
+ # Use Unsloth's saving methods
303
+ if save_method == "lora":
304
+ self.model.save_pretrained(output_dir)
305
+ self.tokenizer.save_pretrained(output_dir)
306
+ elif save_method == "merged_16bit":
307
+ self.model.save_pretrained_merged(output_dir, self.tokenizer, save_method="merged_16bit")
308
+ elif save_method == "merged_4bit":
309
+ self.model.save_pretrained_merged(output_dir, self.tokenizer, save_method="merged_4bit")
310
+ elif save_method == "gguf":
311
+ self.model.save_pretrained_gguf(output_dir, self.tokenizer, quantization_method="q4_k_m")
312
+ else:
313
+ raise ValueError(f"Unknown save method: {save_method}")
314
+
315
+ print(f"Model saved to {output_dir} using method {save_method}")
316
+ return output_dir
317
+
318
+ def push_to_hub(self, repo_id, token=None, save_method="lora", private=False):
319
+ """Push model to Hugging Face Hub using Unsloth's optimized methods"""
320
+ if save_method == "lora":
321
+ self.model.push_to_hub_merged(repo_id, self.tokenizer, save_method="lora", token=token)
322
+ elif save_method == "merged_16bit":
323
+ self.model.push_to_hub_merged(repo_id, self.tokenizer, save_method="merged_16bit", token=token)
324
+ elif save_method == "merged_4bit":
325
+ self.model.push_to_hub_merged(repo_id, self.tokenizer, save_method="merged_4bit", token=token)
326
+ elif save_method == "gguf":
327
+ self.model.push_to_hub_gguf(
328
+ repo_id,
329
+ self.tokenizer,
330
+ quantization_method=["q4_k_m", "q5_k_m"],
331
+ token=token
332
+ )
333
+ else:
334
+ raise ValueError(f"Unknown save method: {save_method}")
335
+
336
+ print(f"Model successfully pushed to: https://huggingface.co/{repo_id}")
337
+ return f"https://huggingface.co/{repo_id}"
338
+ ```
339
+
340
+ ### 2. PromptCreator
341
+ ```python
342
+ class PromptCreator:
343
+ """Creates and formats prompts for multiple choice questions"""
344
+
345
+ # Prompt types
346
+ BASIC = "basic" # Simple answer-only format
347
+ YAML_REASONING = "yaml" # YAML formatted reasoning
348
+ TEACHER_REASONED = "teacher" # Same YAML format but using teacher completions
349
+
350
+ def __init__(self, prompt_type=BASIC):
351
+ if prompt_type == self.TEACHER_REASONED:
352
+ prompt_type = self.YAML_REASONING
353
+ self.prompt_type = prompt_type
354
+ self.original_type = prompt_type
355
+
356
+ def format_choices(self, choices):
357
+ """Format choices as a lettered list"""
358
+ return "\n".join(
359
+ [f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)]
360
+ )
361
+
362
+ def get_max_letter(self, choices):
363
+ """Get the maximum letter based on number of choices"""
364
+ return chr(65 + len(choices) - 1)
365
+
366
+ def create_inference_prompt(self, question, choices):
367
+ """Create a prompt for inference based on current prompt type"""
368
+ formatted_choices = self.format_choices(choices)
369
+ max_letter = self.get_max_letter(choices)
370
+
371
+ if self.prompt_type == self.YAML_REASONING:
372
+ return self._create_yaml_prompt(question, formatted_choices, max_letter)
373
+ else:
374
+ return self._create_basic_prompt(question, formatted_choices, max_letter)
375
+
376
+ def _create_basic_prompt(self, question, formatted_choices, max_letter):
377
+ """Create a basic prompt asking for just the answer letter"""
378
+ return f"""
379
+ QUESTION:
380
+ {question}
381
+
382
+ CHOICES:
383
+ {formatted_choices}
384
+
385
+ Answer with a single letter from A through {max_letter} without any additional explanation or commentary.
386
+ """
387
+
388
+ def _create_yaml_prompt(self, question, formatted_choices, max_letter):
389
+ """Create a prompt requesting YAML-formatted reasoning"""
390
+ return f"""
391
+ QUESTION:
392
+ {question}
393
+
394
+ CHOICES:
395
+ {formatted_choices}
396
+
397
+ Analyze this question step-by-step and provide a detailed explanation.
398
+ Your response MUST be in YAML format as follows:
399
+
400
+ understanding: |
401
+ <your understanding of what the question is asking>
402
+ analysis: |
403
+ <your analysis of each option>
404
+ reasoning: |
405
+ <your step-by-step reasoning process>
406
+ conclusion: |
407
+ <your final conclusion>
408
+ answer: <single letter A through {max_letter}>
409
+
410
+ The answer field MUST contain ONLY a single character letter.
411
+ """
412
+
413
+ def create_training_prompt(self, question, choices):
414
+ """Create a prompt for training with the current prompt type"""
415
+ formatted_choices = self.format_choices(choices)
416
+ max_letter = self.get_max_letter(choices)
417
+
418
+ if self.prompt_type == self.YAML_REASONING:
419
+ return self._create_yaml_training_prompt(
420
+ question, formatted_choices, max_letter
421
+ )
422
+ else:
423
+ return self._create_basic_training_prompt(
424
+ question, formatted_choices, max_letter
425
+ )
426
+
427
+ def _create_basic_training_prompt(self, question, formatted_choices, max_letter):
428
+ """Create a basic training prompt"""
429
+ return f"""
430
+ QUESTION:
431
+ {question}
432
+
433
+ CHOICES:
434
+ {formatted_choices}
435
+
436
+ The answer is a single letter (A, B, C, etc.). Only provide ONE character as your answer:
437
+ """
438
+
439
+ def _create_yaml_training_prompt(self, question, formatted_choices, max_letter):
440
+ """Create a YAML-formatted training prompt"""
441
+ return f"""
442
+ QUESTION:
443
+ {question}
444
+
445
+ CHOICES:
446
+ {formatted_choices}
447
+
448
+ Analyze this question step-by-step and provide a detailed explanation.
449
+ Follow the YAML format in your response:
450
+
451
+ understanding: |
452
+ <your understanding of the question>
453
+ analysis: |
454
+ <your analysis of each option>
455
+ reasoning: |
456
+ <your reasoning about the correct answer>
457
+ conclusion: |
458
+ <your final conclusion>
459
+ answer: <single letter A through {max_letter}>
460
+ """
461
+
462
+ def set_prompt_type(self, prompt_type):
463
+ """Set the prompt type"""
464
+ self.original_type = prompt_type
465
+ if prompt_type == self.TEACHER_REASONED:
466
+ pass
467
+ self.prompt_type = prompt_type
468
+ return self
469
+
470
+ def is_teacher_mode(self):
471
+ """Check if we're using teacher mode"""
472
+ return self.original_type == self.TEACHER_REASONED
473
+ ```
474
+
475
+ ### 3. ResponseParser
476
+ ```python
477
+ class ResponseParser:
478
+ """Parser for model responses with support for different formats"""
479
+
480
+ # Parser modes
481
+ BASIC = "basic" # Extract single letter answer
482
+ YAML = "yaml" # Parse YAML formatted response with reasoning
483
+
484
+ def __init__(self, parser_mode=BASIC):
485
+ self.parser_mode = parser_mode
486
+
487
+ def parse(self, response_text):
488
+ """Parse the model's response according to the current mode"""
489
+ if self.parser_mode == self.YAML:
490
+ return self._parse_yaml_response(response_text)
491
+ else:
492
+ return self._parse_basic_response(response_text)
493
+
494
+ def _parse_basic_response(self, response_text):
495
+ """Parse basic response looking for a letter answer"""
496
+ import re
497
+
498
+ # Try to extract a single letter answer (A-Z)
499
+ answer_match = re.search(r"(?:^|\s)([A-Z])(?:\s|$|\.)", response_text)
500
+ if answer_match:
501
+ answer = answer_match.group(1)
502
+ else:
503
+ # Take first character if it's a letter
504
+ if response_text and response_text[0].isalpha():
505
+ answer = response_text[0].upper()
506
+ else:
507
+ answer = None
508
+
509
+ # For basic mode, we don't extract detailed reasoning
510
+ reasoning = ""
511
+
512
+ return answer, reasoning
513
+
514
+ def _parse_yaml_response(self, response_text):
515
+ """Parse YAML formatted response extracting answer and reasoning"""
516
+ import re
517
+ import yaml
518
+
519
+ # First try to find answer in YAML format
520
+ yaml_match = re.search(r"answer:\s*([A-Z])", response_text)
521
+ if yaml_match:
522
+ answer = yaml_match.group(1)
523
+ else:
524
+ # Fall back to basic extraction if YAML parsing fails
525
+ answer_match = re.search(r"(?:^|\s)([A-Z])(?:\s|$|\.)", response_text)
526
+ if answer_match:
527
+ answer = answer_match.group(1)
528
+ elif response_text and response_text[0].isalpha():
529
+ answer = response_text[0].upper()
530
+ else:
531
+ answer = None
532
+
533
+ # Try to parse reasoning from YAML format
534
+ reasoning = ""
535
+ if "reasoning:" in response_text:
536
+ yaml_content = yaml.safe_load("---\n" + response_text)
537
+ if isinstance(yaml_content, dict) and "reasoning" in yaml_content:
538
+ reasoning = yaml_content["reasoning"]
539
+
540
+ # Add other YAML fields if available
541
+ if "understanding" in yaml_content:
542
+ reasoning = f"Understanding: {yaml_content['understanding']}\n\n{reasoning}"
543
+ if "conclusion" in yaml_content:
544
+ reasoning = f"{reasoning}\n\nConclusion: {yaml_content['conclusion']}"
545
+ else:
546
+ # Use the full response as reasoning if not in YAML format
547
+ reasoning = response_text
548
+
549
+ return answer, reasoning
550
+
551
+ def set_parser_mode(self, parser_mode):
552
+ """Set the parser mode"""
553
+ self.parser_mode = parser_mode
554
+ return self
555
+
556
+ @classmethod
557
+ def from_prompt_type(cls, prompt_type):
558
+ """Create a parser instance with mode matching the prompt type"""
559
+ if prompt_type == PromptCreator.YAML_REASONING or prompt_type == PromptCreator.TEACHER_REASONED:
560
+ return cls(parser_mode=cls.YAML)
561
+ else:
562
+ return cls(parser_mode=cls.BASIC)
563
+ ```
564
+
565
+ ### 4. MultipleChoiceTester
566
+ ```python
567
+ class MultipleChoiceTester:
568
+ """Framework for testing Qwen models on multiple choice questions"""
569
+
570
+ def __init__(self, model_handler, prompt_creator=None):
571
+ self.model_handler = model_handler
572
+ self.prompt_creator = prompt_creator or PromptCreator(PromptCreator.BASIC)
573
+ self.response_parser = ResponseParser.from_prompt_type(self.prompt_creator.prompt_type)
574
+
575
+ def infer_example(self, example, temperature=0.7, max_tokens=1024, prompt_type=None, stream=False):
576
+ """Inference on a single example for visualization/demonstration"""
577
+ # Allow temporary override of prompt type
578
+ original_prompt_type = None
579
+ if prompt_type is not None:
580
+ original_prompt_type = self.prompt_creator.prompt_type
581
+ self.prompt_creator.set_prompt_type(prompt_type)
582
+ self.response_parser = ResponseParser.from_prompt_type(prompt_type)
583
+
584
+ # Prepare data
585
+ question = example["question"]
586
+
587
+ # Handle different formats of choices
588
+ if isinstance(example["choices"], list):
589
+ choices = example["choices"]
590
+ elif isinstance(example["choices"], str) and example["choices"].startswith("["):
591
+ import ast
592
+ choices = ast.literal_eval(example["choices"]) if "[" in example["choices"] else example["choices"].split(",")
593
+ else:
594
+ choices = str(example["choices"]).split(",")
595
+
596
+ # Generate the prompt using prompt creator
597
+ prompt = self.prompt_creator.create_inference_prompt(question, choices)
598
+
599
+ # Start timing
600
+ start_time = time.time()
601
+
602
+ if stream:
603
+ # Use streaming generation
604
+ streamer = self.model_handler.generate_with_streaming(
605
+ prompt=prompt,
606
+ temperature=temperature,
607
+ max_tokens=max_tokens,
608
+ stream=True
609
+ )
610
+
611
+ # Collect output from streamer
612
+ raw_response = ""
613
+ print("Model response:")
614
+ for text_chunk in streamer:
615
+ print(text_chunk, end="", flush=True)
616
+ raw_response += text_chunk
617
+ print("\n")
618
+ else:
619
+ # Generate without streaming
620
+ raw_response = self.model_handler.generate_with_streaming(
621
+ prompt=prompt,
622
+ temperature=temperature,
623
+ max_tokens=max_tokens,
624
+ stream=False
625
+ )
626
+
627
+ response_time = time.time() - start_time
628
+
629
+ # Parse the response using the response parser
630
+ predicted_answer, reasoning = self.response_parser.parse(raw_response)
631
+
632
+ # Prepare results
633
+ result = {
634
+ "question": question,
635
+ "choices": choices,
636
+ "predicted_answer": predicted_answer,
637
+ "reasoning": reasoning,
638
+ "response_time": response_time,
639
+ "raw_response": raw_response,
640
+ "prompt_type": self.prompt_creator.prompt_type,
641
+ }
642
+
643
+ # Add task_id if available
644
+ if "task_id" in example:
645
+ result["task_id"] = example["task_id"]
646
+
647
+ # Calculate metrics if label is provided
648
+ if "answer" in example:
649
+ label = example["answer"]
650
+ result["correct_answer"] = label
651
+ result["is_correct"] = predicted_answer == label
652
+
653
+ # Calculate perplexity if requested
654
+ if hasattr(self.model_handler, "calculate_perplexity"):
655
+ perplexity = self.model_handler.calculate_perplexity(prompt, raw_response)
656
+ result["perplexity"] = perplexity
657
+
658
+ # Restore original prompt type if it was overridden
659
+ if original_prompt_type is not None:
660
+ self.prompt_creator.set_prompt_type(original_prompt_type)
661
+ self.response_parser = ResponseParser.from_prompt_type(original_prompt_type)
662
+
663
+ return result
664
+
665
+ def infer_batch(self, examples, temperature=0.7, max_tokens=1024, prompt_type=None, batch_size=4):
666
+ """Inference on a batch of examples"""
667
+ # Allow temporary override of prompt type
668
+ original_prompt_type = None
669
+ if prompt_type is not None:
670
+ original_prompt_type = self.prompt_creator.prompt_type
671
+ self.prompt_creator.set_prompt_type(prompt_type)
672
+ self.response_parser = ResponseParser.from_prompt_type(prompt_type)
673
+
674
+ # Prepare all prompts
675
+ prompts = []
676
+ metadata = []
677
 
678
+ for i, example in enumerate(examples):
679
+ # Extract data
680
+ question = example["question"]
681
 
682
+ # Handle different formats of choices
683
+ if isinstance(example["choices"], list):
684
+ choices = example["choices"]
685
+ elif isinstance(example["choices"], str) and example["choices"].startswith("["):
686
+ import ast
687
+ choices = ast.literal_eval(example["choices"]) if "[" in example["choices"] else example["choices"].split(",")
688
+ else:
689
+ choices = str(example["choices"]).split(",")
690
 
691
+ # Generate the prompt using prompt creator
692
+ prompt = self.prompt_creator.create_inference_prompt(question, choices)
693
+ prompts.append(prompt)
694
 
695
+ # Store metadata for later
696
+ meta = {
697
+ "question": question,
698
+ "choices": choices,
699
+ "index": i,
700
+ }
701
 
702
+ # Add label if available
703
+ if "answer" in example:
704
+ meta["label"] = example["answer"]
705
+
706
+ if "task_id" in example:
707
+ meta["task_id"] = example["task_id"]
708
+
709
+ metadata.append(meta)
710
+
711
+ # Process in batches
712
+ results = []
713
+ correct_count = 0
714
+ total_count = 0
715
+ perplexities = []
716
+
717
+ for i in range(0, len(prompts), batch_size):
718
+ batch_prompts = prompts[i:i+batch_size]
719
+ batch_meta = metadata[i:i+batch_size]
720
+
721
+ # Process batch
722
+ start_time = time.time()
723
+ batch_responses = []
724
+
725
+ for prompt in batch_prompts:
726
+ response = self.model_handler.generate_with_streaming(
727
+ prompt=prompt,
728
+ temperature=temperature,
729
+ max_tokens=max_tokens,
730
+ stream=False
731
+ )
732
+ batch_responses.append(response)
733
+
734
+ batch_time = time.time() - start_time
735
+
736
+ # Process each response in the batch
737
+ for j, (response, meta) in enumerate(zip(batch_responses, batch_meta)):
738
+ # Parse response
739
+ predicted_answer, reasoning = self.response_parser.parse(response)
740
+
741
+ # Create result
742
+ result = {
743
+ "question": meta["question"],
744
+ "choices": meta["choices"],
745
+ "predicted_answer": predicted_answer,
746
+ "reasoning": reasoning,
747
+ "raw_response": response,
748
+ "prompt_type": self.prompt_creator.prompt_type,
749
+ "response_time": batch_time / len(batch_prompts),
750
+ }
751
+
752
+ # Add task_id if available
753
+ if "task_id" in meta:
754
+ result["task_id"] = meta["task_id"]
755
+
756
+ # Add metrics if label available
757
+ if "label" in meta:
758
+ label = meta["label"]
759
+ result["correct_answer"] = label
760
+ result["is_correct"] = predicted_answer == label
761
+
762
+ # Update counts for accuracy
763
+ total_count += 1
764
+ if result["is_correct"]:
765
+ correct_count += 1
766
+
767
+ # Calculate perplexity if possible
768
+ if hasattr(self.model_handler, "calculate_perplexity"):
769
+ prompt = batch_prompts[j]
770
+ perplexity = self.model_handler.calculate_perplexity(prompt, response)
771
+ result["perplexity"] = perplexity
772
+ perplexities.append(perplexity)
773
+
774
+ results.append(result)
775
+
776
+ # Calculate aggregate metrics
777
+ summary_metrics = {}
778
+ if total_count > 0:
779
+ summary_metrics["accuracy"] = correct_count / total_count
780
+ summary_metrics["correct_count"] = correct_count
781
+ summary_metrics["total_count"] = total_count
782
+
783
+ if perplexities:
784
+ summary_metrics["avg_perplexity"] = sum(perplexities) / len(perplexities)
785
+ summary_metrics["min_perplexity"] = min(perplexities)
786
+ summary_metrics["max_perplexity"] = max(perplexities)
787
+
788
+ # Restore original prompt type if it was overridden
789
+ if original_prompt_type is not None:
790
+ self.prompt_creator.set_prompt_type(original_prompt_type)
791
+ self.response_parser = ResponseParser.from_prompt_type(original_prompt_type)
792
+
793
+ return results, summary_metrics
794
+
795
+ def evaluate_dataset(self, dataset, temperature=0.7, max_tokens=1024, num_examples=None,
796
+ verbose=True, prompt_type=None, batch_size=4, log_to_wandb=False):
797
+ """Inference on a whole dataset with metrics calculation"""
798
+ # Allow overriding the prompt type for this evaluation
799
+ original_prompt_type = self.prompt_creator.prompt_type
800
+ if prompt_type is not None:
801
+ self.prompt_creator.set_prompt_type(prompt_type)
802
+ self.response_parser = ResponseParser.from_prompt_type(prompt_type)
803
+
804
+ # Select subset if specified
805
+ if num_examples is not None:
806
+ dataset = dataset.select(range(min(num_examples, len(dataset))))
807
+
808
+ results = []
809
+ correct_count = 0
810
+ total_count = 0
811
+ perplexities = []
812
+
813
+ # Process examples in batches
814
+ for i in range(0, len(dataset), batch_size):
815
+ batch_examples = dataset[i:i+batch_size]
816
+
817
+ if verbose:
818
+ batch_desc = f"Batch {i//batch_size + 1}/{(len(dataset) + batch_size - 1) // batch_size}"
819
+ print(f"\nProcessing {batch_desc} with {len(batch_examples)} examples...")
820
+
821
+ # Infer batch
822
+ batch_results, batch_metrics = self.infer_batch(
823
+ examples=batch_examples,
824
+ temperature=temperature,
825
+ max_tokens=max_tokens,
826
+ batch_size=batch_size
827
+ )
828
+
829
+ # Update metrics
830
+ results.extend(batch_results)
831
+ if "correct_count" in batch_metrics:
832
+ correct_count += batch_metrics["correct_count"]
833
+ total_count += batch_metrics["total_count"]
834
+
835
+ if verbose:
836
+ batch_accuracy = batch_metrics["accuracy"]
837
+ overall_accuracy = correct_count / total_count
838
+ print(f"Batch accuracy: {batch_accuracy:.2%}, Overall: {overall_accuracy:.2%} ({correct_count}/{total_count})")
839
+
840
+ # Collect perplexities
841
+ if "avg_perplexity" in batch_metrics:
842
+ for result in batch_results:
843
+ if "perplexity" in result:
844
+ perplexities.append(result["perplexity"])
845
+
846
+ # Calculate final accuracy
847
+ accuracy = correct_count / total_count if total_count > 0 else 0.0
848
+
849
+ if verbose:
850
+ prompt_type_str = self.prompt_creator.prompt_type
851
+ print(f"\nFinal accuracy with {prompt_type_str} prompts: {accuracy:.2%} ({correct_count}/{total_count})")
852
+ if perplexities:
853
+ avg_perplexity = sum(perplexities) / len(perplexities)
854
+ print(f"Average perplexity: {avg_perplexity:.4f}")
855
+
856
+ # Prepare comprehensive summary
857
+ summary = {
858
+ "accuracy": accuracy,
859
+ "correct_count": correct_count,
860
+ "total_count": total_count,
861
+ "prompt_type": self.prompt_creator.prompt_type,
862
+ "results": results,
863
+ }
864
+
865
+ # Add perplexity metrics if available
866
+ if perplexities:
867
+ summary["avg_perplexity"] = sum(perplexities) / len(perplexities)
868
+ summary["min_perplexity"] = min(perplexities)
869
+ summary["max_perplexity"] = max(perplexities)
870
+
871
+ # Log results to wandb if requested
872
+ if log_to_wandb and wandb.run is not None:
873
+ metrics = {
874
+ "test/accuracy": accuracy,
875
+ "test/correct_count": correct_count,
876
+ "test/total_count": total_count,
877
+ }
878
+ if perplexities:
879
+ metrics["test/avg_perplexity"] = summary["avg_perplexity"]
880
+ metrics["test/min_perplexity"] = summary["min_perplexity"]
881
+ metrics["test/max_perplexity"] = summary["max_perplexity"]
882
+
883
+ wandb.log(metrics)
884
+
885
+ # Create a table of results for visualization if task_id exists
886
+ if "task_id" in dataset.features:
887
+ columns = ["task_id", "question", "correct_answer", "predicted_answer", "is_correct"]
888
+ table = wandb.Table(columns=columns)
889
+
890
+ for res in results[:min(100, len(results))]:
891
+ table.add_data(
892
+ res.get("task_id", "unknown"),
893
+ res["question"][:100] + "...",
894
+ res.get("correct_answer", ""),
895
+ res.get("predicted_answer", ""),
896
+ res.get("is_correct", False)
897
+ )
898
+
899
+ wandb.log({"test_samples": table})
900
+
901
+ # Restore original prompt type
902
+ self.prompt_creator.set_prompt_type(original_prompt_type)
903
+ self.response_parser = ResponseParser.from_prompt_type(original_prompt_type)
904
+
905
+ return summary
906
+
907
+ def save_results(self, results, output_dir="./results"):
908
+ """Save evaluation results to file"""
909
+ os.makedirs(output_dir, exist_ok=True)
910
+
911
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
912
+ results_file = os.path.join(output_dir, f"results_{timestamp}.json")
913
+
914
+ # Create serializable results
915
+ serializable_results = {
916
+ "accuracy": results.get("accuracy", 0.0),
917
+ "correct_count": results.get("correct_count", 0),
918
+ "total_count": results.get("total_count", 0),
919
+ "timestamp": timestamp,
920
+ "prompt_type": results.get("prompt_type", "unknown"),
921
+ }
922
+
923
+ # Add perplexity metrics if available
924
+ if "avg_perplexity" in results:
925
+ serializable_results["avg_perplexity"] = results["avg_perplexity"]
926
+ serializable_results["min_perplexity"] = results["min_perplexity"]
927
+ serializable_results["max_perplexity"] = results["max_perplexity"]
928
+
929
+ # Process individual results
930
+ serializable_results["individual_results"] = []
931
+ for result in results["results"]:
932
+ # Skip perplexity in individual results to save space
933
+ result_copy = result.copy()
934
+ if "perplexity" in result_copy:
935
+ del result_copy["perplexity"]
936
+
937
+ # Convert choices if needed
938
+ choices = result_copy["choices"]
939
+ if not isinstance(choices, list):
940
+ try:
941
+ import ast
942
+ result_copy["choices"] = ast.literal_eval(choices)
943
+ except (SyntaxError, ValueError):
944
+ pass
945
+
946
+ serializable_results["individual_results"].append(result_copy)
947
+
948
+ # Save to file
949
+ with open(results_file, "w") as f:
950
+ import json
951
+ json.dump(serializable_results, f, indent=2)
952
+
953
+ print(f"Results saved to {results_file}")
954
+ return results_file
955
+ ```
956
+
957
+ </details>
958
+
959
+ ## Quick Start
960
+
961
+ Here's a simple example of how to use the model:
962
+
963
+ ```python
964
+ from transformers import AutoModelForCausalLM, AutoTokenizer
965
+ import torch
966
+
967
+ # Load the model and tokenizer
968
+ model_id = "tuandunghcmut/Qwen25_Coder_MultipleChoice"
969
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
970
+ model = AutoModelForCausalLM.from_pretrained(
971
+ model_id,
972
+ torch_dtype=torch.bfloat16,
973
+ device_map="auto",
974
+ trust_remote_code=True
975
+ )
976
+
977
+ # Example question
978
+ question = "What is the correct way to open a file in Python for reading?"
979
+ choices = [
980
+ "open('file.txt', 'r')",
981
+ "file.open('file.txt', 'read')",
982
+ "read('file.txt')",
983
+ "File.open('file.txt')"
984
+ ]
985
+
986
+ # Format the prompt
987
+ prompt = f"""
988
+ QUESTION:
989
+ {question}
990
+
991
+ CHOICES:
992
+ {chr(65 + i)}. {choice}
993
+ for i, choice in enumerate(choices)}
994
+
995
+ Answer with a single letter from A through {chr(65 + len(choices) - 1)} without any additional explanation or commentary.
996
+ """
997
+
998
+ # Generate response
999
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
1000
+ outputs = model.generate(**inputs, max_new_tokens=10)
1001
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
1002
+
1003
+ print(f"Model's answer: {response}")
1004
+ ```
1005
+
1006
+ ## Advanced Usage
1007
+
1008
+ ### Using the MultipleChoiceTester Framework
1009
+
1010
+ For more advanced usage, you can use the provided `MultipleChoiceTester` framework:
1011
+
1012
+ ```python
1013
+ from save import QwenModelHandler, MultipleChoiceTester, PromptCreator
1014
+
1015
+ # Initialize the model handler
1016
+ model_handler = QwenModelHandler(
1017
+ model_name="tuandunghcmut/Qwen25_Coder_MultipleChoice",
1018
+ max_seq_length=2048,
1019
+ quantization="4bit",
1020
+ device_map="auto"
1021
+ )
1022
+
1023
+ # Create a prompt creator with YAML reasoning format
1024
+ prompt_creator = PromptCreator(PromptCreator.YAML_REASONING)
1025
+
1026
+ # Initialize the tester
1027
+ tester = MultipleChoiceTester(model_handler, prompt_creator=prompt_creator)
1028
+
1029
+ # Example question
1030
+ example = {
1031
+ "question": "What is the correct way to open a file in Python for reading?",
1032
+ "choices": [
1033
+ "open('file.txt', 'r')",
1034
+ "file.open('file.txt', 'read')",
1035
+ "read('file.txt')",
1036
+ "File.open('file.txt')"
1037
+ ],
1038
+ "answer": "A" # Optional ground truth
1039
+ }
1040
+
1041
+ # Get prediction with reasoning
1042
+ result = tester.infer_example(example, temperature=0.0001, stream=True)
1043
+ print(f"Predicted answer: {result['predicted_answer']}")
1044
+ print("Reasoning:")
1045
+ print(result['reasoning'])
1046
+ ```
1047
+
1048
+ ### Batch Processing
1049
+
1050
+ You can also process multiple questions in batches:
1051
+
1052
+ ```python
1053
+ # List of examples
1054
+ examples = [
1055
+ {
1056
+ "question": "What is the correct way to open a file in Python for reading?",
1057
+ "choices": ["open('file.txt', 'r')", "file.open('file.txt', 'read')", "read('file.txt')", "File.open('file.txt')"],
1058
+ "answer": "A"
1059
+ },
1060
+ # Add more examples...
1061
+ ]
1062
+
1063
+ # Process batch
1064
+ results, metrics = tester.infer_batch(examples, batch_size=4)
1065
+ print(f"Batch accuracy: {metrics['accuracy']:.2%}")
1066
+ ```
1067
+
1068
+ ### Streaming Inference
1069
+
1070
+ The model supports streaming inference, which provides real-time output as the model generates its response. This is particularly useful for interactive applications and when you want to see the reasoning process in real-time.
1071
+
1072
+ #### Basic Streaming Usage
1073
+
1074
+ Here's how to use streaming inference:
1075
+
1076
+ ```python
1077
+ # Initialize model handler and tester as before
1078
+ model_handler = QwenModelHandler(
1079
+ model_name="tuandunghcmut/Qwen25_Coder_MultipleChoice",
1080
+ max_seq_length=2048
1081
+ )
1082
+ tester = MultipleChoiceTester(model_handler)
1083
+
1084
+ # Example with streaming
1085
+ example = {
1086
+ "question": "Which Python method is used to remove whitespace from both ends of a string?",
1087
+ "choices": [
1088
+ "strip()",
1089
+ "trim()",
1090
+ "clean()",
1091
+ "remove_whitespace()"
1092
+ ],
1093
+ "answer": "A"
1094
+ }
1095
+
1096
+ # Enable streaming with stream=True
1097
+ result = tester.infer_example(
1098
+ example,
1099
+ temperature=0.0001,
1100
+ max_tokens=1024,
1101
+ stream=True # Enable streaming
1102
+ )
1103
+
1104
+ # The output will be printed in real-time as the model generates it
1105
+ # You can also access the complete response after generation
1106
+ print("\nFinal result:")
1107
+ print(f"Predicted answer: {result['predicted_answer']}")
1108
+ print("Complete reasoning:")
1109
+ print(result['reasoning'])
1110
+ ```
1111
+
1112
+ #### Advanced Streaming Patterns
1113
+
1114
+ ##### 1. Custom Stream Processing
1115
+
1116
+ You can process the streamed output in custom ways:
1117
+
1118
+ ```python
1119
+ def process_stream(streamer):
1120
+ """Custom stream processing function"""
1121
+ collected_text = ""
1122
+ for chunk in streamer:
1123
+ # Process each chunk as it arrives
1124
+ collected_text += chunk
1125
+ # You can do custom processing here
1126
+ # For example, parse partial YAML, update UI, etc.
1127
+ yield chunk, collected_text
1128
+
1129
+ # Use custom stream processing
1130
+ result = tester.infer_example(
1131
+ example,
1132
+ temperature=0.0001,
1133
+ stream=True
1134
+ )
1135
+
1136
+ # Process the stream with custom logic
1137
+ for chunk, full_text in process_stream(result['stream']):
1138
+ # Do something with each chunk
1139
+ print(f"Chunk: {chunk}")
1140
+ print(f"Full text so far: {full_text}")
1141
+ ```
1142
+
1143
+ ##### 2. YAML Streaming with Real-time Parsing
1144
+
1145
+ When using YAML reasoning format, you can parse the output as it streams:
1146
+
1147
+ ```python
1148
+ import yaml
1149
+ from io import StringIO
1150
+
1151
+ def parse_yaml_stream(streamer):
1152
+ """Parse YAML content as it streams"""
1153
+ buffer = StringIO()
1154
+ for chunk in streamer:
1155
+ buffer.write(chunk)
1156
+ try:
1157
+ # Try to parse the current buffer as YAML
1158
+ yaml_content = yaml.safe_load(buffer.getvalue())
1159
+ if yaml_content:
1160
+ yield chunk, yaml_content
1161
+ except yaml.YAMLError:
1162
+ # Not enough content for valid YAML yet
1163
+ continue
1164
+
1165
+ # Use YAML streaming with parsing
1166
+ result = tester.infer_example(
1167
+ example,
1168
+ temperature=0.0001,
1169
+ prompt_type=PromptCreator.YAML_REASONING,
1170
+ stream=True
1171
+ )
1172
+
1173
+ # Process YAML content as it streams
1174
+ for chunk, yaml_content in parse_yaml_stream(result['stream']):
1175
+ if isinstance(yaml_content, dict):
1176
+ # Access YAML fields as they become available
1177
+ if 'understanding' in yaml_content:
1178
+ print(f"Understanding: {yaml_content['understanding']}")
1179
+ if 'reasoning' in yaml_content:
1180
+ print(f"Reasoning: {yaml_content['reasoning']}")
1181
+ if 'answer' in yaml_content:
1182
+ print(f"Answer: {yaml_content['answer']}")
1183
+ ```
1184
+
1185
+ ##### 3. Streaming with Progress Tracking
1186
+
1187
+ You can track generation progress and timing:
1188
+
1189
+ ```python
1190
+ import time
1191
+
1192
+ def stream_with_progress(streamer):
1193
+ """Stream with progress tracking"""
1194
+ start_time = time.time()
1195
+ tokens_generated = 0
1196
+
1197
+ for chunk in streamer:
1198
+ tokens_generated += len(chunk.split())
1199
+ elapsed = time.time() - start_time
1200
+ tokens_per_second = tokens_generated / elapsed if elapsed > 0 else 0
1201
+
1202
+ yield {
1203
+ 'chunk': chunk,
1204
+ 'tokens': tokens_generated,
1205
+ 'tokens_per_second': tokens_per_second,
1206
+ 'elapsed': elapsed
1207
+ }
1208
+
1209
+ # Use streaming with progress tracking
1210
+ result = tester.infer_example(
1211
+ example,
1212
+ temperature=0.0001,
1213
+ stream=True
1214
+ )
1215
+
1216
+ for progress in stream_with_progress(result['stream']):
1217
+ print(f"Generated {progress['tokens']} tokens "
1218
+ f"({progress['tokens_per_second']:.2f} tokens/sec)")
1219
+ print(f"Chunk: {progress['chunk']}")
1220
+ ```
1221
+
1222
+ #### Implementation Details
1223
+
1224
+ The streaming implementation uses Unsloth's optimized inference with the following key features:
1225
+
1226
+ 1. **Efficient Token Generation**
1227
+ - Uses Unsloth's `FastLanguageModel` for optimized inference
1228
+ - Implements streaming using `TextIteratorStreamer`
1229
+ - Supports both greedy and temperature-based sampling
1230
+
1231
+ 2. **Memory Management**
1232
+ - Streams tokens without storing the entire response in memory
1233
+ - Efficiently handles long responses
1234
+ - Supports batch processing with streaming
1235
+
1236
+ 3. **Performance Optimizations**
1237
+ - Uses `use_cache=True` for faster generation
1238
+ - Implements `min_p` sampling for better quality
1239
+ - Supports 4-bit quantization for reduced memory usage
1240
+
1241
+ 4. **Error Handling**
1242
+ - Gracefully handles streaming interruptions
1243
+ - Provides partial results if generation is interrupted
1244
+ - Maintains context for resumed generation
1245
+
1246
+ The streaming output will show the model's reasoning process in real-time, including:
1247
+ - Understanding of the question
1248
+ - Analysis of each option
1249
+ - Step-by-step reasoning
1250
+ - Final conclusion
1251
+ - Answer selection
1252
+
1253
+ This is particularly useful for:
1254
+ - Debugging model behavior
1255
+ - Creating interactive demos
1256
+ - Understanding the model's reasoning process
1257
+ - Providing immediate feedback to users
1258
+ - Building real-time applications
1259
+
1260
+ ## Model Features
1261
+
1262
+ - **YAML-Based Reasoning**: The model provides structured reasoning in YAML format
1263
+ - **Multiple Prompt Types**: Supports both basic and YAML-formatted reasoning prompts
1264
+ - **Batch Processing**: Efficiently process multiple questions at once
1265
+ - **Performance Metrics**: Tracks accuracy, perplexity, and response times
1266
+ - **Streaming Support**: Real-time output streaming for interactive use
1267
+
1268
+ ## Examples
1269
+
1270
+ Check out the [example notebook](https://colab.research.google.com/drive/1YOUR_NOTEBOOK_ID) for more detailed usage examples and demonstrations.
1271
+
1272
+ ## Contributing
1273
+
1274
+ Contributions are welcome! Please feel free to submit a Pull Request.
1275
+
1276
+ ## License
1277
+
1278
+ This project is licensed under the MIT License - see the LICENSE file for details.
1279
+ # Small-Qwen-Coding-Multiple-Choice
1280
+
1281
+
1282
+
1283
+
1284
+
1285
+
1286
+
1287
+
1288
+
1289
+
1290
+
1291
+ # Qwen25_Coder_MultipleChoice_v4 (Automatically created validation in training process)
1292
+
1293
+ This model is a fine-tuned version of [unsloth/qwen2.5-coder-1.5b-instruct-bnb-4bit](https://huggingface.co/unsloth/qwen2.5-coder-1.5b-instruct-bnb-4bit).
1294
+ Wandb Link for experiment tracking and result reproducing: [wandb link](https://wandb.ai/tuandung/Qwen2.5-Coder-1.5B-Instruct-LoRA-Training/runs/k00n15pk?nw=nwuserdungvo20csehcmut]
1295
+ It achieves the following results on the evaluation set:
1296
+ - eval_loss: 0.5559
1297
+ - eval_runtime: 23.4917
1298
+ - eval_samples_per_second: 3.831
1299
+ - eval_steps_per_second: 0.511
1300
+ - epoch: 2.7687
1301
+ - step: 390
1302
 
1303
  ### Training hyperparameters
1304