othdu commited on
Commit
95d2da1
·
verified ·
1 Parent(s): e8002d0

Upload 5 files

Browse files
Files changed (2) hide show
  1. src/inference/model.py +236 -170
  2. src/training/finetune.py +287 -0
src/inference/model.py CHANGED
@@ -1,171 +1,237 @@
1
- import os
2
- import json
3
- import torch
4
- import logging
5
- from typing import Dict, Any, Optional
6
- from transformers import AutoModelForCausalLM, AutoTokenizer
7
- from peft import PeftModel
8
- import time
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
- class AgriQAAssistant:
13
-
14
- def __init__(self, model_path: str = "nada013/agriqa-assistant"):
15
- self.model_path = model_path
16
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
- self.model = None
18
- self.tokenizer = None
19
- self.config = None
20
-
21
- self.load_model()
22
-
23
- def load_model(self):
24
-
25
- logger.info(f"Loading model from Hugging Face: {self.model_path}")
26
-
27
- try:
28
- # Configuration for the uploaded model
29
- self.config = {
30
- 'base_model': 'Qwen/Qwen1.5-1.8B-Chat',
31
- 'generation_config': {
32
- 'max_new_tokens': 512, # Increased for complete responses
33
- 'do_sample': True,
34
- 'temperature': 0.3, # Lower temperature for more consistent, structured responses
35
- 'top_p': 0.85, # Slightly lower for more focused sampling
36
- 'top_k': 40, # Lower for more focused responses
37
- 'repetition_penalty': 1.2, # Higher penalty to avoid repetition
38
- 'length_penalty': 1.1, # Encourage slightly longer, detailed responses
39
- 'no_repeat_ngram_size': 3 # Avoid repeating 3-grams
40
- }
41
- }
42
-
43
- # Load tokenizer from base model
44
- logger.info("Loading tokenizer from base model...")
45
- self.tokenizer = AutoTokenizer.from_pretrained(
46
- self.config['base_model'],
47
- trust_remote_code=True
48
- )
49
-
50
- if self.tokenizer.pad_token is None:
51
- self.tokenizer.pad_token = self.tokenizer.eos_token
52
-
53
- # Load base model first
54
- logger.info("Loading base model...")
55
- base_model = AutoModelForCausalLM.from_pretrained(
56
- self.config['base_model'],
57
- torch_dtype=torch.float16,
58
- device_map="auto",
59
- trust_remote_code=True,
60
-
61
- )
62
-
63
- # Load the LoRA adapter from Hugging Face
64
- logger.info("Loading LoRA adapter from Hugging Face...")
65
- self.model = PeftModel.from_pretrained(
66
- base_model,
67
- self.model_path,
68
- torch_dtype=torch.float16,
69
- device_map="auto",
70
-
71
- )
72
-
73
- # Set to evaluation mode
74
- self.model.eval()
75
-
76
- logger.info("Model loaded successfully from Hugging Face")
77
-
78
- except Exception as e:
79
- logger.error(f"Failed to load model: {e}")
80
- raise
81
-
82
- def format_prompt(self, question: str) -> str:
83
- """Format the question for the model using proper format."""
84
- # Use the tokenizer's chat template if available
85
- if hasattr(self.tokenizer, 'apply_chat_template'):
86
- try:
87
- messages = [
88
- {"role": "system", "content": "You are AgriQA, an agricultural expert assistant. Your job is to answer farmers' questions with clear, practical, and accurate steps they can directly apply in the field.\n\nWhen answering:\n1. Start with a short, direct answer to the question.\n2. Provide a numbered step-by-step solution.\n3. Include specific details like measurements, quantities, time intervals, and names of products or tools.\n4. Mention any safety precautions if needed.\n5. End with an extra tip or follow-up advice.\n\nFormat Example:\nQuestion: How to control aphid infestation in mustard crops?\nAnswer:\n1. Inspect the crop daily to detect early signs of infestation.\n2. Spray Imidacloprid 17.8% SL at a rate of 0.3 ml per liter of water.\n3. Ensure thorough coverage, especially under the leaves.\n4. Remove surrounding weeds that may host aphids.\n5. Repeat spraying after 7 days if infestation continues.\nNote: Wear gloves and a mask during spraying.\n\nAlways keep your language clear, concise, and easy to understand."},
89
- {"role": "user", "content": question}
90
- ]
91
- formatted_prompt = self.tokenizer.apply_chat_template(
92
- messages,
93
- tokenize=False,
94
- add_generation_prompt=True
95
- )
96
- return formatted_prompt
97
- except Exception as e:
98
- logger.warning(f"Failed to use chat template: {e}. Using fallback format.")
99
-
100
- # Fallback format for Qwen1.5-Chat
101
- system_prompt = "You are AgriQA, an agricultural expert assistant. Your job is to answer farmers' questions with clear, practical, and accurate steps they can directly apply in the field.\n\nWhen answering:\n1. Start with a short, direct answer to the question.\n2. Provide a numbered step-by-step solution.\n3. Include specific details like measurements, quantities, time intervals, and names of products or tools.\n4. Mention any safety precautions if needed.\n5. End with an extra tip or follow-up advice.\n\nFormat Example:\nQuestion: How to control aphid infestation in mustard crops?\nAnswer:\n1. Inspect the crop daily to detect early signs of infestation.\n2. Spray Imidacloprid 17.8% SL at a rate of 0.3 ml per liter of water.\n3. Ensure thorough coverage, especially under the leaves.\n4. Remove surrounding weeds that may host aphids.\n5. Repeat spraying after 7 days if infestation continues.\nNote: Wear gloves and a mask during spraying.\n\nAlways keep your language clear, concise, and easy to understand."
102
- formatted_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n"
103
- return formatted_prompt
104
-
105
- def generate_response(self, question: str, max_length: Optional[int] = None) -> Dict[str, Any]:
106
- start_time = time.time()
107
-
108
- try:
109
- # Format the prompt
110
- prompt = self.format_prompt(question)
111
-
112
- # Tokenize input
113
- inputs = self.tokenizer(
114
- prompt,
115
- return_tensors="pt",
116
- truncation=True,
117
- max_length=2048
118
- ).to(self.device)
119
-
120
- # Generation parameters
121
- gen_config = self.config['generation_config'].copy()
122
- if max_length:
123
- gen_config['max_new_tokens'] = max_length
124
-
125
- # Generate response
126
- with torch.no_grad():
127
- outputs = self.model.generate(
128
- **inputs,
129
- **gen_config,
130
- pad_token_id=self.tokenizer.eos_token_id
131
- )
132
-
133
- # Decode response
134
- response = self.tokenizer.decode(
135
- outputs[0][inputs['input_ids'].shape[1]:],
136
- skip_special_tokens=True
137
- ).strip()
138
-
139
- # Calculate response time
140
- response_time = time.time() - start_time
141
-
142
- return {
143
- 'answer': response,
144
- 'response_time': response_time,
145
- 'model_info': {
146
- 'model_name': 'agriqa-assistant',
147
- 'model_source': 'Hugging Face',
148
- 'model_path': self.model_path,
149
- 'base_model': self.config['base_model']
150
- }
151
- }
152
-
153
- except Exception as e:
154
- logger.error(f"Error generating response: {e}")
155
- return {
156
- 'answer': "I apologize, but I encountered an error while processing your question. Please try again.",
157
- 'confidence': 0.0,
158
- 'response_time': time.time() - start_time,
159
- 'error': str(e)
160
- }
161
-
162
- def get_model_info(self) -> Dict[str, Any]:
163
- """Get information about the loaded model."""
164
- return {
165
- 'model_name': 'agriqa-assistant',
166
- 'model_source': 'Hugging Face',
167
- 'model_path': self.model_path,
168
- 'base_model': self.config['base_model'],
169
- 'device': self.device,
170
- 'generation_config': self.config['generation_config']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  }
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import logging
5
+ from typing import Dict, Any, Optional
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ from peft import PeftModel
8
+ import time
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class AgriQAAssistant:
13
+
14
+ def __init__(self, model_path: str = "nada013/agriqa-assistant"):
15
+ self.model_path = model_path
16
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ self.model = None
18
+ self.tokenizer = None
19
+ self.config = None
20
+
21
+ self.load_model()
22
+
23
+ def load_model(self):
24
+
25
+ logger.info(f"Loading model from Hugging Face: {self.model_path}")
26
+
27
+ try:
28
+ # Configuration for the uploaded model
29
+ self.config = {
30
+ 'base_model': 'Qwen/Qwen1.5-1.8B-Chat',
31
+ 'generation_config': {
32
+ 'max_new_tokens': 512, # Increased for complete responses
33
+ 'do_sample': True,
34
+ 'temperature': 0.3, # Lower temperature for more consistent, structured responses
35
+ 'top_p': 0.85, # Slightly lower for more focused sampling
36
+ 'top_k': 40, # Lower for more focused responses
37
+ 'repetition_penalty': 1.2, # Higher penalty to avoid repetition
38
+ 'length_penalty': 1.1, # Encourage slightly longer, detailed responses
39
+ 'no_repeat_ngram_size': 3 # Avoid repeating 3-grams
40
+ }
41
+ }
42
+
43
+ # Load tokenizer from base model
44
+ logger.info("Loading tokenizer from base model...")
45
+ self.tokenizer = AutoTokenizer.from_pretrained(
46
+ self.config['base_model'],
47
+ trust_remote_code=True
48
+ )
49
+
50
+ if self.tokenizer.pad_token is None:
51
+ self.tokenizer.pad_token = self.tokenizer.eos_token
52
+
53
+ <<<<<<< HEAD
54
+ # Try to load the model directly from Hugging Face first
55
+ try:
56
+ logger.info("Attempting to load model directly from Hugging Face...")
57
+ self.model = AutoModelForCausalLM.from_pretrained(
58
+ self.model_path,
59
+ torch_dtype=torch.float16,
60
+ device_map="auto",
61
+ trust_remote_code=True,
62
+ attn_implementation="eager",
63
+ use_flash_attention_2=False
64
+ )
65
+ logger.info("Model loaded directly from Hugging Face successfully")
66
+ except Exception as direct_load_error:
67
+ logger.info(f"Direct loading failed: {direct_load_error}")
68
+ logger.info("Falling back to base model + LoRA adapter approach...")
69
+
70
+ # Load base model first
71
+ logger.info("Loading base model...")
72
+ base_model = AutoModelForCausalLM.from_pretrained(
73
+ self.config['base_model'],
74
+ torch_dtype=torch.float16,
75
+ device_map="auto",
76
+ trust_remote_code=True,
77
+ attn_implementation="eager",
78
+ use_flash_attention_2=False
79
+ )
80
+
81
+ # Try to load the LoRA adapter
82
+ try:
83
+ logger.info("Loading LoRA adapter from Hugging Face...")
84
+ self.model = PeftModel.from_pretrained(
85
+ base_model,
86
+ self.model_path,
87
+ torch_dtype=torch.float16,
88
+ device_map="auto",
89
+ attn_implementation="eager",
90
+ use_flash_attention_2=False
91
+ )
92
+ logger.info("LoRA adapter loaded successfully")
93
+ except Exception as lora_error:
94
+ logger.warning(f"LoRA adapter loading failed: {lora_error}")
95
+ logger.info("Using base model without LoRA adapter...")
96
+ self.model = base_model
97
+ =======
98
+ # Load base model first
99
+ logger.info("Loading base model...")
100
+ base_model = AutoModelForCausalLM.from_pretrained(
101
+ self.config['base_model'],
102
+ torch_dtype=torch.float16,
103
+ device_map="auto",
104
+ trust_remote_code=True,
105
+
106
+ )
107
+
108
+ # Load the LoRA adapter from Hugging Face
109
+ logger.info("Loading LoRA adapter from Hugging Face...")
110
+ self.model = PeftModel.from_pretrained(
111
+ base_model,
112
+ self.model_path,
113
+ torch_dtype=torch.float16,
114
+ device_map="auto",
115
+
116
+ )
117
+ >>>>>>> 3b1d9d4700da14631c2d7f96e38c9e460a1a4dd0
118
+
119
+ # Set to evaluation mode
120
+ self.model.eval()
121
+
122
+ <<<<<<< HEAD
123
+ # Log model information
124
+ logger.info(f"Model loaded successfully from Hugging Face")
125
+ logger.info(f"Model type: {type(self.model).__name__}")
126
+ logger.info(f"Device: {self.device}")
127
+
128
+ # Check if it's a PeftModel
129
+ if hasattr(self.model, 'peft_config'):
130
+ logger.info("LoRA adapter configuration:")
131
+ for adapter_name, config in self.model.peft_config.items():
132
+ logger.info(f" - {adapter_name}: {config.target_modules}")
133
+
134
+ except Exception as e:
135
+ logger.error(f"Failed to load model: {e}")
136
+ logger.error(f"Model path: {self.model_path}")
137
+ logger.error(f"Base model: {self.config['base_model']}")
138
+ import traceback
139
+ logger.error(f"Traceback: {traceback.format_exc()}")
140
+ =======
141
+ logger.info("Model loaded successfully from Hugging Face")
142
+
143
+ except Exception as e:
144
+ logger.error(f"Failed to load model: {e}")
145
+ >>>>>>> 3b1d9d4700da14631c2d7f96e38c9e460a1a4dd0
146
+ raise
147
+
148
+ def format_prompt(self, question: str) -> str:
149
+ """Format the question for the model using proper format."""
150
+ # Use the tokenizer's chat template if available
151
+ if hasattr(self.tokenizer, 'apply_chat_template'):
152
+ try:
153
+ messages = [
154
+ {"role": "system", "content": "You are AgriQA, an agricultural expert assistant. Your job is to answer farmers' questions with clear, practical, and accurate steps they can directly apply in the field.\n\nWhen answering:\n1. Start with a short, direct answer to the question.\n2. Provide a numbered step-by-step solution.\n3. Include specific details like measurements, quantities, time intervals, and names of products or tools.\n4. Mention any safety precautions if needed.\n5. End with an extra tip or follow-up advice.\n\nFormat Example:\nQuestion: How to control aphid infestation in mustard crops?\nAnswer:\n1. Inspect the crop daily to detect early signs of infestation.\n2. Spray Imidacloprid 17.8% SL at a rate of 0.3 ml per liter of water.\n3. Ensure thorough coverage, especially under the leaves.\n4. Remove surrounding weeds that may host aphids.\n5. Repeat spraying after 7 days if infestation continues.\nNote: Wear gloves and a mask during spraying.\n\nAlways keep your language clear, concise, and easy to understand."},
155
+ {"role": "user", "content": question}
156
+ ]
157
+ formatted_prompt = self.tokenizer.apply_chat_template(
158
+ messages,
159
+ tokenize=False,
160
+ add_generation_prompt=True
161
+ )
162
+ return formatted_prompt
163
+ except Exception as e:
164
+ logger.warning(f"Failed to use chat template: {e}. Using fallback format.")
165
+
166
+ # Fallback format for Qwen1.5-Chat
167
+ system_prompt = "You are AgriQA, an agricultural expert assistant. Your job is to answer farmers' questions with clear, practical, and accurate steps they can directly apply in the field.\n\nWhen answering:\n1. Start with a short, direct answer to the question.\n2. Provide a numbered step-by-step solution.\n3. Include specific details like measurements, quantities, time intervals, and names of products or tools.\n4. Mention any safety precautions if needed.\n5. End with an extra tip or follow-up advice.\n\nFormat Example:\nQuestion: How to control aphid infestation in mustard crops?\nAnswer:\n1. Inspect the crop daily to detect early signs of infestation.\n2. Spray Imidacloprid 17.8% SL at a rate of 0.3 ml per liter of water.\n3. Ensure thorough coverage, especially under the leaves.\n4. Remove surrounding weeds that may host aphids.\n5. Repeat spraying after 7 days if infestation continues.\nNote: Wear gloves and a mask during spraying.\n\nAlways keep your language clear, concise, and easy to understand."
168
+ formatted_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n"
169
+ return formatted_prompt
170
+
171
+ def generate_response(self, question: str, max_length: Optional[int] = None) -> Dict[str, Any]:
172
+ start_time = time.time()
173
+
174
+ try:
175
+ # Format the prompt
176
+ prompt = self.format_prompt(question)
177
+
178
+ # Tokenize input
179
+ inputs = self.tokenizer(
180
+ prompt,
181
+ return_tensors="pt",
182
+ truncation=True,
183
+ max_length=2048
184
+ ).to(self.device)
185
+
186
+ # Generation parameters
187
+ gen_config = self.config['generation_config'].copy()
188
+ if max_length:
189
+ gen_config['max_new_tokens'] = max_length
190
+
191
+ # Generate response
192
+ with torch.no_grad():
193
+ outputs = self.model.generate(
194
+ **inputs,
195
+ **gen_config,
196
+ pad_token_id=self.tokenizer.eos_token_id
197
+ )
198
+
199
+ # Decode response
200
+ response = self.tokenizer.decode(
201
+ outputs[0][inputs['input_ids'].shape[1]:],
202
+ skip_special_tokens=True
203
+ ).strip()
204
+
205
+ # Calculate response time
206
+ response_time = time.time() - start_time
207
+
208
+ return {
209
+ 'answer': response,
210
+ 'response_time': response_time,
211
+ 'model_info': {
212
+ 'model_name': 'agriqa-assistant',
213
+ 'model_source': 'Hugging Face',
214
+ 'model_path': self.model_path,
215
+ 'base_model': self.config['base_model']
216
+ }
217
+ }
218
+
219
+ except Exception as e:
220
+ logger.error(f"Error generating response: {e}")
221
+ return {
222
+ 'answer': "I apologize, but I encountered an error while processing your question. Please try again.",
223
+ 'confidence': 0.0,
224
+ 'response_time': time.time() - start_time,
225
+ 'error': str(e)
226
+ }
227
+
228
+ def get_model_info(self) -> Dict[str, Any]:
229
+ """Get information about the loaded model."""
230
+ return {
231
+ 'model_name': 'agriqa-assistant',
232
+ 'model_source': 'Hugging Face',
233
+ 'model_path': self.model_path,
234
+ 'base_model': self.config['base_model'],
235
+ 'device': self.device,
236
+ 'generation_config': self.config['generation_config']
237
  }
