Upload 8 files
Browse files- config.py +49 -45
- dataloader.py +340 -0
- dataset.py +435 -0
- dependency_helpers.py +104 -98
- handler.py +33 -6
- model_Custm.py +6 -5
- service_registry.py +45 -0
- transformer_patches.py +23 -0
config.py
CHANGED
|
@@ -6,13 +6,17 @@ import argparse
|
|
| 6 |
from pathlib import Path
|
| 7 |
from typing import Optional, Dict, List, Literal, Any
|
| 8 |
|
|
|
|
|
|
|
|
|
|
| 9 |
# --- gracefully handle missing pydantic ---
|
| 10 |
try:
|
| 11 |
from pydantic import BaseModel, Field, ValidationError, ConfigDict
|
| 12 |
except ImportError:
|
| 13 |
-
import
|
|
|
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
-
logger.warning("pydantic not available, using dummy
|
| 16 |
class BaseModel:
|
| 17 |
def __init__(self, **kwargs):
|
| 18 |
for k, v in kwargs.items(): setattr(self, k, v)
|
|
@@ -103,82 +107,82 @@ SPECIALIZATIONS = [
|
|
| 103 |
# Define DATASET_PATHS so that each specialization is a string or a list of strings
|
| 104 |
DATASET_PATHS = {
|
| 105 |
"computer": [
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
],
|
| 116 |
|
| 117 |
"cpp": [
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
],
|
| 123 |
|
| 124 |
"java": [
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
],
|
| 130 |
|
| 131 |
"go": [
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
],
|
| 136 |
|
| 137 |
"javascript": [
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
],
|
| 143 |
|
| 144 |
"nim": [
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
],
|
| 151 |
|
| 152 |
"python": [
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
],
|
| 158 |
|
| 159 |
"rust": [
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
],
|
| 165 |
|
| 166 |
"solidity": [
|
| 167 |
-
|
| 168 |
],
|
| 169 |
|
| 170 |
"mathematics": [
|
| 171 |
-
|
| 172 |
-
|
| 173 |
],
|
| 174 |
|
| 175 |
"physics": [
|
| 176 |
-
|
| 177 |
-
|
| 178 |
],
|
| 179 |
|
| 180 |
"other_information": [
|
| 181 |
-
|
| 182 |
]
|
| 183 |
}
|
| 184 |
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
from typing import Optional, Dict, List, Literal, Any
|
| 8 |
|
| 9 |
+
# Import dependency helpers first
|
| 10 |
+
import dependency_helpers
|
| 11 |
+
|
| 12 |
# --- gracefully handle missing pydantic ---
|
| 13 |
try:
|
| 14 |
from pydantic import BaseModel, Field, ValidationError, ConfigDict
|
| 15 |
except ImportError:
|
| 16 |
+
# The import error should be handled by dependency_helpers
|
| 17 |
+
# But we'll add one more fallback just to be safe
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
+
logger.warning("pydantic not available, using dummy implementation")
|
| 20 |
class BaseModel:
|
| 21 |
def __init__(self, **kwargs):
|
| 22 |
for k, v in kwargs.items(): setattr(self, k, v)
|
|
|
|
| 107 |
# Define DATASET_PATHS so that each specialization is a string or a list of strings
|
| 108 |
DATASET_PATHS = {
|
| 109 |
"computer": [
|
| 110 |
+
str(DATA_DIR / "data" / "computer_advanced_debugging.json"),
|
| 111 |
+
str(DATA_DIR / "data" / "computer_agenticAI.json"),
|
| 112 |
+
str(DATA_DIR / "data" / "computer_architecture.json"),
|
| 113 |
+
str(DATA_DIR / "data" / "computer_cloud_security.json"),
|
| 114 |
+
str(DATA_DIR / "data" / "computer_creativity.json"),
|
| 115 |
+
str(DATA_DIR / "data" / "computer_crossplatform.json"),
|
| 116 |
+
str(DATA_DIR / "data" / "computer_cybersecurity.json"),
|
| 117 |
+
str(DATA_DIR / "data" / "computer_error_handling_examples.json"),
|
| 118 |
+
str(DATA_DIR / "data" / "computer_gitInstruct.json")
|
| 119 |
],
|
| 120 |
|
| 121 |
"cpp": [
|
| 122 |
+
str(DATA_DIR / "data" / "cpp_advanced_debugging.json"),
|
| 123 |
+
str(DATA_DIR / "data" / "cpp_blockchain.json"),
|
| 124 |
+
str(DATA_DIR / "data" / "cpp_mbcppp.json"),
|
| 125 |
+
str(DATA_DIR / "data" / "cpp_programming.json")
|
| 126 |
],
|
| 127 |
|
| 128 |
"java": [
|
| 129 |
+
str(DATA_DIR / "data" / "java_ai_language_model.json"),
|
| 130 |
+
str(DATA_DIR / "data" / "java_blockchain.json"),
|
| 131 |
+
str(DATA_DIR / "data" / "java_mbjp.json"),
|
| 132 |
+
str(DATA_DIR / "data" / "java_transformer_language_model.json"),
|
| 133 |
],
|
| 134 |
|
| 135 |
"go": [
|
| 136 |
+
str(DATA_DIR / "data" / "golang_ai_language_model.json"),
|
| 137 |
+
str(DATA_DIR / "data" / "golang_mbgp.json"),
|
| 138 |
+
str(DATA_DIR / "data" / "golang_programming.json")
|
| 139 |
],
|
| 140 |
|
| 141 |
"javascript": [
|
| 142 |
+
str(DATA_DIR / "data" / "javascript_chatbot.json"),
|
| 143 |
+
str(DATA_DIR / "data" / "javascript_n_Typescript_frontend.json"),
|
| 144 |
+
str(DATA_DIR / "data" / "javascript_n_Typescript_backend.json"),
|
| 145 |
+
str(DATA_DIR / "data" / "javascript_programming.json")
|
| 146 |
],
|
| 147 |
|
| 148 |
"nim": [
|
| 149 |
+
str(DATA_DIR / "data" / "nim_ai_language_model.json"),
|
| 150 |
+
str(DATA_DIR / "data" / "nim_blockchain.json"),
|
| 151 |
+
str(DATA_DIR / "data" / "nim_chatbot.json"),
|
| 152 |
+
str(DATA_DIR / "data" / "nim_mbnp.json"),
|
| 153 |
+
str(DATA_DIR / "data" / "nim_programming.json")
|
| 154 |
],
|
| 155 |
|
| 156 |
"python": [
|
| 157 |
+
str(DATA_DIR / "data" / "python_chatbot_guide.json"),
|
| 158 |
+
str(DATA_DIR / "data" / "python_mbpp.json"),
|
| 159 |
+
str(DATA_DIR / "data" / "python_programming.json"),
|
| 160 |
+
str(DATA_DIR / "data" / "python_transformer_model.json")
|
| 161 |
],
|
| 162 |
|
| 163 |
"rust": [
|
| 164 |
+
str(DATA_DIR / "data" / "rust_ai_language_model.json"),
|
| 165 |
+
str(DATA_DIR / "data" / "rust_blockchain.json"),
|
| 166 |
+
str(DATA_DIR / "data" / "rust_mbrp.json"),
|
| 167 |
+
str(DATA_DIR / "data" / "rust_programming.json")
|
| 168 |
],
|
| 169 |
|
| 170 |
"solidity": [
|
| 171 |
+
str(DATA_DIR / "data" / "solidity_programming.json")
|
| 172 |
],
|
| 173 |
|
| 174 |
"mathematics": [
|
| 175 |
+
str(DATA_DIR / "data" / "mathematics.json"),
|
| 176 |
+
str(DATA_DIR / "data" / "mathematics_training.json")
|
| 177 |
],
|
| 178 |
|
| 179 |
"physics": [
|
| 180 |
+
str(DATA_DIR / "data" / "physics_n_engineering.json"),
|
| 181 |
+
str(DATA_DIR / "data" / "physics_n_engineering_applied.json")
|
| 182 |
],
|
| 183 |
|
| 184 |
"other_information": [
|
| 185 |
+
str(DATA_DIR / "data" / "other_information.json")
|
| 186 |
]
|
| 187 |
}
|
| 188 |
|
dataloader.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data loader factory and utilities for transformer models.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import torch
|
| 7 |
+
import logging
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
from typing import Dict, List, Optional, Union, Any, Tuple
|
| 11 |
+
from torch.utils.data import Dataset, DataLoader, TensorDataset
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from config import app_config
|
| 14 |
+
from tokenizer import TokenizerWrapper
|
| 15 |
+
from datagrower.Crawl4MyAI import AdvancedWebCrawler
|
| 16 |
+
from datagrower.Webconverter import WebConverter
|
| 17 |
+
from dataset import DatasetManager
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
class TransformerDataset(Dataset):
|
| 22 |
+
"""Base dataset for transformer models that handles multiple input formats."""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
data_path: str,
|
| 27 |
+
tokenizer: TokenizerWrapper,
|
| 28 |
+
max_length: int = 512,
|
| 29 |
+
format_type: str = None
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
Initialize dataset.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
data_path: Path to the data file
|
| 36 |
+
tokenizer: Tokenizer to use for encoding
|
| 37 |
+
max_length: Maximum sequence length
|
| 38 |
+
format_type: Format of data file ('csv', 'json', 'txt')
|
| 39 |
+
"""
|
| 40 |
+
self.data_path = data_path
|
| 41 |
+
self.tokenizer = tokenizer
|
| 42 |
+
self.max_length = max_length
|
| 43 |
+
self.format_type = format_type or self._detect_format(data_path)
|
| 44 |
+
|
| 45 |
+
# Load data
|
| 46 |
+
self.data = self._load_data()
|
| 47 |
+
logger.info(f"Loaded {len(self.data)} samples from {data_path}")
|
| 48 |
+
|
| 49 |
+
def _detect_format(self, path: str) -> str:
|
| 50 |
+
"""Detect file format from extension."""
|
| 51 |
+
ext = os.path.splitext(path)[1].lower().lstrip('.')
|
| 52 |
+
if ext in ['csv']:
|
| 53 |
+
return 'csv'
|
| 54 |
+
elif ext in ['json']:
|
| 55 |
+
return 'json'
|
| 56 |
+
elif ext in ['txt', 'text']:
|
| 57 |
+
return 'txt'
|
| 58 |
+
else:
|
| 59 |
+
logger.warning(f"Unknown file extension: {ext}, defaulting to CSV")
|
| 60 |
+
return 'csv'
|
| 61 |
+
|
| 62 |
+
def _load_data(self) -> List[Dict[str, Any]]:
|
| 63 |
+
"""Load data based on format type."""
|
| 64 |
+
if not os.path.exists(self.data_path):
|
| 65 |
+
raise FileNotFoundError(f"Data file not found: {self.data_path}")
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
if self.format_type == 'csv':
|
| 69 |
+
return self._load_csv()
|
| 70 |
+
elif self.format_type == 'json':
|
| 71 |
+
return self._load_json()
|
| 72 |
+
elif self.format_type == 'txt':
|
| 73 |
+
return self._load_txt()
|
| 74 |
+
else:
|
| 75 |
+
raise ValueError(f"Unsupported format type: {self.format_type}")
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.error(f"Error loading data from {self.data_path}: {e}")
|
| 78 |
+
raise
|
| 79 |
+
|
| 80 |
+
def _load_csv(self) -> List[Dict[str, Any]]:
|
| 81 |
+
"""Load data from CSV file."""
|
| 82 |
+
df = pd.read_csv(self.data_path)
|
| 83 |
+
# Check for required columns
|
| 84 |
+
if 'text' not in df.columns:
|
| 85 |
+
# Try to find a column with text data
|
| 86 |
+
text_cols = [col for col in df.columns if 'text' in col.lower() or 'content' in col.lower()]
|
| 87 |
+
if text_cols:
|
| 88 |
+
df = df.rename(columns={text_cols[0]: 'text'})
|
| 89 |
+
else:
|
| 90 |
+
# Use the first column as text
|
| 91 |
+
df = df.rename(columns={df.columns[0]: 'text'})
|
| 92 |
+
|
| 93 |
+
# Check for label column
|
| 94 |
+
if 'label' not in df.columns and len(df.columns) > 1:
|
| 95 |
+
# Use the second column as label if present
|
| 96 |
+
df = df.rename(columns={df.columns[1]: 'label'})
|
| 97 |
+
|
| 98 |
+
return df.to_dict('records')
|
| 99 |
+
|
| 100 |
+
def _load_json(self) -> List[Dict[str, Any]]:
|
| 101 |
+
"""Load data from JSON file."""
|
| 102 |
+
with open(self.data_path, 'r', encoding='utf-8') as f:
|
| 103 |
+
data = json.load(f)
|
| 104 |
+
|
| 105 |
+
# Handle different JSON formats
|
| 106 |
+
if isinstance(data, list):
|
| 107 |
+
# Already in list format
|
| 108 |
+
return data
|
| 109 |
+
elif isinstance(data, dict):
|
| 110 |
+
# Extract data from dictionary
|
| 111 |
+
if 'data' in data:
|
| 112 |
+
return data['data']
|
| 113 |
+
elif 'examples' in data:
|
| 114 |
+
return data['examples']
|
| 115 |
+
elif 'user_inputs' in data:
|
| 116 |
+
return data['user_inputs']
|
| 117 |
+
else:
|
| 118 |
+
# Convert flat dictionary to list
|
| 119 |
+
return [{'text': str(value), 'id': key} for key, value in data.items()]
|
| 120 |
+
else:
|
| 121 |
+
raise ValueError(f"Unsupported JSON data structure: {type(data)}")
|
| 122 |
+
|
| 123 |
+
def _load_txt(self) -> List[Dict[str, Any]]:
|
| 124 |
+
"""Load data from text file, one sample per line."""
|
| 125 |
+
with open(self.data_path, 'r', encoding='utf-8') as f:
|
| 126 |
+
lines = f.readlines()
|
| 127 |
+
|
| 128 |
+
# Clean and convert to dictionaries
|
| 129 |
+
return [{'text': line.strip(), 'id': i} for i, line in enumerate(lines) if line.strip()]
|
| 130 |
+
|
| 131 |
+
def __len__(self) -> int:
|
| 132 |
+
"""Get dataset length."""
|
| 133 |
+
return len(self.data)
|
| 134 |
+
|
| 135 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 136 |
+
"""Get an item from the dataset."""
|
| 137 |
+
item = self.data[idx]
|
| 138 |
+
text = item.get('text', '')
|
| 139 |
+
|
| 140 |
+
# Handle empty text
|
| 141 |
+
if not text:
|
| 142 |
+
text = " " # Use space to avoid tokenizer errors
|
| 143 |
+
|
| 144 |
+
# Tokenize text
|
| 145 |
+
encoding = self.tokenizer(
|
| 146 |
+
text,
|
| 147 |
+
max_length=self.max_length,
|
| 148 |
+
padding="max_length",
|
| 149 |
+
truncation=True,
|
| 150 |
+
return_tensors="pt"
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Extract tensors and squeeze batch dimension
|
| 154 |
+
input_ids = encoding["input_ids"].squeeze(0)
|
| 155 |
+
attention_mask = encoding["attention_mask"].squeeze(0)
|
| 156 |
+
|
| 157 |
+
# Get label if available
|
| 158 |
+
label = item.get('label', 0)
|
| 159 |
+
if isinstance(label, str):
|
| 160 |
+
try:
|
| 161 |
+
label = float(label)
|
| 162 |
+
except ValueError:
|
| 163 |
+
# Use hash of string for categorical labels
|
| 164 |
+
label = hash(label) % 100 # Limit to 100 categories
|
| 165 |
+
|
| 166 |
+
return {
|
| 167 |
+
'input_ids': input_ids,
|
| 168 |
+
'attention_mask': attention_mask,
|
| 169 |
+
'labels': torch.tensor(label, dtype=torch.long)
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
def prepare_data_loaders_extended(
|
| 173 |
+
data_path: Union[str, Dict[str, str]],
|
| 174 |
+
tokenizer: Any,
|
| 175 |
+
batch_size: int = 16,
|
| 176 |
+
max_length: int = 512,
|
| 177 |
+
val_split: float = 0.1,
|
| 178 |
+
format_type: Optional[str] = None,
|
| 179 |
+
num_workers: int = 0
|
| 180 |
+
) -> Dict[str, DataLoader]:
|
| 181 |
+
"""
|
| 182 |
+
Create data loaders for training and validation.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
data_path: Path to data file or dictionary mapping split to path
|
| 186 |
+
tokenizer: Tokenizer to use for encoding
|
| 187 |
+
batch_size: Batch size
|
| 188 |
+
max_length: Maximum sequence length
|
| 189 |
+
val_split: Validation split ratio when only one path is provided
|
| 190 |
+
format_type: Format of data file
|
| 191 |
+
num_workers: Number of workers for DataLoader
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
Dictionary mapping split names to DataLoaders
|
| 195 |
+
"""
|
| 196 |
+
data_loaders = {}
|
| 197 |
+
|
| 198 |
+
# Handle different types of data_path
|
| 199 |
+
if isinstance(data_path, dict):
|
| 200 |
+
# Multiple paths for different splits
|
| 201 |
+
for split_name, path in data_path.items():
|
| 202 |
+
dataset = TransformerDataset(
|
| 203 |
+
data_path=path,
|
| 204 |
+
tokenizer=tokenizer,
|
| 205 |
+
max_length=max_length,
|
| 206 |
+
format_type=format_type
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
data_loaders[split_name] = DataLoader(
|
| 210 |
+
dataset,
|
| 211 |
+
batch_size=batch_size,
|
| 212 |
+
shuffle=(split_name == 'train'),
|
| 213 |
+
num_workers=num_workers
|
| 214 |
+
)
|
| 215 |
+
else:
|
| 216 |
+
# Single path, create train/val split
|
| 217 |
+
dataset = TransformerDataset(
|
| 218 |
+
data_path=data_path,
|
| 219 |
+
tokenizer=tokenizer,
|
| 220 |
+
max_length=max_length,
|
| 221 |
+
format_type=format_type
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# Split dataset
|
| 225 |
+
val_size = int(len(dataset) * val_split)
|
| 226 |
+
train_size = len(dataset) - val_size
|
| 227 |
+
|
| 228 |
+
if val_size > 0:
|
| 229 |
+
train_dataset, val_dataset = torch.utils.data.random_split(
|
| 230 |
+
dataset, [train_size, val_size]
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
data_loaders['train'] = DataLoader(
|
| 234 |
+
train_dataset,
|
| 235 |
+
batch_size=batch_size,
|
| 236 |
+
shuffle=True,
|
| 237 |
+
num_workers=num_workers
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
data_loaders['validation'] = DataLoader(
|
| 241 |
+
val_dataset,
|
| 242 |
+
batch_size=batch_size,
|
| 243 |
+
shuffle=False,
|
| 244 |
+
num_workers=num_workers
|
| 245 |
+
)
|
| 246 |
+
else:
|
| 247 |
+
# No validation split
|
| 248 |
+
data_loaders['train'] = DataLoader(
|
| 249 |
+
dataset,
|
| 250 |
+
batch_size=batch_size,
|
| 251 |
+
shuffle=True,
|
| 252 |
+
num_workers=num_workers
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
return data_loaders
|
| 256 |
+
|
| 257 |
+
def prepare_data_loaders(
|
| 258 |
+
data_path: str,
|
| 259 |
+
tokenizer: Any,
|
| 260 |
+
batch_size: int = 16,
|
| 261 |
+
val_split: float = 0.1
|
| 262 |
+
) -> Tuple[DataLoader, Optional[DataLoader]]:
|
| 263 |
+
"""
|
| 264 |
+
Simplified version that returns train and validation loaders directly.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
data_path: Path to data file
|
| 268 |
+
tokenizer: Tokenizer to use for encoding
|
| 269 |
+
batch_size: Batch size
|
| 270 |
+
val_split: Validation split ratio
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
Tuple of (train_loader, val_loader)
|
| 274 |
+
"""
|
| 275 |
+
loaders = prepare_data_loaders_extended(
|
| 276 |
+
data_path=data_path,
|
| 277 |
+
tokenizer=tokenizer,
|
| 278 |
+
batch_size=batch_size,
|
| 279 |
+
val_split=val_split
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
train_loader = loaders.get('train')
|
| 283 |
+
val_loader = loaders.get('validation')
|
| 284 |
+
|
| 285 |
+
return train_loader, val_loader
|
| 286 |
+
|
| 287 |
+
def load_dataset(
|
| 288 |
+
specialization: str,
|
| 289 |
+
tokenizer: Any = None,
|
| 290 |
+
split: str = 'train'
|
| 291 |
+
) -> Dataset:
|
| 292 |
+
"""
|
| 293 |
+
Load a dataset for a specific specialization.
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
specialization: Name of the specialization
|
| 297 |
+
tokenizer: Tokenizer to use (optional)
|
| 298 |
+
split: Dataset split to load
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
Dataset instance
|
| 302 |
+
"""
|
| 303 |
+
# Get dataset path from config
|
| 304 |
+
if hasattr(app_config, 'DATASET_PATHS') and specialization in app_config.DATASET_PATHS:
|
| 305 |
+
data_path = app_config.DATASET_PATHS[specialization]
|
| 306 |
+
else:
|
| 307 |
+
data_path = os.path.join(app_config.BASE_DATA_DIR, f"{specialization}.csv")
|
| 308 |
+
|
| 309 |
+
# Get or create tokenizer
|
| 310 |
+
if tokenizer is None:
|
| 311 |
+
from tokenizer import TokenizerWrapper
|
| 312 |
+
tokenizer = TokenizerWrapper()
|
| 313 |
+
|
| 314 |
+
# handle URL paths first via crawler + converter
|
| 315 |
+
if data_path.startswith("http://") or data_path.startswith("https://"):
|
| 316 |
+
crawler = AdvancedWebCrawler()
|
| 317 |
+
converter = WebConverter(crawler=crawler)
|
| 318 |
+
raw_entries = converter.get_converted_web_data([data_path])
|
| 319 |
+
# assume raw_entries is list of dicts {"text":…, "label":…}
|
| 320 |
+
return TransformerDataset(data_path=data_path, tokenizer=tokenizer)._process_records(raw_entries)
|
| 321 |
+
|
| 322 |
+
# Create dataset
|
| 323 |
+
dataset = TransformerDataset(
|
| 324 |
+
data_path=data_path,
|
| 325 |
+
tokenizer=tokenizer,
|
| 326 |
+
max_length=app_config.TRANSFORMER_CONFIG.MAX_SEQ_LENGTH
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
return dataset
|
| 330 |
+
|
| 331 |
+
def load_for_specialization(spec: str):
|
| 332 |
+
paths = app_config.get("DATASET_PATHS", {}).get(spec, [])
|
| 333 |
+
# normalize to list
|
| 334 |
+
if isinstance(paths, str):
|
| 335 |
+
paths = [paths]
|
| 336 |
+
manager = DatasetManager()
|
| 337 |
+
return manager.load_dataset(paths, spec)
|
| 338 |
+
|
| 339 |
+
# Short alias for common use case
|
| 340 |
+
get_dataloader = prepare_data_loaders
|
dataset.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# dataset.py
|
| 2 |
+
import os
|
| 3 |
+
import csv
|
| 4 |
+
import json
|
| 5 |
+
import torch
|
| 6 |
+
import logging
|
| 7 |
+
from preprocess import Preprocessor
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
from typing import List, Dict, Any, Optional, Union
|
| 10 |
+
from functools import wraps
|
| 11 |
+
from time import time
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
def safe_file_operation(func):
|
| 16 |
+
"""Decorator to safely handle file operations with timeout"""
|
| 17 |
+
@wraps(func)
|
| 18 |
+
def wrapper(self, *args, **kwargs):
|
| 19 |
+
start_time = time()
|
| 20 |
+
timeout_seconds = 300 # 5-minute timeout
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
# Try to perform the operation
|
| 24 |
+
result = func(self, *args, **kwargs)
|
| 25 |
+
|
| 26 |
+
# Check if operation took too long
|
| 27 |
+
if time() - start_time > timeout_seconds:
|
| 28 |
+
logger.warning(f"File operation {func.__name__} took more than {timeout_seconds} seconds")
|
| 29 |
+
|
| 30 |
+
return result
|
| 31 |
+
except (IOError, OSError) as e:
|
| 32 |
+
logger.error(f"File operation error in {func.__name__}: {str(e)}")
|
| 33 |
+
# Return empty result based on function type
|
| 34 |
+
if func.__name__.startswith('_load_'):
|
| 35 |
+
return []
|
| 36 |
+
raise
|
| 37 |
+
except json.JSONDecodeError as e:
|
| 38 |
+
logger.error(f"JSON decode error in {self.file_path}: {str(e)}")
|
| 39 |
+
return []
|
| 40 |
+
except csv.Error as e:
|
| 41 |
+
logger.error(f"CSV error in {self.file_path}: {str(e)}")
|
| 42 |
+
return []
|
| 43 |
+
except Exception as e:
|
| 44 |
+
logger.error(f"Unexpected error in {func.__name__}: {str(e)}")
|
| 45 |
+
raise
|
| 46 |
+
|
| 47 |
+
return wrapper
|
| 48 |
+
|
| 49 |
+
class TensorDataset(Dataset):
|
| 50 |
+
"""Dataset class for handling tensor data with features and labels."""
|
| 51 |
+
def __init__(self, features, labels):
|
| 52 |
+
"""
|
| 53 |
+
Initialize TensorDataset.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
features (Tensor): Feature tensors.
|
| 57 |
+
labels (Tensor): Label tensors.
|
| 58 |
+
"""
|
| 59 |
+
self.features = features
|
| 60 |
+
self.labels = labels
|
| 61 |
+
|
| 62 |
+
def __len__(self):
|
| 63 |
+
return len(self.features)
|
| 64 |
+
|
| 65 |
+
def __getitem__(self, idx):
|
| 66 |
+
return self.features[idx], self.labels[idx]
|
| 67 |
+
|
| 68 |
+
class CustomDataset(Dataset):
|
| 69 |
+
"""A dataset that supports loading JSON, CSV, and TXT formats.
|
| 70 |
+
It auto-detects the file type (if not specified) and filters out any
|
| 71 |
+
records that are not dictionaries. If a preprocessor is provided, it
|
| 72 |
+
applies it to each record. Additionally, it can standardize sample keys
|
| 73 |
+
dynamically using a provided header mapping. For example, you can define a
|
| 74 |
+
mapping like:
|
| 75 |
+
mapping = {
|
| 76 |
+
"title": ["Title", "Headline", "Article Title"],
|
| 77 |
+
"content": ["Content", "Body", "Text"],
|
| 78 |
+
}
|
| 79 |
+
so that regardless of the CSV's header names your trainer always sees a
|
| 80 |
+
standardized set of keys."""
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
file_path: Optional[str] = None,
|
| 84 |
+
tokenizer = None,
|
| 85 |
+
max_length: Optional[int] = None,
|
| 86 |
+
file_format: Optional[str] = None,
|
| 87 |
+
preprocessor: Optional[Preprocessor] = None,
|
| 88 |
+
header_mapping: Optional[Dict[str, List[str]]] = None,
|
| 89 |
+
data: Optional[List[Dict[str, Any]]] = None, # Add data parameter
|
| 90 |
+
specialization: Optional[str] = None # Add specialization parameter
|
| 91 |
+
):
|
| 92 |
+
"""Args:
|
| 93 |
+
file_path (Optional[str]): Path to the dataset file.
|
| 94 |
+
tokenizer: Tokenizer instance to process the text.
|
| 95 |
+
max_length (Optional[int]): Maximum sequence length.
|
| 96 |
+
file_format (Optional[str]): Format of the file; inferred from the extension if not provided.
|
| 97 |
+
preprocessor (Optional[Preprocessor]): Preprocessor to apply to each sample.
|
| 98 |
+
header_mapping (Optional[Dict[str, List[str]]]): Dictionary that maps standardized keys.
|
| 99 |
+
data (Optional[List[Dict[str, Any]]]): Direct data input instead of loading from file.
|
| 100 |
+
specialization (Optional[str]): Specialization field for the dataset."""
|
| 101 |
+
|
| 102 |
+
self.file_path = file_path
|
| 103 |
+
self.tokenizer = tokenizer
|
| 104 |
+
self.max_length = max_length
|
| 105 |
+
self.preprocessor = preprocessor
|
| 106 |
+
self.header_mapping = header_mapping
|
| 107 |
+
self.specialization = specialization # Store the specialization
|
| 108 |
+
|
| 109 |
+
# Initialize samples either from data or file
|
| 110 |
+
if data is not None:
|
| 111 |
+
self.samples = data
|
| 112 |
+
else:
|
| 113 |
+
# Determine the file format if not specified and file_path is provided
|
| 114 |
+
if file_path is not None:
|
| 115 |
+
if file_format is None:
|
| 116 |
+
_, ext = os.path.splitext(file_path)
|
| 117 |
+
ext = ext.lower()
|
| 118 |
+
if ext in ['.json']:
|
| 119 |
+
file_format = 'json'
|
| 120 |
+
elif ext in ['.csv']:
|
| 121 |
+
file_format = 'csv'
|
| 122 |
+
elif ext in ['.txt']:
|
| 123 |
+
file_format = 'txt'
|
| 124 |
+
else:
|
| 125 |
+
logger.error(f"Unsupported file extension: {ext}")
|
| 126 |
+
raise ValueError(f"Unsupported file extension: {ext}")
|
| 127 |
+
|
| 128 |
+
self.file_format = file_format
|
| 129 |
+
self.samples = self._load_file()
|
| 130 |
+
else:
|
| 131 |
+
self.samples = []
|
| 132 |
+
|
| 133 |
+
# Auto-detection: Ensure all loaded samples are dicts.
|
| 134 |
+
initial_sample_count = len(self.samples)
|
| 135 |
+
self.samples = [sample for sample in self.samples if isinstance(sample, dict)]
|
| 136 |
+
if len(self.samples) < initial_sample_count:
|
| 137 |
+
logger.warning(f"Filtered out {initial_sample_count - len(self.samples)} samples that were not dicts.")
|
| 138 |
+
|
| 139 |
+
# If a preprocessor is provided, apply preprocessing to each record.
|
| 140 |
+
if self.preprocessor:
|
| 141 |
+
preprocessed_samples = []
|
| 142 |
+
for sample in self.samples:
|
| 143 |
+
try:
|
| 144 |
+
processed = self.preprocessor.preprocess_record(sample)
|
| 145 |
+
preprocessed_samples.append(processed)
|
| 146 |
+
except Exception as e:
|
| 147 |
+
logger.error(f"Error preprocessing record {sample}: {e}")
|
| 148 |
+
self.samples = preprocessed_samples
|
| 149 |
+
|
| 150 |
+
def _load_file(self) -> List[Dict[str, Any]]:
|
| 151 |
+
try:
|
| 152 |
+
if self.file_format == 'json':
|
| 153 |
+
return self._load_json()
|
| 154 |
+
elif self.file_format == 'csv':
|
| 155 |
+
return self._load_csv()
|
| 156 |
+
elif self.file_format == 'txt':
|
| 157 |
+
return self._load_txt()
|
| 158 |
+
else:
|
| 159 |
+
logger.error(f"Unrecognized file format: {self.file_format}")
|
| 160 |
+
raise ValueError(f"Unrecognized file format: {self.file_format}")
|
| 161 |
+
except Exception as e:
|
| 162 |
+
logger.error(f"Error loading file {self.file_path}: {e}")
|
| 163 |
+
raise
|
| 164 |
+
|
| 165 |
+
@safe_file_operation
|
| 166 |
+
def _load_json(self) -> List[Dict[str, Any]]:
|
| 167 |
+
"""Load JSON file with better error handling and validation"""
|
| 168 |
+
try:
|
| 169 |
+
with open(self.file_path, 'r', encoding='utf-8') as f:
|
| 170 |
+
data = json.load(f)
|
| 171 |
+
|
| 172 |
+
# Validate data structure
|
| 173 |
+
if isinstance(data, list):
|
| 174 |
+
valid_records = [record for record in data if isinstance(record, dict)]
|
| 175 |
+
if len(valid_records) < len(data):
|
| 176 |
+
logger.warning(f"{len(data) - len(valid_records)} records were not dictionaries in {self.file_path}")
|
| 177 |
+
return valid_records
|
| 178 |
+
elif isinstance(data, dict):
|
| 179 |
+
# Handle single record case
|
| 180 |
+
logger.warning(f"JSON file contains a single dictionary, not a list: {self.file_path}")
|
| 181 |
+
return [data]
|
| 182 |
+
else:
|
| 183 |
+
logger.error(f"JSON file does not contain a list or dictionary: {self.file_path}")
|
| 184 |
+
return []
|
| 185 |
+
except json.JSONDecodeError as e:
|
| 186 |
+
line_col = f"line {e.lineno}, column {e.colno}"
|
| 187 |
+
logger.error(f"JSON decode error at {line_col} in {self.file_path}: {e.msg}")
|
| 188 |
+
# Try to recover partial content
|
| 189 |
+
try:
|
| 190 |
+
with open(self.file_path, 'r', encoding='utf-8') as f:
|
| 191 |
+
content = f.read()
|
| 192 |
+
# Try parsing up to the error
|
| 193 |
+
valid_part = content[:e.pos]
|
| 194 |
+
import re
|
| 195 |
+
# Find complete objects (rough approach)
|
| 196 |
+
matches = re.findall(r'\{[^{}]*\}', valid_part)
|
| 197 |
+
if matches:
|
| 198 |
+
logger.info(f"Recovered {len(matches)} complete records from {self.file_path}")
|
| 199 |
+
parsed_records = []
|
| 200 |
+
for match in matches:
|
| 201 |
+
try:
|
| 202 |
+
parsed_records.append(json.loads(match))
|
| 203 |
+
except:
|
| 204 |
+
pass
|
| 205 |
+
return parsed_records
|
| 206 |
+
except:
|
| 207 |
+
pass
|
| 208 |
+
return []
|
| 209 |
+
|
| 210 |
+
@safe_file_operation
|
| 211 |
+
def _load_csv(self) -> List[Dict[str, Any]]:
|
| 212 |
+
"""Load CSV with better error handling"""
|
| 213 |
+
samples = []
|
| 214 |
+
try:
|
| 215 |
+
with open(self.file_path, 'r', encoding='utf-8') as csvfile:
|
| 216 |
+
# Try detecting dialect first
|
| 217 |
+
try:
|
| 218 |
+
dialect = csv.Sniffer().sniff(csvfile.read(1024))
|
| 219 |
+
csvfile.seek(0)
|
| 220 |
+
reader = csv.DictReader(csvfile, dialect=dialect)
|
| 221 |
+
except:
|
| 222 |
+
# Fall back to excel dialect
|
| 223 |
+
csvfile.seek(0)
|
| 224 |
+
reader = csv.DictReader(csvfile, dialect='excel')
|
| 225 |
+
|
| 226 |
+
for i, row in enumerate(reader):
|
| 227 |
+
if not isinstance(row, dict):
|
| 228 |
+
logger.warning(f"Row {i} is not a dict: {row} -- skipping.")
|
| 229 |
+
continue
|
| 230 |
+
samples.append(row)
|
| 231 |
+
|
| 232 |
+
if not samples:
|
| 233 |
+
logger.warning(f"No valid rows found in CSV file: {self.file_path}")
|
| 234 |
+
|
| 235 |
+
except csv.Error as e:
|
| 236 |
+
logger.error(f"Error reading CSV file {self.file_path}: {e}")
|
| 237 |
+
return samples
|
| 238 |
+
|
| 239 |
+
def _load_txt(self) -> List[Dict[str, Any]]:
|
| 240 |
+
samples = []
|
| 241 |
+
with open(self.file_path, 'r', encoding='utf-8') as txtfile:
|
| 242 |
+
for i, line in enumerate(txtfile):
|
| 243 |
+
line = line.strip()
|
| 244 |
+
if line:
|
| 245 |
+
# Wrap each line in a dictionary.
|
| 246 |
+
samples.append({"text": line})
|
| 247 |
+
return samples
|
| 248 |
+
|
| 249 |
+
def _standardize_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
| 250 |
+
"""Remaps the sample's keys to a set of standardized keys using self.header_mapping.
|
| 251 |
+
For each standardized key, the first matching header from the sample is used.
|
| 252 |
+
If none is found, a default empty string is assigned."""
|
| 253 |
+
standardized = {}
|
| 254 |
+
for std_field, possible_keys in self.header_mapping.items():
|
| 255 |
+
for key in possible_keys:
|
| 256 |
+
if key in sample:
|
| 257 |
+
standardized[std_field] = sample[key]
|
| 258 |
+
break
|
| 259 |
+
if std_field not in standardized:
|
| 260 |
+
standardized[std_field] = ""
|
| 261 |
+
return standardized
|
| 262 |
+
|
| 263 |
+
def __len__(self) -> int:
|
| 264 |
+
return len(self.samples)
|
| 265 |
+
|
| 266 |
+
def __getitem__(self, index: int) -> Dict[str, Any]:
|
| 267 |
+
sample = self.samples[index]
|
| 268 |
+
|
| 269 |
+
# If a header mapping is provided, standardize the sample keys.
|
| 270 |
+
if self.header_mapping is not None:
|
| 271 |
+
sample = self._standardize_sample(sample)
|
| 272 |
+
|
| 273 |
+
# Determine the text to tokenize:
|
| 274 |
+
# If standardized keys "title" or "content" exist, combine them.
|
| 275 |
+
if 'title' in sample or 'content' in sample:
|
| 276 |
+
title = sample.get('title', '')
|
| 277 |
+
content = sample.get('content', '')
|
| 278 |
+
# Convert non-string fields to strings
|
| 279 |
+
if not isinstance(title, str):
|
| 280 |
+
title = str(title)
|
| 281 |
+
if not isinstance(content, str):
|
| 282 |
+
content = str(content)
|
| 283 |
+
text = (title + " " + content).strip()
|
| 284 |
+
elif "text" in sample:
|
| 285 |
+
text = sample["text"] if isinstance(sample["text"], str) else str(sample["text"])
|
| 286 |
+
else:
|
| 287 |
+
# Fallback: join all values (cast to str)
|
| 288 |
+
text = " ".join(str(v) for v in sample.values())
|
| 289 |
+
|
| 290 |
+
# Tokenize the combined text.
|
| 291 |
+
tokenized = self.tokenizer.encode_plus(
|
| 292 |
+
text,
|
| 293 |
+
max_length=self.max_length,
|
| 294 |
+
padding='max_length',
|
| 295 |
+
truncation=True,
|
| 296 |
+
return_tensors='pt'
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# Get specialization from sample or use class default
|
| 300 |
+
specialization = None
|
| 301 |
+
if isinstance(sample, dict) and "specialization" in sample:
|
| 302 |
+
specialization = sample["specialization"]
|
| 303 |
+
elif self.specialization:
|
| 304 |
+
specialization = self.specialization
|
| 305 |
+
|
| 306 |
+
# Return a standardized dictionary for training.
|
| 307 |
+
result = {
|
| 308 |
+
"input_ids": tokenized["input_ids"].squeeze(0),
|
| 309 |
+
"attention_mask": tokenized["attention_mask"].squeeze(0),
|
| 310 |
+
"token_type_ids": tokenized.get("token_type_ids", torch.zeros_like(tokenized["input_ids"])).squeeze(0),
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
# Add specialization if available
|
| 314 |
+
if specialization:
|
| 315 |
+
result["specialization"] = specialization
|
| 316 |
+
|
| 317 |
+
# Optionally include standardized text fields if needed
|
| 318 |
+
if 'title' in locals():
|
| 319 |
+
result["title"] = title
|
| 320 |
+
if 'content' in locals():
|
| 321 |
+
result["content"] = content
|
| 322 |
+
|
| 323 |
+
return result
|
| 324 |
+
|
| 325 |
+
# dataset.py - Simple dataset module to fix initialization dependency issues
|
| 326 |
+
import logging
|
| 327 |
+
import os
|
| 328 |
+
import json
|
| 329 |
+
from typing import Dict, List, Any, Optional, Union
|
| 330 |
+
|
| 331 |
+
logger = logging.getLogger(__name__)
|
| 332 |
+
|
| 333 |
+
class DatasetManager:
|
| 334 |
+
"""
|
| 335 |
+
Simple dataset manager to provide basic functionality for model_manager
|
| 336 |
+
without requiring external dataset dependencies
|
| 337 |
+
"""
|
| 338 |
+
def __init__(self, data_dir: Optional[str] = None):
|
| 339 |
+
self.data_dir = data_dir or os.path.join(os.path.dirname(__file__), "data")
|
| 340 |
+
self.datasets = {}
|
| 341 |
+
self._ensure_data_dir()
|
| 342 |
+
|
| 343 |
+
def _ensure_data_dir(self):
|
| 344 |
+
"""Ensure data directory exists"""
|
| 345 |
+
try:
|
| 346 |
+
if not os.path.exists(self.data_dir):
|
| 347 |
+
os.makedirs(self.data_dir, exist_ok=True)
|
| 348 |
+
logger.info(f"Created dataset directory at {self.data_dir}")
|
| 349 |
+
except (PermissionError, OSError) as e:
|
| 350 |
+
logger.warning(f"Could not create data directory: {e}")
|
| 351 |
+
# Fall back to temp directory
|
| 352 |
+
self.data_dir = os.path.join("/tmp", "wildnerve_data")
|
| 353 |
+
os.makedirs(self.data_dir, exist_ok=True)
|
| 354 |
+
logger.info(f"Using fallback data directory at {self.data_dir}")
|
| 355 |
+
|
| 356 |
+
def load_dataset(self, name: str) -> List[Dict[str, Any]]:
|
| 357 |
+
"""Load dataset by name"""
|
| 358 |
+
if name in self.datasets:
|
| 359 |
+
return self.datasets[name]
|
| 360 |
+
|
| 361 |
+
# Check for dataset file
|
| 362 |
+
filepath = os.path.join(self.data_dir, f"{name}.json")
|
| 363 |
+
if os.path.exists(filepath):
|
| 364 |
+
try:
|
| 365 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 366 |
+
data = json.load(f)
|
| 367 |
+
self.datasets[name] = data
|
| 368 |
+
return data
|
| 369 |
+
except Exception as e:
|
| 370 |
+
logger.error(f"Error loading dataset {name}: {e}")
|
| 371 |
+
|
| 372 |
+
# Return empty dataset if not found
|
| 373 |
+
logger.warning(f"Dataset {name} not found, returning empty dataset")
|
| 374 |
+
return []
|
| 375 |
+
|
| 376 |
+
def get_dataset_names(self) -> List[str]:
|
| 377 |
+
"""Get list of available datasets"""
|
| 378 |
+
try:
|
| 379 |
+
return [f.split('.')[0] for f in os.listdir(self.data_dir)
|
| 380 |
+
if f.endswith('.json')]
|
| 381 |
+
except Exception as e:
|
| 382 |
+
logger.error(f"Error listing datasets: {e}")
|
| 383 |
+
return []
|
| 384 |
+
|
| 385 |
+
def create_sample_dataset(self, name: str, samples: int = 10) -> List[Dict[str, Any]]:
|
| 386 |
+
"""Create a sample dataset for testing"""
|
| 387 |
+
data = [
|
| 388 |
+
{
|
| 389 |
+
"id": i,
|
| 390 |
+
"text": f"Sample text {i} for model training",
|
| 391 |
+
"label": i % 2 # Binary label
|
| 392 |
+
}
|
| 393 |
+
for i in range(samples)
|
| 394 |
+
]
|
| 395 |
+
|
| 396 |
+
# Save to file
|
| 397 |
+
filepath = os.path.join(self.data_dir, f"{name}.json")
|
| 398 |
+
try:
|
| 399 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
| 400 |
+
json.dump(data, f, indent=2)
|
| 401 |
+
self.datasets[name] = data
|
| 402 |
+
logger.info(f"Created sample dataset {name} with {samples} samples")
|
| 403 |
+
except Exception as e:
|
| 404 |
+
logger.error(f"Error creating sample dataset: {e}")
|
| 405 |
+
|
| 406 |
+
return data
|
| 407 |
+
|
| 408 |
+
def _load_and_process_dataset(self, path_or_paths: Union[str, List[str]], specialization: str) -> TensorDataset:
|
| 409 |
+
# …existing code up to reading the file…
|
| 410 |
+
import pandas as pd
|
| 411 |
+
|
| 412 |
+
# Handle multiple JSON files by concatenation
|
| 413 |
+
if isinstance(path_or_paths, list):
|
| 414 |
+
frames = [pd.read_json(p) for p in path_or_paths]
|
| 415 |
+
data = pd.concat(frames, ignore_index=True)
|
| 416 |
+
else:
|
| 417 |
+
data = pd.read_json(path_or_paths)
|
| 418 |
+
|
| 419 |
+
# …existing code that splits into features/labels and returns TensorDataset…
|
| 420 |
+
|
| 421 |
+
# Create a default dataset manager instance
|
| 422 |
+
dataset_manager = DatasetManager()
|
| 423 |
+
|
| 424 |
+
def get_dataset(name: str) -> List[Dict[str, Any]]:
|
| 425 |
+
"""Helper function to get a dataset by name"""
|
| 426 |
+
return dataset_manager.load_dataset(name)
|
| 427 |
+
|
| 428 |
+
# Create some minimal sample data if running as main
|
| 429 |
+
if __name__ == "__main__":
|
| 430 |
+
logging.basicConfig(level=logging.INFO)
|
| 431 |
+
dm = DatasetManager()
|
| 432 |
+
dm.create_sample_dataset("test_dataset", samples=20)
|
| 433 |
+
print(f"Available datasets: {dm.get_dataset_names()}")
|
| 434 |
+
test_data = dm.load_dataset("test_dataset")
|
| 435 |
+
print(f"Loaded {len(test_data)} samples from test_dataset")
|
dependency_helpers.py
CHANGED
|
@@ -1,118 +1,124 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
This
|
| 4 |
"""
|
| 5 |
-
import importlib
|
| 6 |
-
import logging
|
| 7 |
-
import sys
|
| 8 |
import os
|
| 9 |
-
|
|
|
|
|
|
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
Safely import a module without crashing if it's not available.
|
| 16 |
-
|
| 17 |
-
Args:
|
| 18 |
-
module_name: Name of the module to import
|
| 19 |
-
|
| 20 |
-
Returns:
|
| 21 |
-
The imported module or None if import failed
|
| 22 |
-
"""
|
| 23 |
-
try:
|
| 24 |
-
return importlib.import_module(module_name)
|
| 25 |
-
except ImportError as e:
|
| 26 |
-
logger.warning(f"Failed to import {module_name}: {e}")
|
| 27 |
-
return None
|
| 28 |
|
| 29 |
def is_module_available(module_name: str) -> bool:
|
| 30 |
-
"""
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
Args:
|
| 34 |
-
module_name: Name of the module to check
|
| 35 |
-
|
| 36 |
-
Returns:
|
| 37 |
-
True if module is available, False otherwise
|
| 38 |
-
"""
|
| 39 |
-
try:
|
| 40 |
-
importlib.util.find_spec(module_name)
|
| 41 |
-
return True
|
| 42 |
-
except ImportError:
|
| 43 |
-
return False
|
| 44 |
|
| 45 |
-
def
|
| 46 |
-
"""
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
| 50 |
-
dependencies: List of module names to check
|
| 51 |
-
|
| 52 |
-
Returns:
|
| 53 |
-
Dictionary mapping module names to availability (True/False)
|
| 54 |
-
"""
|
| 55 |
-
return {dep: is_module_available(dep) for dep in dependencies}
|
| 56 |
|
| 57 |
-
def
|
| 58 |
-
"""
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
| 72 |
|
| 73 |
-
def
|
| 74 |
-
"""
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
except Exception as e:
|
| 89 |
-
logger.warning(f"Primary function {primary_func.__name__} failed: {e}")
|
| 90 |
-
return fallback_func(*args, **kwargs)
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
try:
|
| 104 |
-
import subprocess
|
| 105 |
-
logger.info(f"Attempting to install {package_name}")
|
| 106 |
-
subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])
|
| 107 |
-
return True
|
| 108 |
-
except Exception as e:
|
| 109 |
-
logger.warning(f"Failed to install {package_name}: {e}")
|
| 110 |
-
return False
|
| 111 |
|
| 112 |
-
#
|
| 113 |
-
|
| 114 |
-
DEPENDENCY_STATUS = check_dependencies(CRITICAL_DEPENDENCIES)
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
| 1 |
"""
|
| 2 |
+
Dependency helpers to make the model work even if some libraries are missing.
|
| 3 |
+
This file provides fallback implementations for missing dependencies.
|
| 4 |
"""
|
|
|
|
|
|
|
|
|
|
| 5 |
import os
|
| 6 |
+
import logging
|
| 7 |
+
import importlib.util
|
| 8 |
+
from typing import Any, Dict, Optional, Type, Callable
|
| 9 |
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
+
# Dictionary to track mock implementations
|
| 13 |
+
MOCK_MODULES = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
def is_module_available(module_name: str) -> bool:
|
| 16 |
+
"""Check if a module is available without importing it"""
|
| 17 |
+
return importlib.util.find_spec(module_name) is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
def create_mock_emissions_tracker() -> Type:
|
| 20 |
+
"""Create a mock implementation of codecarbon's EmissionsTracker"""
|
| 21 |
+
class MockEmissionsTracker:
|
| 22 |
+
def __init__(self, *args, **kwargs):
|
| 23 |
+
logger.info("Using mock EmissionsTracker")
|
| 24 |
+
|
| 25 |
+
def __enter__(self):
|
| 26 |
+
return self
|
| 27 |
+
|
| 28 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
def start(self):
|
| 32 |
+
return self
|
| 33 |
+
|
| 34 |
+
def stop(self):
|
| 35 |
+
return 0.0 # Return zero emissions
|
| 36 |
|
| 37 |
+
return MockEmissionsTracker
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
+
def create_mock_pydantic_classes() -> Dict[str, Type]:
|
| 40 |
+
"""Create mock implementations of pydantic classes"""
|
| 41 |
+
class MockBaseModel:
|
| 42 |
+
"""Mock implementation of pydantic's BaseModel"""
|
| 43 |
+
def __init__(self, **kwargs):
|
| 44 |
+
for key, value in kwargs.items():
|
| 45 |
+
setattr(self, key, value)
|
| 46 |
+
|
| 47 |
+
def dict(self) -> Dict[str, Any]:
|
| 48 |
+
return {k: v for k, v in self.__dict__.items()
|
| 49 |
+
if not k.startswith('_')}
|
| 50 |
+
|
| 51 |
+
def json(self) -> str:
|
| 52 |
+
import json
|
| 53 |
+
return json.dumps(self.dict())
|
| 54 |
+
|
| 55 |
+
def mock_field(*args, **kwargs) -> Any:
|
| 56 |
+
"""Mock implementation of pydantic's Field"""
|
| 57 |
+
return kwargs.get('default', None)
|
| 58 |
|
| 59 |
+
class MockValidationError(Exception):
|
| 60 |
+
"""Mock implementation of pydantic's ValidationError"""
|
| 61 |
+
pass
|
| 62 |
+
|
| 63 |
+
mock_config_dict = dict
|
| 64 |
+
|
| 65 |
+
return {
|
| 66 |
+
"BaseModel": MockBaseModel,
|
| 67 |
+
"Field": mock_field,
|
| 68 |
+
"ValidationError": MockValidationError,
|
| 69 |
+
"ConfigDict": mock_config_dict
|
| 70 |
+
}
|
| 71 |
|
| 72 |
+
def setup_dependency_fallbacks():
|
| 73 |
+
"""Setup fallbacks for all required dependencies"""
|
| 74 |
+
# Handle codecarbon
|
| 75 |
+
if not is_module_available("codecarbon"):
|
| 76 |
+
logger.warning("codecarbon not found, using mock implementation")
|
| 77 |
+
MOCK_MODULES["codecarbon"] = type("MockCodecarbon", (), {
|
| 78 |
+
"EmissionsTracker": create_mock_emissions_tracker()
|
| 79 |
+
})
|
| 80 |
|
| 81 |
+
# Handle pydantic
|
| 82 |
+
if not is_module_available("pydantic"):
|
| 83 |
+
logger.warning("pydantic not found, using mock implementation")
|
| 84 |
+
mock_classes = create_mock_pydantic_classes()
|
| 85 |
+
MOCK_MODULES["pydantic"] = type("MockPydantic", (), mock_classes)
|
| 86 |
+
|
| 87 |
+
# Setup service_registry fallback if needed
|
| 88 |
+
if not is_module_available("service_registry"):
|
| 89 |
+
from types import SimpleNamespace
|
| 90 |
+
registry_obj = SimpleNamespace()
|
| 91 |
+
registry_obj.register = lambda *args, **kwargs: None
|
| 92 |
+
registry_obj.get = lambda *args: None
|
| 93 |
+
registry_obj.has = lambda *args: False
|
| 94 |
|
| 95 |
+
MOCK_MODULES["service_registry"] = type("MockServiceRegistry", (), {
|
| 96 |
+
"registry": registry_obj,
|
| 97 |
+
"MODEL": "MODEL",
|
| 98 |
+
"TOKENIZER": "TOKENIZER"
|
| 99 |
+
})
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
+
# Custom import hook to provide mock implementations
|
| 102 |
+
class DependencyImportFinder:
|
| 103 |
+
def __init__(self):
|
| 104 |
+
self._mock_modules = MOCK_MODULES
|
| 105 |
+
|
| 106 |
+
def find_module(self, fullname, path=None):
|
| 107 |
+
if fullname in self._mock_modules:
|
| 108 |
+
return self
|
| 109 |
|
| 110 |
+
def load_module(self, fullname):
|
| 111 |
+
import sys
|
| 112 |
+
if fullname in sys.modules:
|
| 113 |
+
return sys.modules[fullname]
|
| 114 |
|
| 115 |
+
module = self._mock_modules[fullname]
|
| 116 |
+
sys.modules[fullname] = module
|
| 117 |
+
return module
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
+
# Initialize the fallbacks
|
| 120 |
+
setup_dependency_fallbacks()
|
|
|
|
| 121 |
|
| 122 |
+
# Install the custom import hook
|
| 123 |
+
import sys
|
| 124 |
+
sys.meta_path.insert(0, DependencyImportFinder())
|
handler.py
CHANGED
|
@@ -17,13 +17,33 @@ logging.basicConfig(
|
|
| 17 |
)
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
| 20 |
-
#
|
| 21 |
try:
|
| 22 |
import pydantic
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
# Make sure adapter_layer.py is properly located
|
| 29 |
try:
|
|
@@ -43,7 +63,14 @@ try:
|
|
| 43 |
|
| 44 |
except ImportError as e:
|
| 45 |
logger.error(f"Could not import adapter_layer: {e}")
|
| 46 |
-
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
class EndpointHandler:
|
| 49 |
def __init__(self, path=""):
|
|
|
|
| 17 |
)
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
| 20 |
+
# Safely check for required packages without crashing
|
| 21 |
try:
|
| 22 |
import pydantic
|
| 23 |
+
print(f"pydantic is available: {pydantic.__version__}")
|
| 24 |
+
except ImportError:
|
| 25 |
+
print("pydantic is not available - continuing without it")
|
| 26 |
+
# Create minimal compatibility layer
|
| 27 |
+
class pydantic:
|
| 28 |
+
@staticmethod
|
| 29 |
+
def __version__():
|
| 30 |
+
return "unavailable"
|
| 31 |
+
|
| 32 |
+
class BaseModel:
|
| 33 |
+
def __init__(self, **kwargs):
|
| 34 |
+
for k, v in kwargs.items():
|
| 35 |
+
setattr(self, k, v)
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
from codecarbon import EmissionsTracker
|
| 39 |
+
print(f"codecarbon is available")
|
| 40 |
+
except ImportError:
|
| 41 |
+
print("codecarbon is not available - continuing without carbon tracking")
|
| 42 |
+
# Create minimal compatibility class
|
| 43 |
+
class EmissionsTracker:
|
| 44 |
+
def __init__(self, *args, **kwargs): pass
|
| 45 |
+
def start(self): return self
|
| 46 |
+
def stop(self): return 0.0
|
| 47 |
|
| 48 |
# Make sure adapter_layer.py is properly located
|
| 49 |
try:
|
|
|
|
| 63 |
|
| 64 |
except ImportError as e:
|
| 65 |
logger.error(f"Could not import adapter_layer: {e}")
|
| 66 |
+
# Don't raise error - provide fallback adapter implementation
|
| 67 |
+
class WildnerveModelAdapter:
|
| 68 |
+
def __init__(self, path=""):
|
| 69 |
+
self.path = path
|
| 70 |
+
logger.info(f"Using fallback WildnerveModelAdapter with path: {path}")
|
| 71 |
+
|
| 72 |
+
def generate(self, text_input, **kwargs):
|
| 73 |
+
return f"Model adapter unavailable. Received input: {text_input[:30]}..."
|
| 74 |
|
| 75 |
class EndpointHandler:
|
| 76 |
def __init__(self, path=""):
|
model_Custm.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# model_Custm.py
|
| 2 |
import os
|
| 3 |
import sys
|
| 4 |
import math
|
|
@@ -8,14 +8,15 @@ import numpy as np
|
|
| 8 |
import torch.nn as nn
|
| 9 |
from typing import Optional, List, Dict, Union
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
# Import the carbon tracker early - before transformers
|
| 12 |
try:
|
| 13 |
from codecarbon import EmissionsTracker
|
| 14 |
except ImportError:
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def start(self): return self
|
| 18 |
-
def stop(self): return 0.0
|
| 19 |
|
| 20 |
# Apply patches before importing transformers
|
| 21 |
import transformer_patches
|
|
|
|
| 1 |
+
# model_Custm.py - with dependency fallbacks
|
| 2 |
import os
|
| 3 |
import sys
|
| 4 |
import math
|
|
|
|
| 8 |
import torch.nn as nn
|
| 9 |
from typing import Optional, List, Dict, Union
|
| 10 |
|
| 11 |
+
# Import dependency helpers first
|
| 12 |
+
import dependency_helpers
|
| 13 |
+
|
| 14 |
# Import the carbon tracker early - before transformers
|
| 15 |
try:
|
| 16 |
from codecarbon import EmissionsTracker
|
| 17 |
except ImportError:
|
| 18 |
+
# Use the mock from dependency_helpers
|
| 19 |
+
EmissionsTracker = dependency_helpers.create_mock_emissions_tracker()
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# Apply patches before importing transformers
|
| 22 |
import transformer_patches
|
service_registry.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple service registry for dependency injection
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Any, Dict, Optional
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
# Constants used as keys
|
| 10 |
+
MODEL = "model"
|
| 11 |
+
TOKENIZER = "tokenizer"
|
| 12 |
+
|
| 13 |
+
class ServiceRegistry:
|
| 14 |
+
"""A simple service registry for dependency management"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self._services = {}
|
| 18 |
+
|
| 19 |
+
def register(self, key: str, service: Any, overwrite: bool = False) -> None:
|
| 20 |
+
"""Register a service with the given key"""
|
| 21 |
+
if key in self._services and not overwrite:
|
| 22 |
+
logger.warning(f"Service with key '{key}' already registered")
|
| 23 |
+
return
|
| 24 |
+
|
| 25 |
+
self._services[key] = service
|
| 26 |
+
logger.debug(f"Registered service with key: {key}")
|
| 27 |
+
|
| 28 |
+
def get(self, key: str) -> Optional[Any]:
|
| 29 |
+
"""Get a service by its key"""
|
| 30 |
+
if key not in self._services:
|
| 31 |
+
logger.warning(f"No service registered with key: {key}")
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
return self._services[key]
|
| 35 |
+
|
| 36 |
+
def has(self, key: str) -> bool:
|
| 37 |
+
"""Check if a service with the given key exists"""
|
| 38 |
+
return key in self._services
|
| 39 |
+
|
| 40 |
+
def clear(self) -> None:
|
| 41 |
+
"""Clear all registered services"""
|
| 42 |
+
self._services.clear()
|
| 43 |
+
|
| 44 |
+
# Create singleton instance
|
| 45 |
+
registry = ServiceRegistry()
|
transformer_patches.py
CHANGED
|
@@ -211,3 +211,26 @@ if __name__ == "__main__":
|
|
| 211 |
print("\nPatch status:")
|
| 212 |
for patch, status in _patch_status.items():
|
| 213 |
print(f" {'✓' if status else '✗'} {patch}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
print("\nPatch status:")
|
| 212 |
for patch, status in _patch_status.items():
|
| 213 |
print(f" {'✓' if status else '✗'} {patch}")
|
| 214 |
+
|
| 215 |
+
"""
|
| 216 |
+
Transformer patches to make the model work better with HuggingFace transformers.
|
| 217 |
+
This file applies monkey patches to fix compatibility issues or add functionality.
|
| 218 |
+
"""
|
| 219 |
+
import logging
|
| 220 |
+
from typing import Dict, Any, Optional
|
| 221 |
+
|
| 222 |
+
logger = logging.getLogger(__name__)
|
| 223 |
+
|
| 224 |
+
def apply_transformer_patches():
|
| 225 |
+
"""Apply monkey patches to transformers if needed"""
|
| 226 |
+
try:
|
| 227 |
+
import transformers
|
| 228 |
+
logger.info(f"Applying patches to transformers v{transformers.__version__}")
|
| 229 |
+
|
| 230 |
+
# No patches needed currently, but you can add them here if needed in future
|
| 231 |
+
|
| 232 |
+
except ImportError:
|
| 233 |
+
logger.warning("Transformers library not found, skipping patches")
|
| 234 |
+
|
| 235 |
+
# Apply patches when imported
|
| 236 |
+
apply_transformer_patches()
|