Nada commited on
Commit
64df380
·
1 Parent(s): fb5028e
Files changed (1) hide show
  1. src/inference/model.py +112 -157
src/inference/model.py CHANGED
@@ -1,23 +1,15 @@
1
  import os
2
  import json
3
- import logging
4
-
5
- # Set PyTorch environment variables before import to prevent logging issues
6
- os.environ.setdefault('TORCH_LOGS', 'torch')
7
- os.environ.setdefault('TORCH_SHOW_CPP_STACKTRACES', '0')
8
- os.environ.setdefault('TORCH_USE_CUDA_DSA', '0')
9
-
10
  import torch
 
11
  from typing import Dict, Any, Optional
12
-
13
- # Set transformers environment variables
14
- os.environ.setdefault('TRANSFORMERS_VERBOSITY', 'error')
15
- os.environ.setdefault('TOKENIZERS_PARALLELISM', 'false')
16
-
17
  from transformers import AutoModelForCausalLM, AutoTokenizer
18
  from peft import PeftModel
19
  import time
20
 
 
 
 
21
  logger = logging.getLogger(__name__)
22
 
23
  class AgriQAAssistant:
@@ -32,99 +24,85 @@ class AgriQAAssistant:
32
  self.load_model()
33
 
34
  def load_model(self):
35
-
36
  logger.info(f"Loading model from Hugging Face: {self.model_path}")
37
 
38
- # Set additional environment variables for model loading
39
- os.environ.setdefault('HF_HUB_OFFLINE', 'false')
40
- os.environ.setdefault('HF_HUB_DISABLE_TELEMETRY', '1')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
42
  try:
43
- # Configuration for the uploaded model
44
- self.config = {
45
- 'base_model': 'Qwen/Qwen1.5-1.8B-Chat',
46
- 'generation_config': {
47
- 'max_new_tokens': 512, # Increased for complete responses
48
- 'do_sample': True,
49
- 'temperature': 0.3, # Lower temperature for more consistent, structured responses
50
- 'top_p': 0.85, # Slightly lower for more focused sampling
51
- 'top_k': 40, # Lower for more focused responses
52
- 'repetition_penalty': 1.2, # Higher penalty to avoid repetition
53
- 'length_penalty': 1.1, # Encourage slightly longer, detailed responses
54
- 'no_repeat_ngram_size': 3 # Avoid repeating 3-grams
55
- }
56
- }
57
 
58
- # Load tokenizer from base model
59
- logger.info("Loading tokenizer from base model...")
60
- self.tokenizer = AutoTokenizer.from_pretrained(
61
  self.config['base_model'],
62
- trust_remote_code=True
 
63
  )
64
 
65
- if self.tokenizer.pad_token is None:
66
- self.tokenizer.pad_token = self.tokenizer.eos_token
67
-
68
- # Try to load the model directly from Hugging Face first
69
  try:
70
- logger.info("Attempting to load model directly from Hugging Face...")
71
- self.model = AutoModelForCausalLM.from_pretrained(
 
72
  self.model_path,
73
  torch_dtype=torch.float16,
74
- device_map="auto",
75
- trust_remote_code=True,
76
- attn_implementation="eager",
77
- use_flash_attention_2=False
78
- )
79
- logger.info("Model loaded directly from Hugging Face successfully")
80
- except Exception as direct_load_error:
81
- logger.info(f"Direct loading failed: {direct_load_error}")
82
- logger.info("Falling back to base model + LoRA adapter approach...")
83
-
84
- # Load base model first
85
- logger.info("Loading base model...")
86
- base_model = AutoModelForCausalLM.from_pretrained(
87
- self.config['base_model'],
88
- torch_dtype=torch.float16,
89
  device_map="auto"
90
  )