src/training/finetune.py CHANGED
@@ -1,3 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import sys
3
  import yaml
@@ -282,4 +568,5 @@ def main():
282
  fine_tuner.run()
283
 
284
  if __name__ == "__main__":
 
285
  main()
 
1
+ <<<<<<< HEAD
2
+ import os
3
+ import sys
4
+ import yaml
5
+ import argparse
6
+ import logging
7
+ from typing import Dict, Any
8
+ import torch
9
+ from transformers import (
10
+ AutoModelForCausalLM,
11
+ AutoTokenizer,
12
+ TrainingArguments,
13
+ Trainer,
14
+ DataCollatorForLanguageModeling,
15
+ EarlyStoppingCallback,
16
+ BitsAndBytesConfig
17
+ )
18
+ from peft import (
19
+ LoraConfig,
20
+ get_peft_model,
21
+ prepare_model_for_kbit_training,
22
+ TaskType
23
+ )
24
+ from datasets import Dataset
25
+ from tqdm import tqdm
26
+
27
+ # Setup logging
28
+ logging.basicConfig(level=logging.INFO)
29
+ logger = logging.getLogger(__name__)
30
+
31
+ class AgriQAFineTuner:
32
+
33
+ def __init__(self, config_path: str):
34
+ self.config = self.load_config(config_path) # load the config file
35
+ self.setup_environment()
36
+
37
+ def load_config(self, config_path: str) -> Dict[str, Any]:
38
+ with open(config_path, 'r') as f:
39
+ config = yaml.safe_load(f)
40
+ return config
41
+
42
+ def setup_environment(self) -> None:
43
+ # Set device
44
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ logger.info(f"Using device: {self.device}")
46
+
47
+ # Create output directory
48
+ os.makedirs(self.config['training']['output_dir'], exist_ok=True)
49
+
50
+ def load_model_and_tokenizer(self):
51
+ logger.info(f"Loading model: {self.config['model']['base_model']}")
52
+
53
+ # Load tokenizer
54
+ self.tokenizer = AutoTokenizer.from_pretrained(
55
+ self.config['model']['base_model'],
56
+ trust_remote_code=self.config['model']['trust_remote_code']
57
+ )
58
+
59
+ # Add padding token if not present
60
+ if self.tokenizer.pad_token is None:
61
+ self.tokenizer.pad_token = self.tokenizer.eos_token
62
+
63
+ # Load model with quantization if specified
64
+ if self.config['hardware']['use_4bit']:
65
+ logger.info("Loading model with 4-bit quantization")
66
+ quantization_config = BitsAndBytesConfig(
67
+ load_in_4bit=True,
68
+ bnb_4bit_compute_dtype=torch.float16,
69
+ bnb_4bit_quant_type=self.config['hardware']['bnb_4bit_quant_type'],
70
+ bnb_4bit_use_double_quant=self.config['hardware']['bnb_4bit_use_double_quant'],
71
+ )
72
+ self.model = AutoModelForCausalLM.from_pretrained(
73
+ self.config['model']['base_model'],
74
+ quantization_config=quantization_config,
75
+ device_map=self.config['hardware']['device_map'],
76
+ trust_remote_code=self.config['model']['trust_remote_code']
77
+ )
78
+ else:
79
+ self.model = AutoModelForCausalLM.from_pretrained(
80
+ self.config['model']['base_model'],
81
+ device_map=self.config['hardware']['device_map'],
82
+ trust_remote_code=self.config['model']['trust_remote_code']
83
+ )
84
+
85
+ # Prepare model for k-bit training
86
+ if self.config['hardware']['use_4bit']:
87
+ self.model = prepare_model_for_kbit_training(self.model)
88
+
89
+ logger.info("Model and tokenizer loaded successfully")
90
+
91
+ def setup_lora(self):
92
+ # Apply LoRA configuration
93
+ logger.info("Setting up LoRA configuration")
94
+ lora_config = LoraConfig(
95
+ r=self.config['lora']['r'],
96
+ lora_alpha=self.config['lora']['lora_alpha'],
97
+ target_modules=self.config['lora']['target_modules'],
98
+ lora_dropout=self.config['lora']['lora_dropout'],
99
+ bias=self.config['lora']['bias'],
100
+ task_type=self.config['lora']['task_type'],
101
+ )
102
+
103
+ # Enable gradient checkpointing for memory optimization
104
+ if self.config['training']['gradient_checkpointing']:
105
+ self.model.gradient_checkpointing_enable()
106
+ logger.info("Gradient checkpointing enabled for memory optimization")
107
+
108
+ # Apply LoRA
109
+ self.model = get_peft_model(self.model, lora_config)
110
+ self.model.print_trainable_parameters()
111
+
112
+ logger.info("LoRA configuration applied successfully")
113
+
114
+ def load_dataset(self):
115
+ """Load the tokenized datasets."""
116
+ logger.info("Loading dataset")
117
+
118
+ # Load pre-tokenized datasets
119
+ logger.info("Loading pre-tokenized datasets...")
120
+ train_dataset = Dataset.load_from_disk(os.path.join(self.config['data']['tokenized_dir'], "train"))
121
+ val_dataset = Dataset.load_from_disk(os.path.join(self.config['data']['tokenized_dir'], "validation"))
122
+
123
+ # Limit samples if specified
124
+ max_samples = self.config['data'].get('max_samples', None)
125
+ if max_samples:
126
+ logger.info(f"Limiting training samples to {max_samples}")
127
+ train_dataset = train_dataset.select(range(min(max_samples, len(train_dataset))))
128
+ val_dataset = val_dataset.select(range(min(max_samples // 10, len(val_dataset)))) # 10% for validation
129
+
130
+ logger.info(f"Loaded tokenized training samples: {len(train_dataset)}")
131
+ logger.info(f"Loaded tokenized validation samples: {len(val_dataset)}")
132
+
133
+ return train_dataset, val_dataset
134
+
135
+ def setup_training(self, train_dataset, val_dataset):
136
+ logger.info("Setting up training configuration")
137
+
138
+ # Convert numeric values from config
139
+ def convert_numeric(value):
140
+ if isinstance(value, str):
141
+ try:
142
+ return float(value)
143
+ except ValueError:
144
+ return value
145
+ return value
146
+
147
+ # Training arguments with memory optimizations
148
+ training_args = TrainingArguments(
149
+ output_dir=self.config['training']['output_dir'],
150
+ num_train_epochs=convert_numeric(self.config['training']['num_train_epochs']),
151
+ per_device_train_batch_size=convert_numeric(self.config['training']['per_device_train_batch_size']),
152
+ per_device_eval_batch_size=convert_numeric(self.config['training']['per_device_eval_batch_size']),
153
+ gradient_accumulation_steps=convert_numeric(self.config['training']['gradient_accumulation_steps']),
154
+ learning_rate=convert_numeric(self.config['training']['learning_rate']),
155
+ weight_decay=convert_numeric(self.config['training']['weight_decay']),
156
+ warmup_steps=convert_numeric(self.config['training']['warmup_steps']),
157
+ logging_steps=convert_numeric(self.config['training']['logging_steps']),
158
+ save_steps=convert_numeric(self.config['training']['save_steps']),
159
+ eval_steps=convert_numeric(self.config['training']['eval_steps']),
160
+ evaluation_strategy=self.config['training']['evaluation_strategy'],
161
+ save_strategy=self.config['training']['save_strategy'],
162
+ save_total_limit=convert_numeric(self.config['training']['save_total_limit']),
163
+ load_best_model_at_end=self.config['training']['load_best_model_at_end'],
164
+ metric_for_best_model=self.config['training']['metric_for_best_model'],
165
+ greater_is_better=self.config['training']['greater_is_better'],
166
+ fp16=self.config['training']['fp16'],
167
+ dataloader_num_workers=convert_numeric(self.config['training']['dataloader_num_workers']),
168
+ gradient_checkpointing=self.config['training']['gradient_checkpointing'],
169
+ max_grad_norm=convert_numeric(self.config['training']['max_grad_norm']),
170
+ report_to=self.config['logging']['report_to'],
171
+ run_name=self.config['logging']['run_name'],
172
+ log_level=self.config['logging']['log_level'],
173
+ # Memory optimization settings
174
+ dataloader_drop_last=True,
175
+ group_by_length=True,
176
+ length_column_name="length",
177
+ # Disable features that use more memory
178
+ ddp_find_unused_parameters=False,
179
+ dataloader_pin_memory=False,
180
+ # Additional memory optimizations
181
+ optim="adamw_torch_fused", # Use fused optimizer for speed
182
+ torch_compile=False, # Disable torch.compile for memory
183
+ use_cpu=False, # Keep on GPU but optimize memory
184
+ # Reduce memory fragmentation
185
+ dataloader_persistent_workers=False,
186
+ )
187
+
188
+ # Data collator for pre-tokenized data
189
+ data_collator = DataCollatorForLanguageModeling(
190
+ tokenizer=self.tokenizer,
191
+ mlm=False,
192
+ )
193
+
194
+ # Trainer
195
+ self.trainer = Trainer(
196
+ model=self.model,
197
+ args=training_args,
198
+ train_dataset=train_dataset,
199
+ eval_dataset=val_dataset,
200
+ data_collator=data_collator,
201
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
202
+ )
203
+
204
+ logger.info("Training setup completed")
205
+
206
+ def train(self):
207
+ logger.info("Starting training...")
208
+
209
+ try:
210
+ # Train the model
211
+ train_result = self.trainer.train()
212
+
213
+ # Save the final model
214
+ self.trainer.save_model()
215
+
216
+ # Save training metrics
217
+ metrics = train_result.metrics
218
+ self.trainer.log_metrics("train", metrics)
219
+ self.trainer.save_metrics("train", metrics)
220
+ self.trainer.save_state()
221
+
222
+ logger.info("Training completed successfully!")
223
+ logger.info(f"Training metrics: {metrics}")
224
+
225
+ except Exception as e:
226
+ logger.error(f"Training failed: {e}")
227
+ raise
228
+
229
+ def save_model(self):
230
+ logger.info("Saving model...")
231
+
232
+ output_dir = self.config['training']['output_dir']
233
+
234
+ # Save tokenizer
235
+ self.tokenizer.save_pretrained(output_dir)
236
+
237
+ # Save model configuration
238
+ model_config = {
239
+ 'base_model': self.config['model']['base_model'],
240
+ 'lora_config': self.config['lora'],
241
+ 'generation_config': self.config['generation']
242
+ }
243
+
244
+ config_path = os.path.join(output_dir, 'model_config.json')
245
+ import json
246
+ with open(config_path, 'w') as f:
247
+ json.dump(model_config, f, indent=2)
248
+
249
+ logger.info(f"Model saved to {output_dir}")
250
+
251
+ def run(self):
252
+ logger.info("Starting agriQA fine-tuning pipeline...")
253
+
254
+ # Load model and tokenizer
255
+ self.load_model_and_tokenizer()
256
+
257
+ # Setup LoRA
258
+ self.setup_lora()
259
+
260
+ # Load and prepare datasets
261
+ train_dataset, val_dataset = self.load_dataset()
262
+
263
+ # Setup training
264
+ self.setup_training(train_dataset, val_dataset)
265
+
266
+ # Train the model
267
+ self.train()
268
+
269
+ # Save the model
270
+ self.save_model()
271
+
272
+ logger.info("Fine-tuning pipeline completed successfully!")
273
+
274
+ def main():
275
+ parser = argparse.ArgumentParser(description="Fine-tune Qwen model on agriQA dataset")
276
+ parser.add_argument("--config", type=str, default="configs/training_config.yaml",
277
+ help="Path to training configuration file")
278
+
279
+ args = parser.parse_args()
280
+
281
+ # Initialize and run fine-tuning
282
+ fine_tuner = AgriQAFineTuner(args.config)
283
+ fine_tuner.run()
284
+
285
+ if __name__ == "__main__":
286
+ =======
287
  import os
288
  import sys
289
  import yaml
 
568
  fine_tuner.run()
569
 
570
  if __name__ == "__main__":
571
+ >>>>>>> 3b1d9d4700da14631c2d7f96e38c9e460a1a4dd0
572
  main()