WildnerveAI commited on
Commit
05ca8fc
·
verified ·
1 Parent(s): 6a1d9d9

Upload 8 files

Browse files
Files changed (8) hide show
  1. config.py +49 -45
  2. dataloader.py +340 -0
  3. dataset.py +435 -0
  4. dependency_helpers.py +104 -98
  5. handler.py +33 -6
  6. model_Custm.py +6 -5
  7. service_registry.py +45 -0
  8. transformer_patches.py +23 -0
config.py CHANGED
@@ -6,13 +6,17 @@ import argparse
6
  from pathlib import Path
7
  from typing import Optional, Dict, List, Literal, Any
8
 
 
 
 
9
  # --- gracefully handle missing pydantic ---
10
  try:
11
  from pydantic import BaseModel, Field, ValidationError, ConfigDict
12
  except ImportError:
13
- import logging
 
14
  logger = logging.getLogger(__name__)
15
- logger.warning("pydantic not available, using dummy BaseModel/Field")
16
  class BaseModel:
17
  def __init__(self, **kwargs):
18
  for k, v in kwargs.items(): setattr(self, k, v)
@@ -103,82 +107,82 @@ SPECIALIZATIONS = [
103
  # Define DATASET_PATHS so that each specialization is a string or a list of strings
104
  DATASET_PATHS = {
105
  "computer": [
106
- List[str(DATA_DIR / "data" / "computer_advanced_debugging.json")],
107
- List[str(DATA_DIR / "data" / "computer_agenticAI.json")],
108
- List[str(DATA_DIR / "data" / "computer_architecture.json")],
109
- List[str(DATA_DIR / "data" / "computer_cloud_security.json")],
110
- List[str(DATA_DIR / "data" / "computer_creativity.json")],
111
- List[str(DATA_DIR / "data" / "computer_crossplatform.json")],
112
- List[str(DATA_DIR / "data" / "computer_cybersecurity.json")],
113
- List[str(DATA_DIR / "data" / "computer_error_handling_examples.json")],
114
- List[str(DATA_DIR / "data" / "computer_gitInstruct.json")]
115
  ],
116
 
117
  "cpp": [
118
- List[str(DATA_DIR / "data" / "cpp_advanced_debugging.json")],
119
- List[str(DATA_DIR / "data" / "cpp_blockchain.json")],
120
- List[str(DATA_DIR / "data" / "cpp_mbcppp.json")],
121
- List[str(DATA_DIR / "data" / "cpp_programming.json")]
122
  ],
123
 
124
  "java": [
125
- List[str(DATA_DIR / "data" / "java_ai_language_model.json")],
126
- List[str(DATA_DIR / "data" / "java_blockchain.json")],
127
- List[str(DATA_DIR / "data" / "java_mbjp.json")],
128
- List[str(DATA_DIR / "data" / "java_transformer_language_model.json")],
129
  ],
130
 
131
  "go": [
132
- List[str(DATA_DIR / "data" / "golang_ai_language_model.json")],
133
- List[str(DATA_DIR / "data" / "golang_mbgp.json")],
134
- List[str(DATA_DIR / "data" / "golang_programming.json")]
135
  ],
136
 
137
  "javascript": [
138
- List[str(DATA_DIR / "data" / "javascript_chatbot.json")],
139
- List[str(DATA_DIR / "data" / "javascript_n_Typescript_frontend.json")],
140
- List[str(DATA_DIR / "data" / "javascript_n_Typescript_backend.json")],
141
- List[str(DATA_DIR / "data" / "javascript_programming.json")]
142
  ],
143
 
144
  "nim": [
145
- List[str(DATA_DIR / "data" / "nim_ai_language_model.json")],
146
- List[str(DATA_DIR / "data" / "nim_blockchain.json")],
147
- List[str(DATA_DIR / "data" / "nim_chatbot.json")],
148
- List[str(DATA_DIR / "data" / "nim_mbnp.json")],
149
- List[str(DATA_DIR / "data" / "nim_programming.json")]
150
  ],
151
 
152
  "python": [
153
- List[str(DATA_DIR / "data" / "python_chatbot_guide.json")],
154
- List[str(DATA_DIR / "data" / "python_mbpp.json")],
155
- List[str(DATA_DIR / "data" / "python_programming.json")],
156
- List[str(DATA_DIR / "data" / "python_transformer_model.json")]
157
  ],
158
 
159
  "rust": [
160
- List[str(DATA_DIR / "data" / "rust_ai_language_model.json")],
161
- List[str(DATA_DIR / "data" / "rust_blockchain.json")],
162
- List[str(DATA_DIR / "data" / "rust_mbrp.json")],
163
- List[str(DATA_DIR / "data" / "rust_programming.json")]
164
  ],
165
 
166
  "solidity": [
167
- List[str(DATA_DIR / "data" / "solidity_programming.json")]
168
  ],
169
 
170
  "mathematics": [
171
- List[str(DATA_DIR / "data" / "mathematics.json")],
172
- List[str(DATA_DIR / "data" / "mathematics_training.json")]
173
  ],
174
 
175
  "physics": [
176
- List[str(DATA_DIR / "data" / "physics_n_engineering.json")],
177
- List[str(DATA_DIR / "data" / "physics_n_engineering_applied.json")]
178
  ],
179
 
180
  "other_information": [
181
- List[str(DATA_DIR / "data" / "other_information.json")]
182
  ]
183
  }
184
 
 
6
  from pathlib import Path
7
  from typing import Optional, Dict, List, Literal, Any
8
 
9
+ # Import dependency helpers first
10
+ import dependency_helpers
11
+
12
  # --- gracefully handle missing pydantic ---
13
  try:
14
  from pydantic import BaseModel, Field, ValidationError, ConfigDict
15
  except ImportError:
16
+ # The import error should be handled by dependency_helpers
17
+ # But we'll add one more fallback just to be safe
18
  logger = logging.getLogger(__name__)
19
+ logger.warning("pydantic not available, using dummy implementation")
20
  class BaseModel:
21
  def __init__(self, **kwargs):
22
  for k, v in kwargs.items(): setattr(self, k, v)
 
107
  # Define DATASET_PATHS so that each specialization is a string or a list of strings
108
  DATASET_PATHS = {
109
  "computer": [
110
+ str(DATA_DIR / "data" / "computer_advanced_debugging.json"),
111
+ str(DATA_DIR / "data" / "computer_agenticAI.json"),
112
+ str(DATA_DIR / "data" / "computer_architecture.json"),
113
+ str(DATA_DIR / "data" / "computer_cloud_security.json"),
114
+ str(DATA_DIR / "data" / "computer_creativity.json"),
115
+ str(DATA_DIR / "data" / "computer_crossplatform.json"),
116
+ str(DATA_DIR / "data" / "computer_cybersecurity.json"),
117
+ str(DATA_DIR / "data" / "computer_error_handling_examples.json"),
118
+ str(DATA_DIR / "data" / "computer_gitInstruct.json")
119
  ],
120
 
121
  "cpp": [
122
+ str(DATA_DIR / "data" / "cpp_advanced_debugging.json"),
123
+ str(DATA_DIR / "data" / "cpp_blockchain.json"),
124
+ str(DATA_DIR / "data" / "cpp_mbcppp.json"),
125
+ str(DATA_DIR / "data" / "cpp_programming.json")
126
  ],
127
 
128
  "java": [
129
+ str(DATA_DIR / "data" / "java_ai_language_model.json"),
130
+ str(DATA_DIR / "data" / "java_blockchain.json"),
131
+ str(DATA_DIR / "data" / "java_mbjp.json"),
132
+ str(DATA_DIR / "data" / "java_transformer_language_model.json"),
133
  ],
134
 
135
  "go": [
136
+ str(DATA_DIR / "data" / "golang_ai_language_model.json"),
137
+ str(DATA_DIR / "data" / "golang_mbgp.json"),
138
+ str(DATA_DIR / "data" / "golang_programming.json")
139
  ],
140
 
141
  "javascript": [
142
+ str(DATA_DIR / "data" / "javascript_chatbot.json"),
143
+ str(DATA_DIR / "data" / "javascript_n_Typescript_frontend.json"),
144
+ str(DATA_DIR / "data" / "javascript_n_Typescript_backend.json"),
145
+ str(DATA_DIR / "data" / "javascript_programming.json")
146
  ],
147
 
148
  "nim": [
149
+ str(DATA_DIR / "data" / "nim_ai_language_model.json"),
150
+ str(DATA_DIR / "data" / "nim_blockchain.json"),
151
+ str(DATA_DIR / "data" / "nim_chatbot.json"),
152
+ str(DATA_DIR / "data" / "nim_mbnp.json"),
153
+ str(DATA_DIR / "data" / "nim_programming.json")
154
  ],
155
 
156
  "python": [
157
+ str(DATA_DIR / "data" / "python_chatbot_guide.json"),
158
+ str(DATA_DIR / "data" / "python_mbpp.json"),
159
+ str(DATA_DIR / "data" / "python_programming.json"),
160
+ str(DATA_DIR / "data" / "python_transformer_model.json")
161
  ],
162
 
163
  "rust": [
164
+ str(DATA_DIR / "data" / "rust_ai_language_model.json"),
165
+ str(DATA_DIR / "data" / "rust_blockchain.json"),
166
+ str(DATA_DIR / "data" / "rust_mbrp.json"),
167
+ str(DATA_DIR / "data" / "rust_programming.json")
168
  ],
169
 
170
  "solidity": [
171
+ str(DATA_DIR / "data" / "solidity_programming.json")
172
  ],
173
 
174
  "mathematics": [
175
+ str(DATA_DIR / "data" / "mathematics.json"),
176
+ str(DATA_DIR / "data" / "mathematics_training.json")
177
  ],
178
 
179
  "physics": [
180
+ str(DATA_DIR / "data" / "physics_n_engineering.json"),
181
+ str(DATA_DIR / "data" / "physics_n_engineering_applied.json")
182
  ],
183
 
184
  "other_information": [
185
+ str(DATA_DIR / "data" / "other_information.json")
186
  ]
187
  }
188
 
dataloader.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data loader factory and utilities for transformer models.
3
+ """
4
+ import os
5
+ import json
6
+ import torch
7
+ import logging
8
+ import pandas as pd
9
+ import numpy as np
10
+ from typing import Dict, List, Optional, Union, Any, Tuple
11
+ from torch.utils.data import Dataset, DataLoader, TensorDataset
12
+ from pathlib import Path
13
+ from config import app_config
14
+ from tokenizer import TokenizerWrapper
15
+ from datagrower.Crawl4MyAI import AdvancedWebCrawler
16
+ from datagrower.Webconverter import WebConverter
17
+ from dataset import DatasetManager
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ class TransformerDataset(Dataset):
22
+ """Base dataset for transformer models that handles multiple input formats."""
23
+
24
+ def __init__(
25
+ self,
26
+ data_path: str,
27
+ tokenizer: TokenizerWrapper,
28
+ max_length: int = 512,
29
+ format_type: str = None
30
+ ):
31
+ """
32
+ Initialize dataset.
33
+
34
+ Args:
35
+ data_path: Path to the data file
36
+ tokenizer: Tokenizer to use for encoding
37
+ max_length: Maximum sequence length
38
+ format_type: Format of data file ('csv', 'json', 'txt')
39
+ """
40
+ self.data_path = data_path
41
+ self.tokenizer = tokenizer
42
+ self.max_length = max_length
43
+ self.format_type = format_type or self._detect_format(data_path)
44
+
45
+ # Load data
46
+ self.data = self._load_data()
47
+ logger.info(f"Loaded {len(self.data)} samples from {data_path}")
48
+
49
+ def _detect_format(self, path: str) -> str:
50
+ """Detect file format from extension."""
51
+ ext = os.path.splitext(path)[1].lower().lstrip('.')
52
+ if ext in ['csv']:
53
+ return 'csv'
54
+ elif ext in ['json']:
55
+ return 'json'
56
+ elif ext in ['txt', 'text']:
57
+ return 'txt'
58
+ else:
59
+ logger.warning(f"Unknown file extension: {ext}, defaulting to CSV")
60
+ return 'csv'
61
+
62
+ def _load_data(self) -> List[Dict[str, Any]]:
63
+ """Load data based on format type."""
64
+ if not os.path.exists(self.data_path):
65
+ raise FileNotFoundError(f"Data file not found: {self.data_path}")
66
+
67
+ try:
68
+ if self.format_type == 'csv':
69
+ return self._load_csv()
70
+ elif self.format_type == 'json':
71
+ return self._load_json()
72
+ elif self.format_type == 'txt':
73
+ return self._load_txt()
74
+ else:
75
+ raise ValueError(f"Unsupported format type: {self.format_type}")
76
+ except Exception as e:
77
+ logger.error(f"Error loading data from {self.data_path}: {e}")
78
+ raise
79
+
80
+ def _load_csv(self) -> List[Dict[str, Any]]:
81
+ """Load data from CSV file."""
82
+ df = pd.read_csv(self.data_path)
83
+ # Check for required columns
84
+ if 'text' not in df.columns:
85
+ # Try to find a column with text data
86
+ text_cols = [col for col in df.columns if 'text' in col.lower() or 'content' in col.lower()]
87
+ if text_cols:
88
+ df = df.rename(columns={text_cols[0]: 'text'})
89
+ else:
90
+ # Use the first column as text
91
+ df = df.rename(columns={df.columns[0]: 'text'})
92
+
93
+ # Check for label column
94
+ if 'label' not in df.columns and len(df.columns) > 1:
95
+ # Use the second column as label if present
96
+ df = df.rename(columns={df.columns[1]: 'label'})
97
+
98
+ return df.to_dict('records')
99
+
100
+ def _load_json(self) -> List[Dict[str, Any]]:
101
+ """Load data from JSON file."""
102
+ with open(self.data_path, 'r', encoding='utf-8') as f:
103
+ data = json.load(f)
104
+
105
+ # Handle different JSON formats
106
+ if isinstance(data, list):
107
+ # Already in list format
108
+ return data
109
+ elif isinstance(data, dict):
110
+ # Extract data from dictionary
111
+ if 'data' in data:
112
+ return data['data']
113
+ elif 'examples' in data:
114
+ return data['examples']
115
+ elif 'user_inputs' in data:
116
+ return data['user_inputs']
117
+ else:
118
+ # Convert flat dictionary to list
119
+ return [{'text': str(value), 'id': key} for key, value in data.items()]
120
+ else:
121
+ raise ValueError(f"Unsupported JSON data structure: {type(data)}")
122
+
123
+ def _load_txt(self) -> List[Dict[str, Any]]:
124
+ """Load data from text file, one sample per line."""
125
+ with open(self.data_path, 'r', encoding='utf-8') as f:
126
+ lines = f.readlines()
127
+
128
+ # Clean and convert to dictionaries
129
+ return [{'text': line.strip(), 'id': i} for i, line in enumerate(lines) if line.strip()]
130
+
131
+ def __len__(self) -> int:
132
+ """Get dataset length."""
133
+ return len(self.data)
134
+
135
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
136
+ """Get an item from the dataset."""
137
+ item = self.data[idx]
138
+ text = item.get('text', '')
139
+
140
+ # Handle empty text
141
+ if not text:
142
+ text = " " # Use space to avoid tokenizer errors
143
+
144
+ # Tokenize text
145
+ encoding = self.tokenizer(
146
+ text,
147
+ max_length=self.max_length,
148
+ padding="max_length",
149
+ truncation=True,
150
+ return_tensors="pt"
151
+ )
152
+
153
+ # Extract tensors and squeeze batch dimension
154
+ input_ids = encoding["input_ids"].squeeze(0)
155
+ attention_mask = encoding["attention_mask"].squeeze(0)
156
+
157
+ # Get label if available
158
+ label = item.get('label', 0)
159
+ if isinstance(label, str):
160
+ try:
161
+ label = float(label)
162
+ except ValueError:
163
+ # Use hash of string for categorical labels
164
+ label = hash(label) % 100 # Limit to 100 categories
165
+
166
+ return {
167
+ 'input_ids': input_ids,
168
+ 'attention_mask': attention_mask,
169
+ 'labels': torch.tensor(label, dtype=torch.long)
170
+ }
171
+
172
+ def prepare_data_loaders_extended(
173
+ data_path: Union[str, Dict[str, str]],
174
+ tokenizer: Any,
175
+ batch_size: int = 16,
176
+ max_length: int = 512,
177
+ val_split: float = 0.1,
178
+ format_type: Optional[str] = None,
179
+ num_workers: int = 0
180
+ ) -> Dict[str, DataLoader]:
181
+ """
182
+ Create data loaders for training and validation.
183
+
184
+ Args:
185
+ data_path: Path to data file or dictionary mapping split to path
186
+ tokenizer: Tokenizer to use for encoding
187
+ batch_size: Batch size
188
+ max_length: Maximum sequence length
189
+ val_split: Validation split ratio when only one path is provided
190
+ format_type: Format of data file
191
+ num_workers: Number of workers for DataLoader
192
+
193
+ Returns:
194
+ Dictionary mapping split names to DataLoaders
195
+ """
196
+ data_loaders = {}
197
+
198
+ # Handle different types of data_path
199
+ if isinstance(data_path, dict):
200
+ # Multiple paths for different splits
201
+ for split_name, path in data_path.items():
202
+ dataset = TransformerDataset(
203
+ data_path=path,
204
+ tokenizer=tokenizer,
205
+ max_length=max_length,
206
+ format_type=format_type
207
+ )
208
+
209
+ data_loaders[split_name] = DataLoader(
210
+ dataset,
211
+ batch_size=batch_size,
212
+ shuffle=(split_name == 'train'),
213
+ num_workers=num_workers
214
+ )
215
+ else:
216
+ # Single path, create train/val split
217
+ dataset = TransformerDataset(
218
+ data_path=data_path,
219
+ tokenizer=tokenizer,
220
+ max_length=max_length,
221
+ format_type=format_type
222
+ )
223
+
224
+ # Split dataset
225
+ val_size = int(len(dataset) * val_split)
226
+ train_size = len(dataset) - val_size
227
+
228
+ if val_size > 0:
229
+ train_dataset, val_dataset = torch.utils.data.random_split(
230
+ dataset, [train_size, val_size]
231
+ )
232
+
233
+ data_loaders['train'] = DataLoader(
234
+ train_dataset,
235
+ batch_size=batch_size,
236
+ shuffle=True,
237
+ num_workers=num_workers
238
+ )
239
+
240
+ data_loaders['validation'] = DataLoader(
241
+ val_dataset,
242
+ batch_size=batch_size,
243
+ shuffle=False,
244
+ num_workers=num_workers
245
+ )
246
+ else:
247
+ # No validation split
248
+ data_loaders['train'] = DataLoader(
249
+ dataset,
250
+ batch_size=batch_size,
251
+ shuffle=True,
252
+ num_workers=num_workers
253
+ )
254
+
255
+ return data_loaders
256
+
257
+ def prepare_data_loaders(
258
+ data_path: str,
259
+ tokenizer: Any,
260
+ batch_size: int = 16,
261
+ val_split: float = 0.1
262
+ ) -> Tuple[DataLoader, Optional[DataLoader]]:
263
+ """
264
+ Simplified version that returns train and validation loaders directly.
265
+
266
+ Args:
267
+ data_path: Path to data file
268
+ tokenizer: Tokenizer to use for encoding
269
+ batch_size: Batch size
270
+ val_split: Validation split ratio
271
+
272
+ Returns:
273
+ Tuple of (train_loader, val_loader)
274
+ """
275
+ loaders = prepare_data_loaders_extended(
276
+ data_path=data_path,
277
+ tokenizer=tokenizer,
278
+ batch_size=batch_size,
279
+ val_split=val_split
280
+ )
281
+
282
+ train_loader = loaders.get('train')
283
+ val_loader = loaders.get('validation')
284
+
285
+ return train_loader, val_loader
286
+
287
+ def load_dataset(
288
+ specialization: str,
289
+ tokenizer: Any = None,
290
+ split: str = 'train'
291
+ ) -> Dataset:
292
+ """
293
+ Load a dataset for a specific specialization.
294
+
295
+ Args:
296
+ specialization: Name of the specialization
297
+ tokenizer: Tokenizer to use (optional)
298
+ split: Dataset split to load
299
+
300
+ Returns:
301
+ Dataset instance
302
+ """
303
+ # Get dataset path from config
304
+ if hasattr(app_config, 'DATASET_PATHS') and specialization in app_config.DATASET_PATHS:
305
+ data_path = app_config.DATASET_PATHS[specialization]
306
+ else:
307
+ data_path = os.path.join(app_config.BASE_DATA_DIR, f"{specialization}.csv")
308
+
309
+ # Get or create tokenizer
310
+ if tokenizer is None:
311
+ from tokenizer import TokenizerWrapper
312
+ tokenizer = TokenizerWrapper()
313
+
314
+ # handle URL paths first via crawler + converter
315
+ if data_path.startswith("http://") or data_path.startswith("https://"):
316
+ crawler = AdvancedWebCrawler()
317
+ converter = WebConverter(crawler=crawler)
318
+ raw_entries = converter.get_converted_web_data([data_path])
319
+ # assume raw_entries is list of dicts {"text":…, "label":…}
320
+ return TransformerDataset(data_path=data_path, tokenizer=tokenizer)._process_records(raw_entries)
321
+
322
+ # Create dataset
323
+ dataset = TransformerDataset(
324
+ data_path=data_path,
325
+ tokenizer=tokenizer,
326
+ max_length=app_config.TRANSFORMER_CONFIG.MAX_SEQ_LENGTH
327
+ )
328
+
329
+ return dataset
330
+
331
+ def load_for_specialization(spec: str):
332
+ paths = app_config.get("DATASET_PATHS", {}).get(spec, [])
333
+ # normalize to list
334
+ if isinstance(paths, str):
335
+ paths = [paths]
336
+ manager = DatasetManager()
337
+ return manager.load_dataset(paths, spec)
338
+
339
+ # Short alias for common use case
340
+ get_dataloader = prepare_data_loaders
dataset.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset.py
2
+ import os
3
+ import csv
4
+ import json
5
+ import torch
6
+ import logging
7
+ from preprocess import Preprocessor
8
+ from torch.utils.data import Dataset
9
+ from typing import List, Dict, Any, Optional, Union
10
+ from functools import wraps
11
+ from time import time
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ def safe_file_operation(func):
16
+ """Decorator to safely handle file operations with timeout"""
17
+ @wraps(func)
18
+ def wrapper(self, *args, **kwargs):
19
+ start_time = time()
20
+ timeout_seconds = 300 # 5-minute timeout
21
+
22
+ try:
23
+ # Try to perform the operation
24
+ result = func(self, *args, **kwargs)
25
+
26
+ # Check if operation took too long
27
+ if time() - start_time > timeout_seconds:
28
+ logger.warning(f"File operation {func.__name__} took more than {timeout_seconds} seconds")
29
+
30
+ return result
31
+ except (IOError, OSError) as e:
32
+ logger.error(f"File operation error in {func.__name__}: {str(e)}")
33
+ # Return empty result based on function type
34
+ if func.__name__.startswith('_load_'):
35
+ return []
36
+ raise
37
+ except json.JSONDecodeError as e:
38
+ logger.error(f"JSON decode error in {self.file_path}: {str(e)}")
39
+ return []
40
+ except csv.Error as e:
41
+ logger.error(f"CSV error in {self.file_path}: {str(e)}")
42
+ return []
43
+ except Exception as e:
44
+ logger.error(f"Unexpected error in {func.__name__}: {str(e)}")
45
+ raise
46
+
47
+ return wrapper
48
+
49
+ class TensorDataset(Dataset):
50
+ """Dataset class for handling tensor data with features and labels."""
51
+ def __init__(self, features, labels):
52
+ """
53
+ Initialize TensorDataset.
54
+
55
+ Args:
56
+ features (Tensor): Feature tensors.
57
+ labels (Tensor): Label tensors.
58
+ """
59
+ self.features = features
60
+ self.labels = labels
61
+
62
+ def __len__(self):
63
+ return len(self.features)
64
+
65
+ def __getitem__(self, idx):
66
+ return self.features[idx], self.labels[idx]
67
+
68
+ class CustomDataset(Dataset):
69
+ """A dataset that supports loading JSON, CSV, and TXT formats.
70
+ It auto-detects the file type (if not specified) and filters out any
71
+ records that are not dictionaries. If a preprocessor is provided, it
72
+ applies it to each record. Additionally, it can standardize sample keys
73
+ dynamically using a provided header mapping. For example, you can define a
74
+ mapping like:
75
+ mapping = {
76
+ "title": ["Title", "Headline", "Article Title"],
77
+ "content": ["Content", "Body", "Text"],
78
+ }
79
+ so that regardless of the CSV's header names your trainer always sees a
80
+ standardized set of keys."""
81
+ def __init__(
82
+ self,
83
+ file_path: Optional[str] = None,
84
+ tokenizer = None,
85
+ max_length: Optional[int] = None,
86
+ file_format: Optional[str] = None,
87
+ preprocessor: Optional[Preprocessor] = None,
88
+ header_mapping: Optional[Dict[str, List[str]]] = None,
89
+ data: Optional[List[Dict[str, Any]]] = None, # Add data parameter
90
+ specialization: Optional[str] = None # Add specialization parameter
91
+ ):
92
+ """Args:
93
+ file_path (Optional[str]): Path to the dataset file.
94
+ tokenizer: Tokenizer instance to process the text.
95
+ max_length (Optional[int]): Maximum sequence length.
96
+ file_format (Optional[str]): Format of the file; inferred from the extension if not provided.
97
+ preprocessor (Optional[Preprocessor]): Preprocessor to apply to each sample.
98
+ header_mapping (Optional[Dict[str, List[str]]]): Dictionary that maps standardized keys.
99
+ data (Optional[List[Dict[str, Any]]]): Direct data input instead of loading from file.
100
+ specialization (Optional[str]): Specialization field for the dataset."""
101
+
102
+ self.file_path = file_path
103
+ self.tokenizer = tokenizer
104
+ self.max_length = max_length
105
+ self.preprocessor = preprocessor
106
+ self.header_mapping = header_mapping
107
+ self.specialization = specialization # Store the specialization
108
+
109
+ # Initialize samples either from data or file
110
+ if data is not None:
111
+ self.samples = data
112
+ else:
113
+ # Determine the file format if not specified and file_path is provided
114
+ if file_path is not None:
115
+ if file_format is None:
116
+ _, ext = os.path.splitext(file_path)
117
+ ext = ext.lower()
118
+ if ext in ['.json']:
119
+ file_format = 'json'
120
+ elif ext in ['.csv']:
121
+ file_format = 'csv'
122
+ elif ext in ['.txt']:
123
+ file_format = 'txt'
124
+ else:
125
+ logger.error(f"Unsupported file extension: {ext}")
126
+ raise ValueError(f"Unsupported file extension: {ext}")
127
+
128
+ self.file_format = file_format
129
+ self.samples = self._load_file()
130
+ else:
131
+ self.samples = []
132
+
133
+ # Auto-detection: Ensure all loaded samples are dicts.
134
+ initial_sample_count = len(self.samples)
135
+ self.samples = [sample for sample in self.samples if isinstance(sample, dict)]
136
+ if len(self.samples) < initial_sample_count:
137
+ logger.warning(f"Filtered out {initial_sample_count - len(self.samples)} samples that were not dicts.")
138
+
139
+ # If a preprocessor is provided, apply preprocessing to each record.
140
+ if self.preprocessor:
141
+ preprocessed_samples = []
142
+ for sample in self.samples:
143
+ try:
144
+ processed = self.preprocessor.preprocess_record(sample)
145
+ preprocessed_samples.append(processed)
146
+ except Exception as e:
147
+ logger.error(f"Error preprocessing record {sample}: {e}")
148
+ self.samples = preprocessed_samples
149
+
150
+ def _load_file(self) -> List[Dict[str, Any]]:
151
+ try:
152
+ if self.file_format == 'json':
153
+ return self._load_json()
154
+ elif self.file_format == 'csv':
155
+ return self._load_csv()
156
+ elif self.file_format == 'txt':
157
+ return self._load_txt()
158
+ else:
159
+ logger.error(f"Unrecognized file format: {self.file_format}")
160
+ raise ValueError(f"Unrecognized file format: {self.file_format}")
161
+ except Exception as e:
162
+ logger.error(f"Error loading file {self.file_path}: {e}")
163
+ raise
164
+
165
+ @safe_file_operation
166
+ def _load_json(self) -> List[Dict[str, Any]]:
167
+ """Load JSON file with better error handling and validation"""
168
+ try:
169
+ with open(self.file_path, 'r', encoding='utf-8') as f:
170
+ data = json.load(f)
171
+
172
+ # Validate data structure
173
+ if isinstance(data, list):
174
+ valid_records = [record for record in data if isinstance(record, dict)]
175
+ if len(valid_records) < len(data):
176
+ logger.warning(f"{len(data) - len(valid_records)} records were not dictionaries in {self.file_path}")
177
+ return valid_records
178
+ elif isinstance(data, dict):
179
+ # Handle single record case
180
+ logger.warning(f"JSON file contains a single dictionary, not a list: {self.file_path}")
181
+ return [data]
182
+ else:
183
+ logger.error(f"JSON file does not contain a list or dictionary: {self.file_path}")
184
+ return []
185
+ except json.JSONDecodeError as e:
186
+ line_col = f"line {e.lineno}, column {e.colno}"
187
+ logger.error(f"JSON decode error at {line_col} in {self.file_path}: {e.msg}")
188
+ # Try to recover partial content
189
+ try:
190
+ with open(self.file_path, 'r', encoding='utf-8') as f:
191
+ content = f.read()
192
+ # Try parsing up to the error
193
+ valid_part = content[:e.pos]
194
+ import re
195
+ # Find complete objects (rough approach)
196
+ matches = re.findall(r'\{[^{}]*\}', valid_part)
197
+ if matches:
198
+ logger.info(f"Recovered {len(matches)} complete records from {self.file_path}")
199
+ parsed_records = []
200
+ for match in matches:
201
+ try:
202
+ parsed_records.append(json.loads(match))
203
+ except:
204
+ pass
205
+ return parsed_records
206
+ except:
207
+ pass
208
+ return []
209
+
210
+ @safe_file_operation
211
+ def _load_csv(self) -> List[Dict[str, Any]]:
212
+ """Load CSV with better error handling"""
213
+ samples = []
214
+ try:
215
+ with open(self.file_path, 'r', encoding='utf-8') as csvfile:
216
+ # Try detecting dialect first
217
+ try:
218
+ dialect = csv.Sniffer().sniff(csvfile.read(1024))
219
+ csvfile.seek(0)
220
+ reader = csv.DictReader(csvfile, dialect=dialect)
221
+ except:
222
+ # Fall back to excel dialect
223
+ csvfile.seek(0)
224
+ reader = csv.DictReader(csvfile, dialect='excel')
225
+
226
+ for i, row in enumerate(reader):
227
+ if not isinstance(row, dict):
228
+ logger.warning(f"Row {i} is not a dict: {row} -- skipping.")
229
+ continue
230
+ samples.append(row)
231
+
232
+ if not samples:
233
+ logger.warning(f"No valid rows found in CSV file: {self.file_path}")
234
+
235
+ except csv.Error as e:
236
+ logger.error(f"Error reading CSV file {self.file_path}: {e}")
237
+ return samples
238
+
239
+ def _load_txt(self) -> List[Dict[str, Any]]:
240
+ samples = []
241
+ with open(self.file_path, 'r', encoding='utf-8') as txtfile:
242
+ for i, line in enumerate(txtfile):
243
+ line = line.strip()
244
+ if line:
245
+ # Wrap each line in a dictionary.
246
+ samples.append({"text": line})
247
+ return samples
248
+
249
+ def _standardize_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
250
+ """Remaps the sample's keys to a set of standardized keys using self.header_mapping.
251
+ For each standardized key, the first matching header from the sample is used.
252
+ If none is found, a default empty string is assigned."""
253
+ standardized = {}
254
+ for std_field, possible_keys in self.header_mapping.items():
255
+ for key in possible_keys:
256
+ if key in sample:
257
+ standardized[std_field] = sample[key]
258
+ break
259
+ if std_field not in standardized:
260
+ standardized[std_field] = ""
261
+ return standardized
262
+
263
+ def __len__(self) -> int:
264
+ return len(self.samples)
265
+
266
+ def __getitem__(self, index: int) -> Dict[str, Any]:
267
+ sample = self.samples[index]
268
+
269
+ # If a header mapping is provided, standardize the sample keys.
270
+ if self.header_mapping is not None:
271
+ sample = self._standardize_sample(sample)
272
+
273
+ # Determine the text to tokenize:
274
+ # If standardized keys "title" or "content" exist, combine them.
275
+ if 'title' in sample or 'content' in sample:
276
+ title = sample.get('title', '')
277
+ content = sample.get('content', '')
278
+ # Convert non-string fields to strings
279
+ if not isinstance(title, str):
280
+ title = str(title)
281
+ if not isinstance(content, str):
282
+ content = str(content)
283
+ text = (title + " " + content).strip()
284
+ elif "text" in sample:
285
+ text = sample["text"] if isinstance(sample["text"], str) else str(sample["text"])
286
+ else:
287
+ # Fallback: join all values (cast to str)
288
+ text = " ".join(str(v) for v in sample.values())
289
+
290
+ # Tokenize the combined text.
291
+ tokenized = self.tokenizer.encode_plus(
292
+ text,
293
+ max_length=self.max_length,
294
+ padding='max_length',
295
+ truncation=True,
296
+ return_tensors='pt'
297
+ )
298
+
299
+ # Get specialization from sample or use class default
300
+ specialization = None
301
+ if isinstance(sample, dict) and "specialization" in sample:
302
+ specialization = sample["specialization"]
303
+ elif self.specialization:
304
+ specialization = self.specialization
305
+
306
+ # Return a standardized dictionary for training.
307
+ result = {
308
+ "input_ids": tokenized["input_ids"].squeeze(0),
309
+ "attention_mask": tokenized["attention_mask"].squeeze(0),
310
+ "token_type_ids": tokenized.get("token_type_ids", torch.zeros_like(tokenized["input_ids"])).squeeze(0),
311
+ }
312
+
313
+ # Add specialization if available
314
+ if specialization:
315
+ result["specialization"] = specialization
316
+
317
+ # Optionally include standardized text fields if needed
318
+ if 'title' in locals():
319
+ result["title"] = title
320
+ if 'content' in locals():
321
+ result["content"] = content
322
+
323
+ return result
324
+
325
+ # dataset.py - Simple dataset module to fix initialization dependency issues
326
+ import logging
327
+ import os
328
+ import json
329
+ from typing import Dict, List, Any, Optional, Union
330
+
331
+ logger = logging.getLogger(__name__)
332
+
333
+ class DatasetManager:
334
+ """
335
+ Simple dataset manager to provide basic functionality for model_manager
336
+ without requiring external dataset dependencies
337
+ """
338
+ def __init__(self, data_dir: Optional[str] = None):
339
+ self.data_dir = data_dir or os.path.join(os.path.dirname(__file__), "data")
340
+ self.datasets = {}
341
+ self._ensure_data_dir()
342
+
343
+ def _ensure_data_dir(self):
344
+ """Ensure data directory exists"""
345
+ try:
346
+ if not os.path.exists(self.data_dir):
347
+ os.makedirs(self.data_dir, exist_ok=True)
348
+ logger.info(f"Created dataset directory at {self.data_dir}")
349
+ except (PermissionError, OSError) as e:
350
+ logger.warning(f"Could not create data directory: {e}")
351
+ # Fall back to temp directory
352
+ self.data_dir = os.path.join("/tmp", "wildnerve_data")
353
+ os.makedirs(self.data_dir, exist_ok=True)
354
+ logger.info(f"Using fallback data directory at {self.data_dir}")
355
+
356
+ def load_dataset(self, name: str) -> List[Dict[str, Any]]:
357
+ """Load dataset by name"""
358
+ if name in self.datasets:
359
+ return self.datasets[name]
360
+
361
+ # Check for dataset file
362
+ filepath = os.path.join(self.data_dir, f"{name}.json")
363
+ if os.path.exists(filepath):
364
+ try:
365
+ with open(filepath, 'r', encoding='utf-8') as f:
366
+ data = json.load(f)
367
+ self.datasets[name] = data
368
+ return data
369
+ except Exception as e:
370
+ logger.error(f"Error loading dataset {name}: {e}")
371
+
372
+ # Return empty dataset if not found
373
+ logger.warning(f"Dataset {name} not found, returning empty dataset")
374
+ return []
375
+
376
+ def get_dataset_names(self) -> List[str]:
377
+ """Get list of available datasets"""
378
+ try:
379
+ return [f.split('.')[0] for f in os.listdir(self.data_dir)
380
+ if f.endswith('.json')]
381
+ except Exception as e:
382
+ logger.error(f"Error listing datasets: {e}")
383
+ return []
384
+
385
+ def create_sample_dataset(self, name: str, samples: int = 10) -> List[Dict[str, Any]]:
386
+ """Create a sample dataset for testing"""
387
+ data = [
388
+ {
389
+ "id": i,
390
+ "text": f"Sample text {i} for model training",
391
+ "label": i % 2 # Binary label
392
+ }
393
+ for i in range(samples)
394
+ ]
395
+
396
+ # Save to file
397
+ filepath = os.path.join(self.data_dir, f"{name}.json")
398
+ try:
399
+ with open(filepath, 'w', encoding='utf-8') as f:
400
+ json.dump(data, f, indent=2)
401
+ self.datasets[name] = data
402
+ logger.info(f"Created sample dataset {name} with {samples} samples")
403
+ except Exception as e:
404
+ logger.error(f"Error creating sample dataset: {e}")
405
+
406
+ return data
407
+
408
+ def _load_and_process_dataset(self, path_or_paths: Union[str, List[str]], specialization: str) -> TensorDataset:
409
+ # …existing code up to reading the file…
410
+ import pandas as pd
411
+
412
+ # Handle multiple JSON files by concatenation
413
+ if isinstance(path_or_paths, list):
414
+ frames = [pd.read_json(p) for p in path_or_paths]
415
+ data = pd.concat(frames, ignore_index=True)
416
+ else:
417
+ data = pd.read_json(path_or_paths)
418
+
419
+ # …existing code that splits into features/labels and returns TensorDataset…
420
+
421
+ # Create a default dataset manager instance
422
+ dataset_manager = DatasetManager()
423
+
424
+ def get_dataset(name: str) -> List[Dict[str, Any]]:
425
+ """Helper function to get a dataset by name"""
426
+ return dataset_manager.load_dataset(name)
427
+
428
+ # Create some minimal sample data if running as main
429
+ if __name__ == "__main__":
430
+ logging.basicConfig(level=logging.INFO)
431
+ dm = DatasetManager()
432
+ dm.create_sample_dataset("test_dataset", samples=20)
433
+ print(f"Available datasets: {dm.get_dataset_names()}")
434
+ test_data = dm.load_dataset("test_dataset")
435
+ print(f"Loaded {len(test_data)} samples from test_dataset")
dependency_helpers.py CHANGED
@@ -1,118 +1,124 @@
1
  """
2
- Helper utilities for handling dependencies in a graceful manner.
3
- This module provides functions to check for and load dependencies without crashing.
4
  """
5
- import importlib
6
- import logging
7
- import sys
8
  import os
9
- from typing import Optional, Any, Dict, Callable, List
 
 
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
- def safely_import(module_name: str) -> Optional[Any]:
14
- """
15
- Safely import a module without crashing if it's not available.
16
-
17
- Args:
18
- module_name: Name of the module to import
19
-
20
- Returns:
21
- The imported module or None if import failed
22
- """
23
- try:
24
- return importlib.import_module(module_name)
25
- except ImportError as e:
26
- logger.warning(f"Failed to import {module_name}: {e}")
27
- return None
28
 
29
  def is_module_available(module_name: str) -> bool:
30
- """
31
- Check if a module is available without importing it.
32
-
33
- Args:
34
- module_name: Name of the module to check
35
-
36
- Returns:
37
- True if module is available, False otherwise
38
- """
39
- try:
40
- importlib.util.find_spec(module_name)
41
- return True
42
- except ImportError:
43
- return False
44
 
45
- def check_dependencies(dependencies: List[str]) -> Dict[str, bool]:
46
- """
47
- Check multiple dependencies at once.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- Args:
50
- dependencies: List of module names to check
51
-
52
- Returns:
53
- Dictionary mapping module names to availability (True/False)
54
- """
55
- return {dep: is_module_available(dep) for dep in dependencies}
56
 
57
- def get_object_if_available(module_name: str, object_name: str) -> Optional[Any]:
58
- """
59
- Get an object from a module if the module is available.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- Args:
62
- module_name: Name of the module containing the object
63
- object_name: Name of the object to get
64
-
65
- Returns:
66
- The requested object or None if not available
67
- """
68
- module = safely_import(module_name)
69
- if module and hasattr(module, object_name):
70
- return getattr(module, object_name)
71
- return None
 
72
 
73
- def with_fallback(primary_func: Callable, fallback_func: Callable, *args, **kwargs) -> Any:
74
- """
75
- Call primary_func with the given args/kwargs, falling back to fallback_func if it fails.
 
 
 
 
 
76
 
77
- Args:
78
- primary_func: Function to try first
79
- fallback_func: Function to use if primary_func fails
80
- args: Positional arguments to pass to both functions
81
- kwargs: Keyword arguments to pass to both functions
 
 
 
 
 
 
 
 
82
 
83
- Returns:
84
- Result from either primary_func or fallback_func
85
- """
86
- try:
87
- return primary_func(*args, **kwargs)
88
- except Exception as e:
89
- logger.warning(f"Primary function {primary_func.__name__} failed: {e}")
90
- return fallback_func(*args, **kwargs)
91
 
92
- def install_package(package_name: str) -> bool:
93
- """
94
- Attempt to install a package using pip.
95
- Note: This is generally not recommended in production code but can be useful for development.
 
 
 
 
96
 
97
- Args:
98
- package_name: Name of the package to install
 
 
99
 
100
- Returns:
101
- True if installation was successful, False otherwise
102
- """
103
- try:
104
- import subprocess
105
- logger.info(f"Attempting to install {package_name}")
106
- subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])
107
- return True
108
- except Exception as e:
109
- logger.warning(f"Failed to install {package_name}: {e}")
110
- return False
111
 
112
- # Check critical dependencies used in the project
113
- CRITICAL_DEPENDENCIES = ["torch", "transformers", "sentencepiece", "pydantic", "nltk"]
114
- DEPENDENCY_STATUS = check_dependencies(CRITICAL_DEPENDENCIES)
115
 
116
- def get_dependency_status() -> Dict[str, bool]:
117
- """Get the status of critical dependencies."""
118
- return DEPENDENCY_STATUS
 
1
  """
2
+ Dependency helpers to make the model work even if some libraries are missing.
3
+ This file provides fallback implementations for missing dependencies.
4
  """
 
 
 
5
  import os
6
+ import logging
7
+ import importlib.util
8
+ from typing import Any, Dict, Optional, Type, Callable
9
 
10
  logger = logging.getLogger(__name__)
11
 
12
+ # Dictionary to track mock implementations
13
+ MOCK_MODULES = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def is_module_available(module_name: str) -> bool:
16
+ """Check if a module is available without importing it"""
17
+ return importlib.util.find_spec(module_name) is not None
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ def create_mock_emissions_tracker() -> Type:
20
+ """Create a mock implementation of codecarbon's EmissionsTracker"""
21
+ class MockEmissionsTracker:
22
+ def __init__(self, *args, **kwargs):
23
+ logger.info("Using mock EmissionsTracker")
24
+
25
+ def __enter__(self):
26
+ return self
27
+
28
+ def __exit__(self, exc_type, exc_val, exc_tb):
29
+ pass
30
+
31
+ def start(self):
32
+ return self
33
+
34
+ def stop(self):
35
+ return 0.0 # Return zero emissions
36
 
37
+ return MockEmissionsTracker
 
 
 
 
 
 
38
 
39
+ def create_mock_pydantic_classes() -> Dict[str, Type]:
40
+ """Create mock implementations of pydantic classes"""
41
+ class MockBaseModel:
42
+ """Mock implementation of pydantic's BaseModel"""
43
+ def __init__(self, **kwargs):
44
+ for key, value in kwargs.items():
45
+ setattr(self, key, value)
46
+
47
+ def dict(self) -> Dict[str, Any]:
48
+ return {k: v for k, v in self.__dict__.items()
49
+ if not k.startswith('_')}
50
+
51
+ def json(self) -> str:
52
+ import json
53
+ return json.dumps(self.dict())
54
+
55
+ def mock_field(*args, **kwargs) -> Any:
56
+ """Mock implementation of pydantic's Field"""
57
+ return kwargs.get('default', None)
58
 
59
+ class MockValidationError(Exception):
60
+ """Mock implementation of pydantic's ValidationError"""
61
+ pass
62
+
63
+ mock_config_dict = dict
64
+
65
+ return {
66
+ "BaseModel": MockBaseModel,
67
+ "Field": mock_field,
68
+ "ValidationError": MockValidationError,
69
+ "ConfigDict": mock_config_dict
70
+ }
71
 
72
+ def setup_dependency_fallbacks():
73
+ """Setup fallbacks for all required dependencies"""
74
+ # Handle codecarbon
75
+ if not is_module_available("codecarbon"):
76
+ logger.warning("codecarbon not found, using mock implementation")
77
+ MOCK_MODULES["codecarbon"] = type("MockCodecarbon", (), {
78
+ "EmissionsTracker": create_mock_emissions_tracker()
79
+ })
80
 
81
+ # Handle pydantic
82
+ if not is_module_available("pydantic"):
83
+ logger.warning("pydantic not found, using mock implementation")
84
+ mock_classes = create_mock_pydantic_classes()
85
+ MOCK_MODULES["pydantic"] = type("MockPydantic", (), mock_classes)
86
+
87
+ # Setup service_registry fallback if needed
88
+ if not is_module_available("service_registry"):
89
+ from types import SimpleNamespace
90
+ registry_obj = SimpleNamespace()
91
+ registry_obj.register = lambda *args, **kwargs: None
92
+ registry_obj.get = lambda *args: None
93
+ registry_obj.has = lambda *args: False
94
 
95
+ MOCK_MODULES["service_registry"] = type("MockServiceRegistry", (), {
96
+ "registry": registry_obj,
97
+ "MODEL": "MODEL",
98
+ "TOKENIZER": "TOKENIZER"
99
+ })
 
 
 
100
 
101
+ # Custom import hook to provide mock implementations
102
+ class DependencyImportFinder:
103
+ def __init__(self):
104
+ self._mock_modules = MOCK_MODULES
105
+
106
+ def find_module(self, fullname, path=None):
107
+ if fullname in self._mock_modules:
108
+ return self
109
 
110
+ def load_module(self, fullname):
111
+ import sys
112
+ if fullname in sys.modules:
113
+ return sys.modules[fullname]
114
 
115
+ module = self._mock_modules[fullname]
116
+ sys.modules[fullname] = module
117
+ return module
 
 
 
 
 
 
 
 
118
 
119
+ # Initialize the fallbacks
120
+ setup_dependency_fallbacks()
 
121
 
122
+ # Install the custom import hook
123
+ import sys
124
+ sys.meta_path.insert(0, DependencyImportFinder())
handler.py CHANGED
@@ -17,13 +17,33 @@ logging.basicConfig(
17
  )
18
  logger = logging.getLogger(__name__)
19
 
20
- # Verify installed packages
21
  try:
22
  import pydantic
23
- import codecarbon
24
- print(f"Required dependencies are available - pydantic: {pydantic.__version__}, codecarbon: {codecarbon.__version__}")
25
- except ImportError as e:
26
- print(f"WARNING: Missing dependency: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # Make sure adapter_layer.py is properly located
29
  try:
@@ -43,7 +63,14 @@ try:
43
 
44
  except ImportError as e:
45
  logger.error(f"Could not import adapter_layer: {e}")
46
- raise
 
 
 
 
 
 
 
47
 
48
  class EndpointHandler:
49
  def __init__(self, path=""):
 
17
  )
18
  logger = logging.getLogger(__name__)
19
 
20
+ # Safely check for required packages without crashing
21
  try:
22
  import pydantic
23
+ print(f"pydantic is available: {pydantic.__version__}")
24
+ except ImportError:
25
+ print("pydantic is not available - continuing without it")
26
+ # Create minimal compatibility layer
27
+ class pydantic:
28
+ @staticmethod
29
+ def __version__():
30
+ return "unavailable"
31
+
32
+ class BaseModel:
33
+ def __init__(self, **kwargs):
34
+ for k, v in kwargs.items():
35
+ setattr(self, k, v)
36
+
37
+ try:
38
+ from codecarbon import EmissionsTracker
39
+ print(f"codecarbon is available")
40
+ except ImportError:
41
+ print("codecarbon is not available - continuing without carbon tracking")
42
+ # Create minimal compatibility class
43
+ class EmissionsTracker:
44
+ def __init__(self, *args, **kwargs): pass
45
+ def start(self): return self
46
+ def stop(self): return 0.0
47
 
48
  # Make sure adapter_layer.py is properly located
49
  try:
 
63
 
64
  except ImportError as e:
65
  logger.error(f"Could not import adapter_layer: {e}")
66
+ # Don't raise error - provide fallback adapter implementation
67
+ class WildnerveModelAdapter:
68
+ def __init__(self, path=""):
69
+ self.path = path
70
+ logger.info(f"Using fallback WildnerveModelAdapter with path: {path}")
71
+
72
+ def generate(self, text_input, **kwargs):
73
+ return f"Model adapter unavailable. Received input: {text_input[:30]}..."
74
 
75
  class EndpointHandler:
76
  def __init__(self, path=""):
model_Custm.py CHANGED
@@ -1,4 +1,4 @@
1
- # model_Custm.py
2
  import os
3
  import sys
4
  import math
@@ -8,14 +8,15 @@ import numpy as np
8
  import torch.nn as nn
9
  from typing import Optional, List, Dict, Union
10
 
 
 
 
11
  # Import the carbon tracker early - before transformers
12
  try:
13
  from codecarbon import EmissionsTracker
14
  except ImportError:
15
- class EmissionsTracker:
16
- def __init__(self, *args, **kwargs): pass
17
- def start(self): return self
18
- def stop(self): return 0.0
19
 
20
  # Apply patches before importing transformers
21
  import transformer_patches
 
1
+ # model_Custm.py - with dependency fallbacks
2
  import os
3
  import sys
4
  import math
 
8
  import torch.nn as nn
9
  from typing import Optional, List, Dict, Union
10
 
11
+ # Import dependency helpers first
12
+ import dependency_helpers
13
+
14
  # Import the carbon tracker early - before transformers
15
  try:
16
  from codecarbon import EmissionsTracker
17
  except ImportError:
18
+ # Use the mock from dependency_helpers
19
+ EmissionsTracker = dependency_helpers.create_mock_emissions_tracker()
 
 
20
 
21
  # Apply patches before importing transformers
22
  import transformer_patches
service_registry.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple service registry for dependency injection
3
+ """
4
+ import logging
5
+ from typing import Any, Dict, Optional
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ # Constants used as keys
10
+ MODEL = "model"
11
+ TOKENIZER = "tokenizer"
12
+
13
+ class ServiceRegistry:
14
+ """A simple service registry for dependency management"""
15
+
16
+ def __init__(self):
17
+ self._services = {}
18
+
19
+ def register(self, key: str, service: Any, overwrite: bool = False) -> None:
20
+ """Register a service with the given key"""
21
+ if key in self._services and not overwrite:
22
+ logger.warning(f"Service with key '{key}' already registered")
23
+ return
24
+
25
+ self._services[key] = service
26
+ logger.debug(f"Registered service with key: {key}")
27
+
28
+ def get(self, key: str) -> Optional[Any]:
29
+ """Get a service by its key"""
30
+ if key not in self._services:
31
+ logger.warning(f"No service registered with key: {key}")
32
+ return None
33
+
34
+ return self._services[key]
35
+
36
+ def has(self, key: str) -> bool:
37
+ """Check if a service with the given key exists"""
38
+ return key in self._services
39
+
40
+ def clear(self) -> None:
41
+ """Clear all registered services"""
42
+ self._services.clear()
43
+
44
+ # Create singleton instance
45
+ registry = ServiceRegistry()
transformer_patches.py CHANGED
@@ -211,3 +211,26 @@ if __name__ == "__main__":
211
  print("\nPatch status:")
212
  for patch, status in _patch_status.items():
213
  print(f" {'✓' if status else '✗'} {patch}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  print("\nPatch status:")
212
  for patch, status in _patch_status.items():
213
  print(f" {'✓' if status else '✗'} {patch}")
214
+
215
+ """
216
+ Transformer patches to make the model work better with HuggingFace transformers.
217
+ This file applies monkey patches to fix compatibility issues or add functionality.
218
+ """
219
+ import logging
220
+ from typing import Dict, Any, Optional
221
+
222
+ logger = logging.getLogger(__name__)
223
+
224
+ def apply_transformer_patches():
225
+ """Apply monkey patches to transformers if needed"""
226
+ try:
227
+ import transformers
228
+ logger.info(f"Applying patches to transformers v{transformers.__version__}")
229
+
230
+ # No patches needed currently, but you can add them here if needed in future
231
+
232
+ except ImportError:
233
+ logger.warning("Transformers library not found, skipping patches")
234
+
235
+ # Apply patches when imported
236
+ apply_transformer_patches()