91
-
92
- # Try to load the LoRA adapter
93
- try:
94
- logger.info("Loading LoRA adapter from Hugging Face...")
95
- self.model = PeftModel.from_pretrained(
96
- base_model,
97
- self.model_path,
98
- torch_dtype=torch.float16,
99
- device_map="auto"
100
- )
101
- logger.info("LoRA adapter loaded successfully")
102
- except Exception as lora_error:
103
- logger.warning(f"LoRA adapter loading failed: {lora_error}")
104
- logger.info("Using base model without LoRA adapter...")
105
- self.model = base_model
106
-
107
- # Set to evaluation mode
108
- self.model.eval()
109
-
110
- # Log model information
111
- logger.info(f"Model loaded successfully from Hugging Face")
112
- logger.info(f"Model type: {type(self.model).__name__}")
113
- logger.info(f"Device: {self.device}")
114
-
115
- # Check if it's a PeftModel
116
- if hasattr(self.model, 'peft_config'):
117
- logger.info("LoRA adapter configuration:")
118
- for adapter_name, config in self.model.peft_config.items():
119
- logger.info(f" - {adapter_name}: {config.target_modules}")
120
-
121
- except Exception as e:
122
- logger.error(f"Failed to load model: {e}")
123
- logger.error(f"Model path: {self.model_path}")
124
- logger.error(f"Base model: {self.config['base_model']}")
125
- import traceback
126
- logger.error(f"Traceback: {traceback.format_exc()}")
127
- raise
128
 
129
  def format_prompt(self, question: str) -> str:
130
  """Format the question for the model using proper format."""
@@ -152,72 +130,49 @@ class AgriQAAssistant:
152
  def generate_response(self, question: str, max_length: Optional[int] = None) -> Dict[str, Any]:
153
  start_time = time.time()
154
 
155
- try:
156
- # Format the prompt
157
- prompt = self.format_prompt(question)
158
-
159
- # Set device for inputs
160
- device = self.device if hasattr(self, 'device') else 'cpu'
161
-
162
- # Tokenize input
163
- inputs = self.tokenizer(
164
- prompt,
165
- return_tensors="pt",
166
- truncation=True,
167
- max_length=2048
168
- ).to(device)
169
-
170
- # Generation parameters
171
- gen_config = self.config['generation_config'].copy()
172
- if max_length:
173
- gen_config['max_new_tokens'] = max_length
174
-
175
- # Generate response
176
- with torch.no_grad():
177
- try:
178
- outputs = self.model.generate(
179
- **inputs,
180
- **gen_config,
181
- pad_token_id=self.tokenizer.eos_token_id
182
- )
183
- except Exception as gen_error:
184
- logger.error(f"Generation error: {gen_error}")
185
- # Fallback to simpler generation
186
- outputs = self.model.generate(
187
- **inputs,
188
- max_new_tokens=gen_config.get('max_new_tokens', 512),
189
- do_sample=False,
190
- pad_token_id=self.tokenizer.eos_token_id
191
- )
192
-
193
- # Decode response
194
- response = self.tokenizer.decode(
195
- outputs[0][inputs['input_ids'].shape[1]:],
196
- skip_special_tokens=True
197
- ).strip()
198
-
199
- # Calculate response time
200
- response_time = time.time() - start_time
201
-
202
- return {
203
- 'answer': response,
204
- 'response_time': response_time,
205
- 'model_info': {
206
- 'model_name': 'agriqa-assistant',
207
- 'model_source': 'Hugging Face',
208
- 'model_path': self.model_path,
209
- 'base_model': self.config['base_model']
210
- }
211
- }
212
-
213
- except Exception as e:
214
- logger.error(f"Error generating response: {e}")
215
- return {
216
- 'answer': "I apologize, but I encountered an error while processing your question. Please try again.",
217
- 'confidence': 0.0,
218
- 'response_time': time.time() - start_time,
219
- 'error': str(e)
220
  }
 
221
 
222
  def get_model_info(self) -> Dict[str, Any]:
223
  """Get information about the loaded model."""
 
1
  import os
2
  import json
 
 
 
 
 
 
 
3
  import torch
4
+ import logging
5
  from typing import Dict, Any, Optional
 
 
 
 
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from peft import PeftModel
8
  import time
9
 
10
+ # Fix PyTorch logging issue
11
+ os.environ['TORCH_LOGS'] = 'torch'
12
+
13
  logger = logging.getLogger(__name__)
14
 
15
  class AgriQAAssistant:
 
24
  self.load_model()
25
 
26
  def load_model(self):
 
27
  logger.info(f"Loading model from Hugging Face: {self.model_path}")
28
 
