Upload 8 files
Browse files- adapter_layer.py +105 -58
- config.json +12 -1
- config.py +29 -6
- model_Custm.py +24 -14
- model_List.py +212 -180
- model_PrTr.py +104 -25
- service_registry.py +3 -3
adapter_layer.py
CHANGED
|
@@ -208,65 +208,112 @@ class WildnerveModelAdapter:
|
|
| 208 |
gen_list.append(s)
|
| 209 |
return " ".join(tech_list).strip(), " ".join(gen_list).strip()
|
| 210 |
|
| 211 |
-
def generate(self,
|
| 212 |
-
"""Generate
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
# Set appropriate max_length to prevent length errors
|
| 217 |
-
if 'max_length' in kwargs and isinstance(kwargs['max_length'], int):
|
| 218 |
-
if kwargs['max_length'] < 512: # If max_length is too small
|
| 219 |
-
kwargs['max_length'] = 512 # Use a reasonable default
|
| 220 |
-
else:
|
| 221 |
-
kwargs['max_length'] = 1024 # Set a default if not provided
|
| 222 |
-
|
| 223 |
-
# Try using the pretrained GPT-2 model first for generation
|
| 224 |
-
pre = registry.get(PRETRAINED_MODEL)
|
| 225 |
-
if pre:
|
| 226 |
try:
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
import inspect
|
| 232 |
-
sig = inspect.signature(pre.generate)
|
| 233 |
-
if "prompt" in sig.parameters:
|
| 234 |
-
return pre.generate(prompt=prompt, **kwargs)
|
| 235 |
-
else:
|
| 236 |
-
# If no prompt parameter, try tokenizing first
|
| 237 |
-
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, padding=True)
|
| 238 |
-
return pre.generate(input_ids=inputs.input_ids, **kwargs) # Explicitly pass as input_ids
|
| 239 |
-
else:
|
| 240 |
-
logger.warning("Pretrained model doesn't have generate method")
|
| 241 |
except Exception as e:
|
| 242 |
-
logger.error(f"Error using
|
| 243 |
-
|
| 244 |
-
# Fall back to using the custom model if needed
|
| 245 |
-
if self.model:
|
| 246 |
-
try:
|
| 247 |
-
logger.info("Using custom model for generation")
|
| 248 |
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
gen_list.append(s)
|
| 209 |
return " ".join(tech_list).strip(), " ".join(gen_list).strip()
|
| 210 |
|
| 211 |
+
def generate(self, text_input, max_length=None, **kwargs):
|
| 212 |
+
"""Generate text using the model - centralized generation point"""
|
| 213 |
+
try:
|
| 214 |
+
# Use PromptAnalyzer to determine which model to use
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
try:
|
| 216 |
+
from model_List import PromptAnalyzer
|
| 217 |
+
analyzer = PromptAnalyzer()
|
| 218 |
+
model_type, confidence = analyzer.analyze_prompt(text_input)
|
| 219 |
+
logger.info(f"PromptAnalyzer selected {model_type} with confidence {confidence:.2f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
except Exception as e:
|
| 221 |
+
logger.error(f"Error using PromptAnalyzer: {e}")
|
| 222 |
+
model_type = "model_Custm" # Default to custom model on error
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
+
# Enhanced generation parameters with strong repetition prevention
|
| 225 |
+
generation_kwargs = {
|
| 226 |
+
'max_length': max_length or 150,
|
| 227 |
+
'temperature': kwargs.get('temperature', 0.7),
|
| 228 |
+
'top_p': kwargs.get('top_p', 0.95),
|
| 229 |
+
'top_k': kwargs.get('top_k', 50),
|
| 230 |
+
'repetition_penalty': kwargs.get('repetition_penalty', 1.3), # Increased from 1.2
|
| 231 |
+
'no_repeat_ngram_size': kwargs.get('no_repeat_ngram_size', 3), # Increased from 2
|
| 232 |
+
'do_sample': kwargs.get('do_sample', True),
|
| 233 |
+
'num_return_sequences': kwargs.get('num_return_sequences', 1),
|
| 234 |
+
'early_stopping': kwargs.get('early_stopping', True),
|
| 235 |
+
'bad_words_ids': kwargs.get('bad_words_ids', None), # Block repetitive phrases
|
| 236 |
+
'min_length': kwargs.get('min_length', 10), # Ensure reasonable response length
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
# Create penalty_alpha for GPT-2 encoder-decoder attention
|
| 240 |
+
if 'penalty_alpha' not in kwargs:
|
| 241 |
+
generation_kwargs['penalty_alpha'] = 0.6 # Helps prevent looping in GPT-2
|
| 242 |
|
| 243 |
+
# Override with any explicitly provided kwargs
|
| 244 |
+
generation_kwargs.update({k:v for k,v in kwargs.items() if k not in ('prompt', 'context')})
|
| 245 |
+
|
| 246 |
+
if model_type == "model_Custm":
|
| 247 |
+
# Use the Custom Wildnerve model for technical topics
|
| 248 |
+
custom_model = registry.get(MODEL)
|
| 249 |
+
if custom_model:
|
| 250 |
+
try:
|
| 251 |
+
logger.info("Using custom Wildnerve-tlm01_Hybrid_Model for technical prompt")
|
| 252 |
+
# Check signature of the generate method
|
| 253 |
+
import inspect
|
| 254 |
+
if hasattr(custom_model, "generate"):
|
| 255 |
+
sig = inspect.signature(custom_model.generate)
|
| 256 |
+
if "prompt" in sig.parameters:
|
| 257 |
+
return custom_model.generate(prompt=text_input, **generation_kwargs)
|
| 258 |
+
else:
|
| 259 |
+
# If no prompt parameter, try tokenizing first
|
| 260 |
+
inputs = self.tokenizer(text_input, return_tensors="pt", truncation=True, padding=True)
|
| 261 |
+
return custom_model.generate(input_ids=inputs.input_ids, **generation_kwargs)
|
| 262 |
+
else:
|
| 263 |
+
logger.warning("Custom model doesn't have generate method, falling back to pretrained")
|
| 264 |
+
except Exception as e:
|
| 265 |
+
logger.error(f"Error using custom model: {e}")
|
| 266 |
+
else:
|
| 267 |
+
# Use the Pretrained model (GPT-2) for general topics
|
| 268 |
+
pre = registry.get(PRETRAINED_MODEL)
|
| 269 |
+
if pre:
|
| 270 |
+
try:
|
| 271 |
+
logger.info("Using GPT-2 pretrained model for general prompt")
|
| 272 |
+
# Try to use the pretrained model's generate method
|
| 273 |
+
if hasattr(pre, "generate"):
|
| 274 |
+
# Check the signature of the generate method to determine correct parameters
|
| 275 |
+
import inspect
|
| 276 |
+
sig = inspect.signature(pre.generate)
|
| 277 |
+
if "prompt" in sig.parameters:
|
| 278 |
+
return pre.generate(prompt=text_input, **generation_kwargs)
|
| 279 |
+
else:
|
| 280 |
+
# If no prompt parameter, try tokenizing first
|
| 281 |
+
inputs = self.tokenizer(text_input, return_tensors="pt", truncation=True, padding=True)
|
| 282 |
+
return pre.generate(input_ids=inputs.input_ids, **generation_kwargs) # Explicitly pass as input_ids
|
| 283 |
+
else:
|
| 284 |
+
logger.warning("Pretrained model doesn't have generate method")
|
| 285 |
+
except Exception as e:
|
| 286 |
+
logger.error(f"Error using pretrained model: {e}")
|
| 287 |
+
|
| 288 |
+
# Fall back to using the custom model if needed
|
| 289 |
+
if self.model:
|
| 290 |
+
try:
|
| 291 |
+
logger.info("Using custom model for generation")
|
| 292 |
+
|
| 293 |
+
# Check if the model is expecting a prompt parameter or input_ids
|
| 294 |
+
import inspect
|
| 295 |
+
if hasattr(self.model, "generate"):
|
| 296 |
+
sig = inspect.signature(self.model.generate)
|
| 297 |
+
if "prompt" in sig.parameters:
|
| 298 |
+
# Model accepts prompt parameter directly
|
| 299 |
+
return self.model.generate(prompt=text_input, **generation_kwargs) # Explicitly pass as prompt
|
| 300 |
+
else:
|
| 301 |
+
# Model expects tokenized input_ids instead
|
| 302 |
+
logger.info("Model expects tokenized input - converting prompt to input_ids")
|
| 303 |
+
inputs = self.tokenizer(text_input, return_tensors="pt", truncation=True, padding=True)
|
| 304 |
+
return self.model.generate(input_ids=inputs.input_ids, **generation_kwargs) # Explicitly pass as input_ids
|
| 305 |
+
else:
|
| 306 |
+
logger.error("Model has no generate method")
|
| 307 |
+
# Simple fallback for models without generate
|
| 308 |
+
return f"I'm processing your request about '{text_input[:30]}...'"
|
| 309 |
+
except Exception as e:
|
| 310 |
+
logger.error(f"Error using custom model: {e}")
|
| 311 |
+
|
| 312 |
+
# Add last-chance fallback with generic response
|
| 313 |
+
return f"I apologize, but I'm experiencing some technical difficulties processing your request about '{text_input[:30]}...'. (Error: {str(e)})"
|
| 314 |
+
|
| 315 |
+
# Final fallback
|
| 316 |
+
return f"I apologize, but I'm unable to process your request about '{text_input[:30]}...' at this time."
|
| 317 |
+
except Exception as e:
|
| 318 |
+
logger.error(f"Error in generate method: {e}")
|
| 319 |
+
return f"An error occurred while generating text: {str(e)}"
|
config.json
CHANGED
|
@@ -241,8 +241,19 @@
|
|
| 241 |
"HIDDEN_DIM": 768,
|
| 242 |
"MAX_CACHE_SIZE": 10
|
| 243 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
"MAX_ACTIVE_MODELS": 5,
|
| 245 |
"MODEL_IDLE_THRESHOLD": 600,
|
| 246 |
"MAX_MEMORY_USAGE": 0.8,
|
| 247 |
-
"TOP_K": 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
}
|
|
|
|
| 241 |
"HIDDEN_DIM": 768,
|
| 242 |
"MAX_CACHE_SIZE": 10
|
| 243 |
},
|
| 244 |
+
"MODEL_PRIORITY": {
|
| 245 |
+
"PRIMARY": "model_Custm",
|
| 246 |
+
"SECONDARY": "model_PrTr",
|
| 247 |
+
"USE_PRETRAINED_FALLBACK": true
|
| 248 |
+
},
|
| 249 |
"MAX_ACTIVE_MODELS": 5,
|
| 250 |
"MODEL_IDLE_THRESHOLD": 600,
|
| 251 |
"MAX_MEMORY_USAGE": 0.8,
|
| 252 |
+
"TOP_K": 3,
|
| 253 |
+
"TOPIC_KEYWORDS": {
|
| 254 |
+
"programming": ["python", "java", "javascript", /* other keywords */],
|
| 255 |
+
"computer_science": ["algorithm", "complexity", /* other keywords */],
|
| 256 |
+
"software_engineering": ["design pattern", "architecture", /* other keywords */],
|
| 257 |
+
"web_development": ["frontend", "backend", /* other keywords */]
|
| 258 |
+
}
|
| 259 |
}
|
config.py
CHANGED
|
@@ -388,6 +388,12 @@ class AppConfig(BaseModel):
|
|
| 388 |
MAX_ACTIVE_MODELS: int = Field(default=2)
|
| 389 |
MODEL_IDLE_THRESHOLD: int = Field(default=600)
|
| 390 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
def load_config() -> AppConfig:
|
| 392 |
config_path = os.path.join(os.path.dirname(__file__), "config.json")
|
| 393 |
logger.info(f"Loading config from {config_path}")
|
|
@@ -395,14 +401,31 @@ def load_config() -> AppConfig:
|
|
| 395 |
with open(config_path, "r") as f:
|
| 396 |
raw = json.load(f)
|
| 397 |
|
| 398 |
-
#
|
| 399 |
-
class AttrDict(dict):
|
| 400 |
-
__getattr__ = dict.get
|
| 401 |
-
__setattr__ = dict.__setitem__
|
| 402 |
-
|
| 403 |
-
# wrap TRANSFORMER_CONFIG if it's a dict
|
| 404 |
if isinstance(raw.get("TRANSFORMER_CONFIG"), dict):
|
| 405 |
raw["TRANSFORMER_CONFIG"] = AttrDict(raw["TRANSFORMER_CONFIG"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
except Exception as e:
|
| 407 |
logger.error(f"Failed to read config.json: {e}", exc_info=True)
|
| 408 |
raise
|
|
|
|
| 388 |
MAX_ACTIVE_MODELS: int = Field(default=2)
|
| 389 |
MODEL_IDLE_THRESHOLD: int = Field(default=600)
|
| 390 |
|
| 391 |
+
class AttrDict(dict):
|
| 392 |
+
"""Dictionary subclass with attribute-style access"""
|
| 393 |
+
__getattr__ = dict.get
|
| 394 |
+
__setattr__ = dict.__setitem__
|
| 395 |
+
__delattr__ = dict.__delitem__
|
| 396 |
+
|
| 397 |
def load_config() -> AppConfig:
|
| 398 |
config_path = os.path.join(os.path.dirname(__file__), "config.json")
|
| 399 |
logger.info(f"Loading config from {config_path}")
|
|
|
|
| 401 |
with open(config_path, "r") as f:
|
| 402 |
raw = json.load(f)
|
| 403 |
|
| 404 |
+
# Always wrap TRANSFORMER_CONFIG for attribute access
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
if isinstance(raw.get("TRANSFORMER_CONFIG"), dict):
|
| 406 |
raw["TRANSFORMER_CONFIG"] = AttrDict(raw["TRANSFORMER_CONFIG"])
|
| 407 |
+
|
| 408 |
+
# Ensure GPT-2 parameters
|
| 409 |
+
if not isinstance(raw["TRANSFORMER_CONFIG"].get("VOCAB_SIZE"), int) or raw["TRANSFORMER_CONFIG"]["VOCAB_SIZE"] != 50257:
|
| 410 |
+
raw["TRANSFORMER_CONFIG"]["VOCAB_SIZE"] = 50257 # Standard GPT-2 vocab size
|
| 411 |
+
|
| 412 |
+
if raw["TRANSFORMER_CONFIG"].get("MODEL_NAME") != "gpt2":
|
| 413 |
+
raw["TRANSFORMER_CONFIG"]["MODEL_NAME"] = "gpt2"
|
| 414 |
+
|
| 415 |
+
# Ensure OUTPUT_SIZE matches VOCAB_SIZE
|
| 416 |
+
raw["TRANSFORMER_CONFIG"]["OUTPUT_SIZE"] = raw["TRANSFORMER_CONFIG"]["VOCAB_SIZE"]
|
| 417 |
+
|
| 418 |
+
# Add generation parameters if missing
|
| 419 |
+
if "GENERATION_CONFIG" not in raw:
|
| 420 |
+
raw["GENERATION_CONFIG"] = {
|
| 421 |
+
"temperature": 0.7,
|
| 422 |
+
"top_p": 0.95,
|
| 423 |
+
"top_k": 50,
|
| 424 |
+
"repetition_penalty": 1.3,
|
| 425 |
+
"no_repeat_ngram_size": 3,
|
| 426 |
+
"do_sample": True,
|
| 427 |
+
"penalty_alpha": 0.6
|
| 428 |
+
}
|
| 429 |
except Exception as e:
|
| 430 |
logger.error(f"Failed to read config.json: {e}", exc_info=True)
|
| 431 |
raise
|
model_Custm.py
CHANGED
|
@@ -81,15 +81,15 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
|
|
| 81 |
- SmartHybridAttention for better context handling"""
|
| 82 |
def __init__(
|
| 83 |
self,
|
| 84 |
-
vocab_size=
|
| 85 |
specialization="general",
|
| 86 |
dataset_path=None,
|
| 87 |
-
model_name="
|
| 88 |
embedding_dim=768,
|
| 89 |
num_heads=12,
|
| 90 |
hidden_dim=768,
|
| 91 |
num_layers=6,
|
| 92 |
-
output_size=
|
| 93 |
dropout=0.1,
|
| 94 |
max_seq_length=512,
|
| 95 |
pooling_mode="mean",
|
|
@@ -123,18 +123,16 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
|
|
| 123 |
self.tokenizer = registry.get(TOKENIZER)
|
| 124 |
else:
|
| 125 |
try:
|
| 126 |
-
from transformers import
|
| 127 |
-
self.tokenizer =
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
| 129 |
except Exception as e:
|
| 130 |
-
logger.warning(f"
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
| 134 |
-
logger.info("Loaded fallback tokenizer: bert-base-uncased")
|
| 135 |
-
except Exception as e2:
|
| 136 |
-
logger.error(f"Fallback tokenizer load failed: {e2}")
|
| 137 |
-
self.tokenizer = None
|
| 138 |
registry.register(TOKENIZER, self.tokenizer, overwrite=True)
|
| 139 |
|
| 140 |
# Register this model instance in the registry by specialization
|
|
@@ -363,6 +361,7 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
|
|
| 363 |
embeddings = self.encode_sentences([sentence1, sentence2])
|
| 364 |
return np.dot(embeddings[0], embeddings[1]) / (np.linalg.norm(embeddings[0]) * np.linalg.norm(embeddings[1]))
|
| 365 |
|
|
|
|
| 366 |
def generate(
|
| 367 |
self,
|
| 368 |
prompt=None,
|
|
@@ -373,6 +372,17 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
|
|
| 373 |
**kwargs
|
| 374 |
) -> str:
|
| 375 |
"""Generate text using the model, supporting either prompt string or input_ids."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
# Log what we're working with
|
| 377 |
logger.info(f"Generate called with: prompt={type(prompt).__name__ if prompt else None}, input_ids={type(input_ids).__name__ if input_ids else None}")
|
| 378 |
|
|
|
|
| 81 |
- SmartHybridAttention for better context handling"""
|
| 82 |
def __init__(
|
| 83 |
self,
|
| 84 |
+
vocab_size=50257, # Updated to GPT-2 vocab size
|
| 85 |
specialization="general",
|
| 86 |
dataset_path=None,
|
| 87 |
+
model_name="gpt2", # Standardized to GPT-2
|
| 88 |
embedding_dim=768,
|
| 89 |
num_heads=12,
|
| 90 |
hidden_dim=768,
|
| 91 |
num_layers=6,
|
| 92 |
+
output_size=50257, # Updated to GPT-2 vocab size
|
| 93 |
dropout=0.1,
|
| 94 |
max_seq_length=512,
|
| 95 |
pooling_mode="mean",
|
|
|
|
| 123 |
self.tokenizer = registry.get(TOKENIZER)
|
| 124 |
else:
|
| 125 |
try:
|
| 126 |
+
from transformers import GPT2Tokenizer
|
| 127 |
+
self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
| 128 |
+
# Ensure pad_token is set for GPT-2
|
| 129 |
+
if self.tokenizer.pad_token_id is None:
|
| 130 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 131 |
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
| 132 |
except Exception as e:
|
| 133 |
+
logger.warning(f"Failed to load GPT-2 tokenizer: {e}")
|
| 134 |
+
from utils.transformer_utils import get_tokenizer
|
| 135 |
+
self.tokenizer = get_tokenizer()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
registry.register(TOKENIZER, self.tokenizer, overwrite=True)
|
| 137 |
|
| 138 |
# Register this model instance in the registry by specialization
|
|
|
|
| 361 |
embeddings = self.encode_sentences([sentence1, sentence2])
|
| 362 |
return np.dot(embeddings[0], embeddings[1]) / (np.linalg.norm(embeddings[0]) * np.linalg.norm(embeddings[1]))
|
| 363 |
|
| 364 |
+
# Update generate to use adapter_layer as the primary generation point
|
| 365 |
def generate(
|
| 366 |
self,
|
| 367 |
prompt=None,
|
|
|
|
| 372 |
**kwargs
|
| 373 |
) -> str:
|
| 374 |
"""Generate text using the model, supporting either prompt string or input_ids."""
|
| 375 |
+
# Try to use adapter_layer.generate if available
|
| 376 |
+
adapter_layer = registry.get("adapter_layer")
|
| 377 |
+
if adapter_layer and hasattr(adapter_layer, "generate"):
|
| 378 |
+
if prompt:
|
| 379 |
+
return adapter_layer.generate(prompt, max_length=max_length, temperature=temperature, **kwargs)
|
| 380 |
+
elif input_ids is not None and self.tokenizer:
|
| 381 |
+
# Convert input_ids back to text to use centralized generation
|
| 382 |
+
decoded_prompt = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
| 383 |
+
return adapter_layer.generate(decoded_prompt, max_length=max_length, temperature=temperature, **kwargs)
|
| 384 |
+
|
| 385 |
+
# Fall back to direct generation if adapter_layer is not available
|
| 386 |
# Log what we're working with
|
| 387 |
logger.info(f"Generate called with: prompt={type(prompt).__name__ if prompt else None}, input_ids={type(input_ids).__name__ if input_ids else None}")
|
| 388 |
|
model_List.py
CHANGED
|
@@ -32,22 +32,46 @@ class PromptAnalyzer:
|
|
| 32 |
- SmartHybridAttention for analyzing complex or long prompts
|
| 33 |
- Performance tracking and caching for efficiency
|
| 34 |
"""
|
| 35 |
-
def __init__(self):
|
| 36 |
self.logger = logging.getLogger(__name__)
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
# For caching and performance tracking
|
| 39 |
self._model_cache = {}
|
| 40 |
self._performance_metrics = {}
|
| 41 |
|
| 42 |
-
# Define topic keywords for the simple approach
|
| 43 |
-
self.predefined_topics = {
|
| 44 |
-
"programming": ["code", "function", "class", "algorithm", "programming", "python", "javascript", "java", "c++", "developer", "api"],
|
| 45 |
-
"science": ["science", "physics", "chemistry", "biology", "scientific", "experiment", "hypothesis", "theory"],
|
| 46 |
-
"mathematics": ["math", "equation", "calculus", "algebra", "geometry", "theorem", "mathematical"],
|
| 47 |
-
"history": ["history", "historical", "ancient", "century", "war", "civilization", "empire"],
|
| 48 |
-
"general": ["how", "what", "when", "where", "why", "who", "can you", "please", "thanks", "hello"]
|
| 49 |
-
}
|
| 50 |
-
|
| 51 |
# Initialize model_class attribute
|
| 52 |
self.model_class = None
|
| 53 |
|
|
@@ -67,30 +91,88 @@ class PromptAnalyzer:
|
|
| 67 |
except Exception:
|
| 68 |
pass
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
def _init_advanced_tools(self):
|
| 71 |
-
"""Initialize advanced analysis tools with proper error handling"""
|
| 72 |
self.sentence_model = None
|
| 73 |
self.gpt2_model = None
|
| 74 |
self.gpt2_tokenizer = None
|
| 75 |
|
| 76 |
-
#
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
self.logger.info("Loaded GPT-2 model for perplexity calculation")
|
| 91 |
-
except Exception as e:
|
| 92 |
-
self.logger.warning(f"Failed to load GPT-2: {e}")
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
# Initialize SmartHybridAttention
|
| 95 |
try:
|
| 96 |
attention_config = get_hybrid_attention_config()
|
|
@@ -253,173 +335,123 @@ class PromptAnalyzer:
|
|
| 253 |
self.logger.error(f"Error in attention-based analysis: {e}")
|
| 254 |
return None
|
| 255 |
|
| 256 |
-
def
|
| 257 |
-
"""
|
| 258 |
-
Enhanced prompt analysis with SmartHybridAttention for complex prompts
|
| 259 |
-
"""
|
| 260 |
-
# Start with simple keyword-based classification
|
| 261 |
prompt_lower = prompt.lower()
|
| 262 |
-
|
|
|
|
| 263 |
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
|
|
|
|
|
|
| 267 |
|
| 268 |
-
#
|
| 269 |
-
|
| 270 |
|
| 271 |
-
if
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
|
|
|
| 280 |
|
| 281 |
-
#
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
# We could have reference embeddings for each topic and calculate similarity
|
| 288 |
-
# For now, we'll just use the embedding magnitude as a complexity measure
|
| 289 |
-
complexity = np.linalg.norm(embedding)
|
| 290 |
-
|
| 291 |
-
# Adjust scores based on complexity
|
| 292 |
-
if complexity > 15: # High complexity
|
| 293 |
-
if topic_scores.get("programming", 0) > 0:
|
| 294 |
-
topic_scores["programming"] *= 1.5
|
| 295 |
-
if topic_scores.get("science", 0) > 0:
|
| 296 |
-
topic_scores["science"] *= 1.4
|
| 297 |
-
if topic_scores.get("mathematics", 0) > 0:
|
| 298 |
-
topic_scores["mathematics"] *= 1.3
|
| 299 |
-
|
| 300 |
-
if self.gpt2_model and self.gpt2_tokenizer:
|
| 301 |
-
# Calculate perplexity for another dimension of analysis
|
| 302 |
-
try:
|
| 303 |
-
inputs = self.gpt2_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
| 304 |
-
with torch.no_grad():
|
| 305 |
-
outputs = self.gpt2_model(**inputs, labels=inputs["input_ids"])
|
| 306 |
-
|
| 307 |
-
loss = outputs.loss.item()
|
| 308 |
-
perplexity = math.exp(loss)
|
| 309 |
-
|
| 310 |
-
# Adjust scores based on perplexity
|
| 311 |
-
if perplexity > 100: # Very specialized/technical content
|
| 312 |
-
if topic_scores.get("programming", 0) > 0:
|
| 313 |
-
topic_scores["programming"] *= 1.4
|
| 314 |
-
if topic_scores.get("science", 0) > 0:
|
| 315 |
-
topic_scores["science"] *= 1.3
|
| 316 |
-
if topic_scores.get("mathematics", 0) > 0:
|
| 317 |
-
topic_scores["mathematics"] *= 1.2
|
| 318 |
-
except Exception as e:
|
| 319 |
-
logger.warning(f"Error in perplexity calculation: {e}")
|
| 320 |
-
except Exception as e:
|
| 321 |
-
logger.warning(f"Advanced analysis failed: {e}")
|
| 322 |
|
| 323 |
-
#
|
| 324 |
-
|
| 325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
|
| 327 |
-
|
|
|
|
|
|
|
|
|
|
| 328 |
|
| 329 |
-
#
|
| 330 |
-
|
| 331 |
-
if s > 0 and t != primary_topic]
|
| 332 |
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
"""
|
| 337 |
-
Analyze prompt complexity with attention-enhanced analysis
|
| 338 |
-
"""
|
| 339 |
-
# First check if we can use attention-based analysis for complex prompts
|
| 340 |
-
if self.attention and len(prompt) > 150:
|
| 341 |
try:
|
| 342 |
-
|
| 343 |
-
sentences = nltk.sent_tokenize(prompt)
|
| 344 |
-
if len(sentences) > 1:
|
| 345 |
-
# Apply attention to understand cross-sentence relationships
|
| 346 |
-
sentence_embeddings = [self.sentence_model.encode(s) for s in sentences]
|
| 347 |
-
embeddings_tensor = torch.tensor(sentence_embeddings).unsqueeze(1)
|
| 348 |
-
|
| 349 |
-
# Use attention to focus on important parts of the prompt
|
| 350 |
-
attended_embeddings, _ = self.attention(
|
| 351 |
-
query=embeddings_tensor,
|
| 352 |
-
key=embeddings_tensor,
|
| 353 |
-
value=embeddings_tensor
|
| 354 |
-
)
|
| 355 |
-
|
| 356 |
-
# Calculate complexity based on attention-weighted embeddings
|
| 357 |
-
complexity = torch.norm(attended_embeddings.mean(dim=0)).item()
|
| 358 |
-
logger.info(f"Computed attention-weighted complexity: {complexity}")
|
| 359 |
-
|
| 360 |
-
# Return candidate index based on complexity
|
| 361 |
-
if complexity < 12:
|
| 362 |
-
return 0 # Simpler model
|
| 363 |
-
elif complexity < 24:
|
| 364 |
-
return 1 # Moderate model
|
| 365 |
-
else:
|
| 366 |
-
return 2 # Complex model
|
| 367 |
except Exception as e:
|
| 368 |
-
logger.warning(f"
|
| 369 |
-
|
| 370 |
-
# Use embeddings if available for complexity analysis
|
| 371 |
-
if self.sentence_model:
|
| 372 |
-
try:
|
| 373 |
-
# Get embedding and calculate complexity based on vector properties
|
| 374 |
-
embedding = self.sentence_model.encode(prompt)
|
| 375 |
-
# Calculate complexity (vector magnitude)
|
| 376 |
-
complexity = np.linalg.norm(embedding)
|
| 377 |
-
logger.info(f"Computed embedding complexity: {complexity}")
|
| 378 |
-
|
| 379 |
-
# Return appropriate index based on complexity
|
| 380 |
-
if complexity < 10:
|
| 381 |
-
return 0 # Less complex, use simpler model
|
| 382 |
-
elif complexity < 20:
|
| 383 |
-
return 1 # Moderate complexity
|
| 384 |
-
else:
|
| 385 |
-
return 2 # High complexity, use specialized model
|
| 386 |
-
except Exception as e:
|
| 387 |
-
logger.warning(f"Embedding-based analysis failed: {e}")
|
| 388 |
|
| 389 |
-
# Use
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
# Map topics to model indices
|
| 413 |
-
topic_to_index = {
|
| 414 |
-
"general": 0,
|
| 415 |
-
"history": 0,
|
| 416 |
-
"programming": 1,
|
| 417 |
-
"science": 1,
|
| 418 |
-
"mathematics": 2
|
| 419 |
-
}
|
| 420 |
|
| 421 |
-
#
|
| 422 |
-
|
|
|
|
|
|
|
|
|
|
| 423 |
|
| 424 |
def choose_model(self, prompt: str = None) -> Type:
|
| 425 |
"""Enhanced model selection that combines config and analysis"""
|
|
|
|
| 32 |
- SmartHybridAttention for analyzing complex or long prompts
|
| 33 |
- Performance tracking and caching for efficiency
|
| 34 |
"""
|
| 35 |
+
def __init__(self, model_name=None, dataset_path=None, specialization=None, hidden_dim=None):
|
| 36 |
self.logger = logging.getLogger(__name__)
|
| 37 |
|
| 38 |
+
# Load config
|
| 39 |
+
self.config = load_config(config_file="config.json")
|
| 40 |
+
|
| 41 |
+
# Use provided values or config values
|
| 42 |
+
self.model_name = model_name or self.config.PROMPT_ANALYZER_CONFIG.MODEL_NAME
|
| 43 |
+
self.dataset_path = dataset_path or self.config.PROMPT_ANALYZER_CONFIG.DATASET_PATH
|
| 44 |
+
self.specialization = specialization or self.config.PROMPT_ANALYZER_CONFIG.SPECIALIZATION
|
| 45 |
+
self.hidden_dim = hidden_dim or self.config.PROMPT_ANALYZER_CONFIG.HIDDEN_DIM
|
| 46 |
+
|
| 47 |
+
self.logger.info(f"Initialized PromptAnalyzer with {self.model_name}")
|
| 48 |
+
self._model_cache: Dict[str, Type] = {}
|
| 49 |
+
self._performance_metrics: Dict[str, Dict[str, float]] = {}
|
| 50 |
+
|
| 51 |
+
# Load predefined topics from config or fall back to defaults
|
| 52 |
+
self._load_predefined_topics()
|
| 53 |
+
|
| 54 |
+
# Always use a proper SentenceTransformer model - fix this to avoid warnings
|
| 55 |
+
if hasattr(self, 'sentence_model'):
|
| 56 |
+
del self.sentence_model # Remove any existing instance
|
| 57 |
+
|
| 58 |
+
# Use a proper SentenceTransformer model
|
| 59 |
+
self.sentence_model = get_sentence_transformer('sentence-transformers/all-MiniLM-L6-v2')
|
| 60 |
+
self.logger.info(f"Using SentenceTransformer model: sentence-transformers/all-MiniLM-L6-v2")
|
| 61 |
+
|
| 62 |
+
# Use GPT-2 for perplexity calculation
|
| 63 |
+
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 64 |
+
self.model = AutoModelForCausalLM.from_pretrained("gpt2")
|
| 65 |
+
self.model.eval()
|
| 66 |
+
|
| 67 |
+
logger.info(f"Initialized PromptAnalyzer with {self.model_name}, specialization: {self.specialization}, hidden_dim: {self.hidden_dim}")
|
| 68 |
+
if self.dataset_path:
|
| 69 |
+
logger.info(f"Using dataset from: {self.dataset_path}")
|
| 70 |
+
|
| 71 |
# For caching and performance tracking
|
| 72 |
self._model_cache = {}
|
| 73 |
self._performance_metrics = {}
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
# Initialize model_class attribute
|
| 76 |
self.model_class = None
|
| 77 |
|
|
|
|
| 91 |
except Exception:
|
| 92 |
pass
|
| 93 |
|
| 94 |
+
def _load_predefined_topics(self):
|
| 95 |
+
"""Load topic keywords from config file or use defaults with caching"""
|
| 96 |
+
# Try to load from config first
|
| 97 |
+
try:
|
| 98 |
+
if hasattr(app_config, 'TOPIC_KEYWORDS') and app_config.TOPIC_KEYWORDS:
|
| 99 |
+
logger.info("Loading topic keywords from config")
|
| 100 |
+
self.predefined_topics = app_config.TOPIC_KEYWORDS
|
| 101 |
+
return
|
| 102 |
+
|
| 103 |
+
# Try loading from a JSON file in the data directory
|
| 104 |
+
topic_file = os.path.join(app_config.DATA_DIR, "topic_keywords.json")
|
| 105 |
+
if os.path.exists(topic_file):
|
| 106 |
+
with open(topic_file, 'r') as f:
|
| 107 |
+
self.predefined_topics = json.load(f)
|
| 108 |
+
logger.info(f"Loaded {len(self.predefined_topics)} topic categories from {topic_file}")
|
| 109 |
+
return
|
| 110 |
+
except Exception as e:
|
| 111 |
+
logger.warning(f"Error loading topic keywords: {e}, using defaults")
|
| 112 |
+
|
| 113 |
+
# Fall back to default hardcoded topics
|
| 114 |
+
logger.info("Using default hardcoded topic keywords")
|
| 115 |
+
self.predefined_topics = {
|
| 116 |
+
"programming": [
|
| 117 |
+
"python", "java", "javascript", "typescript", "rust", "go", "golang",
|
| 118 |
+
# ...existing keywords...
|
| 119 |
+
],
|
| 120 |
+
"computer_science": [
|
| 121 |
+
# ...existing keywords...
|
| 122 |
+
],
|
| 123 |
+
"software_engineering": [
|
| 124 |
+
# ...existing keywords...
|
| 125 |
+
],
|
| 126 |
+
"web_development": [
|
| 127 |
+
# ...existing keywords...
|
| 128 |
+
]
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
# Cache the topics to a file for future use
|
| 132 |
+
try:
|
| 133 |
+
os.makedirs(app_config.DATA_DIR, exist_ok=True)
|
| 134 |
+
with open(os.path.join(app_config.DATA_DIR, "topic_keywords.json"), 'w') as f:
|
| 135 |
+
json.dump(self.predefined_topics, f, indent=2)
|
| 136 |
+
except Exception as e:
|
| 137 |
+
logger.debug(f"Could not cache topic keywords: {e}")
|
| 138 |
+
|
| 139 |
def _init_advanced_tools(self):
|
| 140 |
+
"""Initialize advanced analysis tools with proper error handling and fallbacks"""
|
| 141 |
self.sentence_model = None
|
| 142 |
self.gpt2_model = None
|
| 143 |
self.gpt2_tokenizer = None
|
| 144 |
|
| 145 |
+
# For embedding model, implement multiple fallbacks
|
| 146 |
+
MAX_RETRIES = 3
|
| 147 |
+
embedding_models = [
|
| 148 |
+
'sentence-transformers/all-MiniLM-L6-v2', # Primary choice
|
| 149 |
+
'sentence-transformers/paraphrase-MiniLM-L3-v2', # Smaller fallback
|
| 150 |
+
'sentence-transformers/distilbert-base-nli-mean-tokens' # Last resort
|
| 151 |
+
]
|
| 152 |
+
|
| 153 |
+
for retry in range(MAX_RETRIES):
|
| 154 |
+
for model_name in embedding_models:
|
| 155 |
+
try:
|
| 156 |
+
from utils.transformer_utils import get_sentence_transformer
|
| 157 |
+
self.sentence_model = get_sentence_transformer(model_name)
|
| 158 |
+
self.logger.info(f"Successfully loaded SentenceTransformer: {model_name}")
|
| 159 |
+
break
|
| 160 |
+
except Exception as e:
|
| 161 |
+
self.logger.warning(f"Failed to load embedding model {model_name}: {e}")
|
| 162 |
|
| 163 |
+
if self.sentence_model:
|
| 164 |
+
break
|
| 165 |
+
|
| 166 |
+
# Wait before retry
|
| 167 |
+
time.sleep(2)
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
+
# Create keyword-based fallback if embedding loading completely fails
|
| 170 |
+
if not self.sentence_model:
|
| 171 |
+
self.logger.warning("All embedding models failed to load - using keyword fallback")
|
| 172 |
+
self._use_keyword_fallback = True
|
| 173 |
+
else:
|
| 174 |
+
self._use_keyword_fallback = False
|
| 175 |
+
|
| 176 |
# Initialize SmartHybridAttention
|
| 177 |
try:
|
| 178 |
attention_config = get_hybrid_attention_config()
|
|
|
|
| 335 |
self.logger.error(f"Error in attention-based analysis: {e}")
|
| 336 |
return None
|
| 337 |
|
| 338 |
+
def _analyze_with_keywords(self, prompt: str) -> Tuple[str, float]:
|
| 339 |
+
"""Analyze prompt using only keywords when embeddings are unavailable"""
|
|
|
|
|
|
|
|
|
|
| 340 |
prompt_lower = prompt.lower()
|
| 341 |
+
technical_matches = 0
|
| 342 |
+
total_words = len(prompt_lower.split())
|
| 343 |
|
| 344 |
+
# Count matches across all technical categories
|
| 345 |
+
for category, keywords in self.predefined_topics.items():
|
| 346 |
+
for keyword in keywords:
|
| 347 |
+
if keyword in prompt_lower:
|
| 348 |
+
technical_matches += 1
|
| 349 |
|
| 350 |
+
# Simple ratio calculation
|
| 351 |
+
match_ratio = technical_matches / max(1, min(15, total_words))
|
| 352 |
|
| 353 |
+
if match_ratio > 0.1: # Even a single match in a short query is significant
|
| 354 |
+
return "model_Custm", match_ratio
|
| 355 |
+
else:
|
| 356 |
+
return "model_PrTr", 0.7
|
| 357 |
+
|
| 358 |
+
def analyze_prompt(self, prompt: str) -> Tuple[str, float]:
|
| 359 |
+
"""Analyze if a prompt is technical or general and return the appropriate model type and confidence score."""
|
| 360 |
+
# Check if we need to use keyword fallback due to embedding failure
|
| 361 |
+
if hasattr(self, '_use_keyword_fallback') and self._use_keyword_fallback:
|
| 362 |
+
return self._analyze_with_keywords(prompt)
|
| 363 |
|
| 364 |
+
# Convert prompt to lowercase for case-insensitive matching
|
| 365 |
+
prompt_lower = prompt.lower()
|
| 366 |
+
|
| 367 |
+
# Check for technical keywords from predefined topics - use memory-efficient approach
|
| 368 |
+
technical_matches = 0
|
| 369 |
+
word_count = len(prompt_lower.split())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
|
| 371 |
+
# Use a set-based intersection approach for better performance on longer texts
|
| 372 |
+
prompt_words = set(prompt_lower.split())
|
| 373 |
+
|
| 374 |
+
# Count keyword matches across all technical categories more efficiently
|
| 375 |
+
for category, keywords in self.predefined_topics.items():
|
| 376 |
+
# Convert keywords to set for O(1) lookups - helps with longer texts
|
| 377 |
+
keywords_set = set(keywords)
|
| 378 |
+
matches = prompt_words.intersection(keywords_set)
|
| 379 |
+
technical_matches += len(matches)
|
| 380 |
|
| 381 |
+
# Also check for multi-word keywords not caught by simple splitting
|
| 382 |
+
for keyword in keywords:
|
| 383 |
+
if " " in keyword and keyword in prompt_lower:
|
| 384 |
+
technical_matches += 1
|
| 385 |
|
| 386 |
+
# Calculate keyword match ratio (normalized by word count)
|
| 387 |
+
keyword_ratio = technical_matches / max(1, min(20, word_count))
|
|
|
|
| 388 |
|
| 389 |
+
# Get attention-based analysis for complex prompts
|
| 390 |
+
attention_scores = None
|
| 391 |
+
if len(prompt) > 100 and self.attention: # Only use attention for longer prompts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
try:
|
| 393 |
+
attention_scores = self._analyze_with_attention(prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
except Exception as e:
|
| 395 |
+
self.logger.warning(f"Error in attention analysis: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
|
| 397 |
+
# Use embedding similarity for semantic understanding
|
| 398 |
+
try:
|
| 399 |
+
# Get embedding of the prompt
|
| 400 |
+
prompt_embedding = self.sentence_model.encode(prompt)
|
| 401 |
+
|
| 402 |
+
# Example technical and general reference texts
|
| 403 |
+
technical_reference = "Write code to solve a programming problem using algorithms and data structures."
|
| 404 |
+
general_reference = "Tell me about daily life topics like weather, food, or general conversation."
|
| 405 |
+
|
| 406 |
+
# Get embeddings for reference texts
|
| 407 |
+
technical_embedding = self.sentence_model.encode(technical_reference)
|
| 408 |
+
general_embedding = self.sentence_model.encode(general_reference)
|
| 409 |
+
|
| 410 |
+
# Calculate cosine similarities
|
| 411 |
+
technical_similarity = cosine_similarity([prompt_embedding], [technical_embedding])[0][0]
|
| 412 |
+
general_similarity = cosine_similarity([prompt_embedding], [general_embedding])[0][0]
|
| 413 |
+
|
| 414 |
+
# Calculate technical score combining all signals:
|
| 415 |
+
# 1. Keyword matching (30%)
|
| 416 |
+
# 2. Semantic similarity (40%)
|
| 417 |
+
# 3. Attention analysis if available (30%)
|
| 418 |
+
technical_score = 0.3 * keyword_ratio + 0.4 * technical_similarity
|
| 419 |
+
|
| 420 |
+
# Add attention score contribution if available
|
| 421 |
+
if attention_scores:
|
| 422 |
+
# Calculate tech score from attention - sum of programming/computer_science categories
|
| 423 |
+
tech_attention_score = (
|
| 424 |
+
attention_scores.get("programming", 0) +
|
| 425 |
+
attention_scores.get("computer_science", 0) +
|
| 426 |
+
attention_scores.get("software_engineering", 0) +
|
| 427 |
+
attention_scores.get("web_development", 0)
|
| 428 |
+
) / 4.0 # Normalize
|
| 429 |
+
technical_score += 0.3 * tech_attention_score
|
| 430 |
+
|
| 431 |
+
# Decide based on combined score
|
| 432 |
+
if technical_score > 0.3: # Threshold - tune this as needed
|
| 433 |
+
return "model_Custm", technical_score
|
| 434 |
+
else:
|
| 435 |
+
return "model_PrTr", 1.0 - technical_score
|
| 436 |
|
| 437 |
+
except Exception as e:
|
| 438 |
+
self.logger.error(f"Error in prompt analysis: {e}")
|
| 439 |
+
|
| 440 |
+
# Fallback to simple keyword matching
|
| 441 |
+
if technical_matches > 0:
|
| 442 |
+
return "model_Custm", 0.7
|
| 443 |
+
else:
|
| 444 |
+
return "model_PrTr", 0.7
|
| 445 |
+
|
| 446 |
+
def analyze(self, prompt: str) -> int:
|
| 447 |
+
"""Legacy compatibility method that returns a candidate index."""
|
| 448 |
+
model_type, confidence = self.analyze_prompt(prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
|
| 450 |
+
# Map model_type to candidate index
|
| 451 |
+
if model_type == "model_Custm":
|
| 452 |
+
return 0 # Index 0 corresponds to model_Custm
|
| 453 |
+
else:
|
| 454 |
+
return 1 # Index 1 corresponds to model_PrTr
|
| 455 |
|
| 456 |
def choose_model(self, prompt: str = None) -> Type:
|
| 457 |
"""Enhanced model selection that combines config and analysis"""
|
model_PrTr.py
CHANGED
|
@@ -58,15 +58,15 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
|
|
| 58 |
The model uses the GPT-2 tokenizer for consistent tokenization."""
|
| 59 |
def __init__(
|
| 60 |
self,
|
| 61 |
-
vocab_size: int = 50257, #
|
| 62 |
specialization: str = "general",
|
| 63 |
dataset_path: str = None,
|
| 64 |
-
model_name: str = "gpt2", #
|
| 65 |
embedding_dim: int = 768,
|
| 66 |
num_heads: int = 12,
|
| 67 |
hidden_dim: int = 768,
|
| 68 |
num_layers: int = 6,
|
| 69 |
-
output_size: int = 50257, #
|
| 70 |
dropout: float = 0.1,
|
| 71 |
max_seq_length: int = 1024, # GPT-2 supports longer contexts
|
| 72 |
pooling_mode: str = "last", # GPT-2 typically uses last token
|
|
@@ -99,15 +99,18 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
|
|
| 99 |
# Initialize the model and tokenizer
|
| 100 |
self.gpt2_model = GPT2LMHeadModel.from_pretrained(model_name)
|
| 101 |
|
| 102 |
-
#
|
| 103 |
if tokenizer is not None:
|
| 104 |
self.tokenizer = tokenizer
|
| 105 |
elif registry.has(TOKENIZER):
|
| 106 |
self.tokenizer = registry.get(TOKENIZER)
|
| 107 |
else:
|
| 108 |
self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
logger.info(f"Successfully loaded GPT-2 model: {model_name}")
|
| 113 |
|
|
@@ -135,36 +138,54 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
|
|
| 135 |
return outputs.logits
|
| 136 |
|
| 137 |
# Update generate to handle both direct prompt and tokenized input
|
| 138 |
-
def generate(self, prompt=None, input_ids=None, **kwargs):
|
| 139 |
"""Generate text using the GPT-2 model"""
|
| 140 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
# Handle either string prompt or direct input_ids
|
| 142 |
if isinstance(prompt, str) and input_ids is None:
|
| 143 |
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
|
| 144 |
input_ids = inputs.input_ids
|
| 145 |
elif input_ids is None:
|
| 146 |
raise ValueError("Either prompt or input_ids must be provided")
|
| 147 |
-
|
| 148 |
-
#
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
"temperature": kwargs.get("temperature", 0.7),
|
| 153 |
-
"top_p": kwargs.get("top_p", 0.9),
|
| 154 |
-
"top_k": kwargs.get("top_k", 50),
|
| 155 |
-
"repetition_penalty": kwargs.get("repetition_penalty", 1.0),
|
| 156 |
-
"do_sample": kwargs.get("do_sample", True),
|
| 157 |
-
"num_return_sequences": kwargs.get("num_return_sequences", 1),
|
| 158 |
-
"pad_token_id": self.tokenizer.pad_token_id
|
| 159 |
-
}
|
| 160 |
|
| 161 |
# Use max_new_tokens instead of max_length if input is longer than max_length-50
|
| 162 |
-
if input_ids.shape[1] > (
|
| 163 |
logger.info(f"Input length {input_ids.shape[1]} is close to max_length, using max_new_tokens instead")
|
| 164 |
-
del
|
| 165 |
|
| 166 |
# Generate output using the full GPT-2 model
|
| 167 |
-
output_ids = self.gpt2_model.generate(input_ids, **
|
| 168 |
|
| 169 |
# Decode the output and ensure it's a string, not a tensor
|
| 170 |
if torch.is_tensor(output_ids):
|
|
@@ -358,7 +379,7 @@ class Wildnerve_tlm01:
|
|
| 358 |
"""
|
| 359 |
def __init__(
|
| 360 |
self,
|
| 361 |
-
model_name="
|
| 362 |
tokenizer=None,
|
| 363 |
device=None,
|
| 364 |
**kwargs
|
|
@@ -408,4 +429,62 @@ try:
|
|
| 408 |
registry.register(PRETRAINED_MODEL, model, overwrite=True)
|
| 409 |
logger.info("Registered pretrained model in service registry")
|
| 410 |
except Exception as e:
|
| 411 |
-
logger.error(f"Failed to register pretrained model: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
The model uses the GPT-2 tokenizer for consistent tokenization."""
|
| 59 |
def __init__(
|
| 60 |
self,
|
| 61 |
+
vocab_size: int = 50257, # Standardized GPT-2 vocab size
|
| 62 |
specialization: str = "general",
|
| 63 |
dataset_path: str = None,
|
| 64 |
+
model_name: str = "gpt2", # Standardized to GPT-2
|
| 65 |
embedding_dim: int = 768,
|
| 66 |
num_heads: int = 12,
|
| 67 |
hidden_dim: int = 768,
|
| 68 |
num_layers: int = 6,
|
| 69 |
+
output_size: int = 50257, # Standardized GPT-2 vocab size
|
| 70 |
dropout: float = 0.1,
|
| 71 |
max_seq_length: int = 1024, # GPT-2 supports longer contexts
|
| 72 |
pooling_mode: str = "last", # GPT-2 typically uses last token
|
|
|
|
| 99 |
# Initialize the model and tokenizer
|
| 100 |
self.gpt2_model = GPT2LMHeadModel.from_pretrained(model_name)
|
| 101 |
|
| 102 |
+
# Ensure proper tokenizer setup for GPT-2
|
| 103 |
if tokenizer is not None:
|
| 104 |
self.tokenizer = tokenizer
|
| 105 |
elif registry.has(TOKENIZER):
|
| 106 |
self.tokenizer = registry.get(TOKENIZER)
|
| 107 |
else:
|
| 108 |
self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
| 109 |
+
|
| 110 |
+
# Ensure GPT-2 tokenizer has pad_token set (critical fix)
|
| 111 |
+
if self.tokenizer.pad_token_id is None:
|
| 112 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 113 |
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
| 114 |
|
| 115 |
logger.info(f"Successfully loaded GPT-2 model: {model_name}")
|
| 116 |
|
|
|
|
| 138 |
return outputs.logits
|
| 139 |
|
| 140 |
# Update generate to handle both direct prompt and tokenized input
|
| 141 |
+
def generate(self, prompt=None, input_ids=None, max_length=None, **kwargs):
|
| 142 |
"""Generate text using the GPT-2 model"""
|
| 143 |
try:
|
| 144 |
+
# Try to use adapter_layer.generate if available (consolidate generation paths)
|
| 145 |
+
adapter_layer = registry.get("adapter_layer")
|
| 146 |
+
if adapter_layer and hasattr(adapter_layer, "generate"):
|
| 147 |
+
if prompt:
|
| 148 |
+
return adapter_layer.generate(prompt, max_length=max_length, **kwargs)
|
| 149 |
+
elif input_ids is not None and self.tokenizer:
|
| 150 |
+
# Convert input_ids back to text
|
| 151 |
+
prompt = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
| 152 |
+
return adapter_layer.generate(prompt, max_length=max_length, **kwargs)
|
| 153 |
+
|
| 154 |
+
# Continue with direct generation if adapter_layer not available
|
| 155 |
+
# Enhanced generation parameters
|
| 156 |
+
generation_config = {
|
| 157 |
+
"max_length": max_length or 150,
|
| 158 |
+
"temperature": kwargs.get('temperature', 0.7),
|
| 159 |
+
"top_p": kwargs.get('top_p', 0.95),
|
| 160 |
+
"top_k": kwargs.get('top_k', 50),
|
| 161 |
+
"repetition_penalty": kwargs.get('repetition_penalty', 1.3),
|
| 162 |
+
"no_repeat_ngram_size": kwargs.get('no_repeat_ngram_size', 3),
|
| 163 |
+
"do_sample": True,
|
| 164 |
+
"pad_token_id": self.tokenizer.pad_token_id,
|
| 165 |
+
"eos_token_id": self.tokenizer.eos_token_id,
|
| 166 |
+
"early_stopping": True,
|
| 167 |
+
"penalty_alpha": 0.6 # Add penalty alpha for better response quality
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
# Handle either string prompt or direct input_ids
|
| 171 |
if isinstance(prompt, str) and input_ids is None:
|
| 172 |
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
|
| 173 |
input_ids = inputs.input_ids
|
| 174 |
elif input_ids is None:
|
| 175 |
raise ValueError("Either prompt or input_ids must be provided")
|
| 176 |
+
|
| 177 |
+
# Add user-provided kwargs that we didn't explicitly set
|
| 178 |
+
for k, v in kwargs.items():
|
| 179 |
+
if k not in generation_config and k not in ('prompt', 'context'):
|
| 180 |
+
generation_config[k] = v
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
# Use max_new_tokens instead of max_length if input is longer than max_length-50
|
| 183 |
+
if input_ids.shape[1] > (generation_config["max_length"] - 50):
|
| 184 |
logger.info(f"Input length {input_ids.shape[1]} is close to max_length, using max_new_tokens instead")
|
| 185 |
+
del generation_config["max_length"]
|
| 186 |
|
| 187 |
# Generate output using the full GPT-2 model
|
| 188 |
+
output_ids = self.gpt2_model.generate(input_ids, **generation_config)
|
| 189 |
|
| 190 |
# Decode the output and ensure it's a string, not a tensor
|
| 191 |
if torch.is_tensor(output_ids):
|
|
|
|
| 379 |
"""
|
| 380 |
def __init__(
|
| 381 |
self,
|
| 382 |
+
model_name="gpt2",
|
| 383 |
tokenizer=None,
|
| 384 |
device=None,
|
| 385 |
**kwargs
|
|
|
|
| 429 |
registry.register(PRETRAINED_MODEL, model, overwrite=True)
|
| 430 |
logger.info("Registered pretrained model in service registry")
|
| 431 |
except Exception as e:
|
| 432 |
+
logger.error(f"Failed to register pretrained model: {e}")
|
| 433 |
+
|
| 434 |
+
def initialize_system():
|
| 435 |
+
"""Initialize all components in the correct order"""
|
| 436 |
+
logger.info("Starting system initialization")
|
| 437 |
+
|
| 438 |
+
# First tokenizer - Use GPT-2 tokenizer instead of BERT
|
| 439 |
+
try:
|
| 440 |
+
from transformers import GPT2Tokenizer
|
| 441 |
+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
| 442 |
+
# GPT-2 tokenizer doesn't have a pad_token by default, so we set it
|
| 443 |
+
if tokenizer.pad_token is None:
|
| 444 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 445 |
+
except Exception as e:
|
| 446 |
+
logger.warning(f"Could not load GPT-2 tokenizer, falling back to wrapper: {e}")
|
| 447 |
+
from tokenizer import TokenizerWrapper
|
| 448 |
+
tokenizer = TokenizerWrapper(model_name="gpt2")
|
| 449 |
+
|
| 450 |
+
# Then register tokenizer
|
| 451 |
+
from service_registry import registry, TOKENIZER, PRETRAINED_MODEL
|
| 452 |
+
registry.register(TOKENIZER, tokenizer, overwrite=True)
|
| 453 |
+
logger.info("Tokenizer registered")
|
| 454 |
+
|
| 455 |
+
# Initialize pretrained model (GPT-2)
|
| 456 |
+
try:
|
| 457 |
+
from model_PrTr import Wildnerve_tlm01 as PretrainedModel
|
| 458 |
+
pretrained = PretrainedModel(model_name="gpt2", tokenizer=tokenizer)
|
| 459 |
+
registry.register(PRETRAINED_MODEL, pretrained, overwrite=True)
|
| 460 |
+
logger.info("GPT-2 pretrained model registered")
|
| 461 |
+
except Exception as e:
|
| 462 |
+
logger.error(f"Failed to initialize GPT-2 model: {e}", exc_info=True)
|
| 463 |
+
|
| 464 |
+
# Now load custom model
|
| 465 |
+
try:
|
| 466 |
+
from model_Custm import Wildnerve_tlm01
|
| 467 |
+
model = Wildnerve_tlm01(
|
| 468 |
+
vocab_size=50257, # Match GPT-2 vocab size
|
| 469 |
+
specialization="general",
|
| 470 |
+
dataset_path=None,
|
| 471 |
+
model_name="gpt2", # Use GPT-2 compatibility
|
| 472 |
+
embedding_dim=768,
|
| 473 |
+
num_heads=12,
|
| 474 |
+
hidden_dim=768,
|
| 475 |
+
num_layers=2,
|
| 476 |
+
output_size=50257, # Match GPT-2 vocab
|
| 477 |
+
dropout=0.1,
|
| 478 |
+
max_seq_length=128,
|
| 479 |
+
pooling_mode="mean",
|
| 480 |
+
tokenizer=tokenizer
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
# Register model
|
| 484 |
+
from service_registry import MODEL
|
| 485 |
+
registry.register(MODEL, model, overwrite=True)
|
| 486 |
+
logger.info("Custom model registered successfully")
|
| 487 |
+
return True
|
| 488 |
+
except Exception as e:
|
| 489 |
+
logger.error(f"Failed to initialize custom model: {e}", exc_info=True)
|
| 490 |
+
return False
|
service_registry.py
CHANGED
|
@@ -137,18 +137,18 @@ def ensure_models_registered():
|
|
| 137 |
tok = registry.get(TOKENIZER)
|
| 138 |
if not tok:
|
| 139 |
from tokenizer import TokenizerWrapper
|
| 140 |
-
tok = TokenizerWrapper(model_name="gpt2")
|
| 141 |
registry.register(TOKENIZER, tok, overwrite=True)
|
| 142 |
|
| 143 |
# Create pretrained model
|
| 144 |
model = model_class(
|
| 145 |
-
model_name="gpt2",
|
| 146 |
tokenizer=tok
|
| 147 |
)
|
| 148 |
|
| 149 |
# Register as pretrained model
|
| 150 |
registry.register(PRETRAINED_MODEL, model, overwrite=True)
|
| 151 |
-
logger.info("Successfully registered pretrained model")
|
| 152 |
return True
|
| 153 |
|
| 154 |
logger.error(f"model_PrTr.py not found at {model_path}")
|
|
|
|
| 137 |
tok = registry.get(TOKENIZER)
|
| 138 |
if not tok:
|
| 139 |
from tokenizer import TokenizerWrapper
|
| 140 |
+
tok = TokenizerWrapper(model_name="gpt2") # Changed from bert-base-uncased
|
| 141 |
registry.register(TOKENIZER, tok, overwrite=True)
|
| 142 |
|
| 143 |
# Create pretrained model
|
| 144 |
model = model_class(
|
| 145 |
+
model_name="gpt2", # Explicitly use gpt2
|
| 146 |
tokenizer=tok
|
| 147 |
)
|
| 148 |
|
| 149 |
# Register as pretrained model
|
| 150 |
registry.register(PRETRAINED_MODEL, model, overwrite=True)
|
| 151 |
+
logger.info("Successfully registered GPT-2 pretrained model")
|
| 152 |
return True
|
| 153 |
|
| 154 |
logger.error(f"model_PrTr.py not found at {model_path}")
|