29
+ # Configuration for the uploaded model
30
+ self.config = {
31
+ 'base_model': 'Qwen/Qwen1.5-1.8B-Chat',
32
+ 'generation_config': {
33
+ 'max_new_tokens': 512,
34
+ 'do_sample': True,
35
+ 'temperature': 0.3,
36
+ 'top_p': 0.85,
37
+ 'top_k': 40,
38
+ 'repetition_penalty': 1.2,
39
+ 'length_penalty': 1.1,
40
+ 'no_repeat_ngram_size': 3
41
+ }
42
+ }
43
+
44
+ # Load tokenizer from base model
45
+ logger.info("Loading tokenizer from base model...")
46
+ self.tokenizer = AutoTokenizer.from_pretrained(
47
+ self.config['base_model'],
48
+ trust_remote_code=True
49
+ )
50
+
51
+ if self.tokenizer.pad_token is None:
52
+ self.tokenizer.pad_token = self.tokenizer.eos_token
53
 
54
+ # Try to load the model directly from Hugging Face first
55
  try:
56
+ logger.info("Attempting to load model directly from Hugging Face...")
57
+ self.model = AutoModelForCausalLM.from_pretrained(
58
+ self.model_path,
59
+ torch_dtype=torch.float16,
60
+ device_map="auto",
61
+ trust_remote_code=True,
62
+ attn_implementation="eager",
63
+ use_flash_attention_2=False
64
+ )
65
+ logger.info("Model loaded directly from Hugging Face successfully")
66
+ except Exception as direct_load_error:
67
+ logger.info(f"Direct loading failed: {direct_load_error}")
68
+ logger.info("Falling back to base model + LoRA adapter approach...")
 
69
 
70
+ # Load base model first
71
+ logger.info("Loading base model...")
72
+ base_model = AutoModelForCausalLM.from_pretrained(
73
  self.config['base_model'],
74
+ torch_dtype=torch.float16,
75
+ device_map="auto"
76
  )
77
 
78
+ # Try to load the LoRA adapter
 
 
 
79
  try:
80
+ logger.info("Loading LoRA adapter from Hugging Face...")
81
+ self.model = PeftModel.from_pretrained(
82
+ base_model,
83
  self.model_path,
84
  torch_dtype=torch.float16,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  device_map="auto"
86
  )
87
+ logger.info("LoRA adapter loaded successfully")
88
+ except Exception as lora_error:
89
+ logger.warning(f"LoRA adapter loading failed: {lora_error}")
90
+ logger.info("Using base model without LoRA adapter...")
91
+ self.model = base_model
92
+
93
+ # Set to evaluation mode
94
+ self.model.eval()
95
+
96
+ # Log model information
97
+ logger.info(f"Model loaded successfully from Hugging Face")
98
+ logger.info(f"Model type: {type(self.model).__name__}")
99
+ logger.info(f"Device: {self.device}")
100
+
101
+ # Check if it's a PeftModel
102
+ if hasattr(self.model, 'peft_config'):
103
+ logger.info("LoRA adapter configuration:")
104
+ for adapter_name, config in self.model.peft_config.items():
105
+ logger.info(f" - {adapter_name}: {config.target_modules}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  def format_prompt(self, question: str) -> str:
108
  """Format the question for the model using proper format."""
 
130
  def generate_response(self, question: str, max_length: Optional[int] = None) -> Dict[str, Any]:
131
  start_time = time.time()
132
 
133
+ # Format the prompt
134
+ prompt = self.format_prompt(question)
135
+
136
+ # Tokenize input
137
+ inputs = self.tokenizer(
138
+ prompt,
139
+ return_tensors="pt",
140
+ truncation=True,
141
+ max_length=2048
142
+ ).to(self.device)
143
+
144
+ # Generation parameters
145
+ gen_config = self.config['generation_config'].copy()
146
+ if max_length:
147
+ gen_config['max_new_tokens'] = max_length
148
+
149
+ # Generate response
150
+ with torch.no_grad():
151
+ outputs = self.model.generate(
152
+ **inputs,
153
+ **gen_config,
154
+ pad_token_id=self.tokenizer.eos_token_id
155
+ )
156
+
157
+ # Decode response
158
+ response = self.tokenizer.decode(
159
+ outputs[0][inputs['input_ids'].shape[1]:],
160
+ skip_special_tokens=True
161
+ ).strip()
162
+
163
+ # Calculate response time
164
+ response_time = time.time() - start_time
165
+
166
+ return {
167
+ 'answer': response,
168
+ 'response_time': response_time,
169
+ 'model_info': {
170
+ 'model_name': 'agriqa-assistant',
171
+ 'model_source': 'Hugging Face',
172
+ 'model_path': self.model_path,
173
+ 'base_model': self.config['base_model']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  }
175
+ }
176
 
177
  def get_model_info(self) -> Dict[str, Any]:
178
  """Get information about the loaded model."""