WildnerveAI commited on
Commit
b671111
·
verified ·
1 Parent(s): 1a8d9bc

Upload 3 files

Browse files
Files changed (3) hide show
  1. model_Custm.py +1 -0
  2. model_List.py +49 -89
  3. transformer_patches.py +44 -0
model_Custm.py CHANGED
@@ -158,6 +158,7 @@ class Wildnerve_tlm01(nn.Module, AbstractModel):
158
  super().__init__()
159
  # Set device once at the start
160
  object.__setattr__(self, "device", torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
 
161
  self.specialization = specialization
162
  self.dataset_path = dataset_path
163
  self.model_name = model_name
 
158
  super().__init__()
159
  # Set device once at the start
160
  object.__setattr__(self, "device", torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
161
+ logger.info(f"Model initialized on device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")
162
  self.specialization = specialization
163
  self.dataset_path = dataset_path
164
  self.model_name = model_name
model_List.py CHANGED
@@ -32,104 +32,64 @@ class PromptAnalyzer:
32
  - Provides candidate model identifiers or a single best match.
33
  """
34
  def __init__(self):
35
- # Predefined topics with keyword sets for topic understanding
36
- self.predefined_topics: Dict[str, List[str]] = {
37
- "general": ["general", "overview", "basic", "introduction"],
38
- "programming": ["code", "programming", "debug", "software", "algorithm", "bug"],
39
- "science": ["research", "experiment", "science", "physics", "biology", "chemistry"],
40
- "history": ["history", "ancient", "modern", "civilization", "war"],
41
- "mathematics": ["math", "algebra", "calculus", "geometry", "statistics"]
 
42
  }
43
- # Initialize a lightweight transformer encoder for embeddings
44
- self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
45
- self.encoder = AutoModel.from_pretrained("distilbert-base-uncased")
46
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
- self.encoder.to(self.device)
48
- # Initialize SmartHybridAttention for refined representations
49
- attention_config = get_hybrid_attention_config()
50
- self.attention = SmartHybridAttention(attention_config)
51
- self.attention.to(self.device)
52
- logger.info("PromptAnalyzer initialized with DistilBERT and SmartHybridAttention.")
53
-
54
- def _encode_text(self, text: str) -> np.ndarray:
55
- """
56
- Encode text into an embedding vector.
57
- First, obtain token embeddings using DistilBERT.
58
- Then refine these embeddings with SmartHybridAttention.
59
- Finally, average-pool to produce a single vector.
60
- """
61
- inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
62
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
63
- with torch.no_grad():
64
- outputs = self.encoder(**inputs) # shape: [batch, seq_len, dim]
65
- token_embeds = outputs.last_hidden_state # [1, seq_len, dim]
66
- # Transpose for attention: [seq_len, batch, dim]
67
- token_embeds = token_embeds.transpose(0, 1)
68
- attended, _ = self.attention(query=token_embeds, key=token_embeds, value=token_embeds)
69
- # Transpose back and pool over tokens: [batch, seq_len, dim] -> [batch, dim]
70
- attended = attended.transpose(0, 1)
71
- pooled = attended.mean(dim=1)
72
- return pooled.squeeze().cpu().numpy()
73
-
74
- def analyze_prompt(self, prompt: str) -> Tuple[str, List[str]]:
75
- """
76
- Analyze the given prompt:
77
- - Compute its refined embedding.
78
- - For each predefined topic, encode its keyword string.
79
- - Compute cosine similarity between prompt and topic embeddings.
80
- - Return the primary topic (highest similarity) and any subtopics
81
- with similarity above 80% of the top score.
82
- """
83
- prompt_embedding = self._encode_text(prompt)
84
  topic_scores = {}
 
85
  for topic, keywords in self.predefined_topics.items():
86
- topic_text = " ".join(keywords)
87
- topic_embedding = self._encode_text(topic_text)
88
- similarity = cosine_similarity(
89
- prompt_embedding.reshape(1, -1),
90
- topic_embedding.reshape(1, -1)
91
- )[0][0]
92
- topic_scores[topic] = similarity
93
- sorted_topics = sorted(topic_scores.items(), key=lambda x: x[1], reverse=True)
94
- primary_topic = sorted_topics[0][0] if sorted_topics else "general"
95
- threshold = sorted_topics[0][1] * 0.8 if sorted_topics else 0.0
96
- subtopics = [topic for topic, score in sorted_topics if score >= threshold and topic != primary_topic]
97
- logger.debug(f"Prompt analyzed (first 30 chars): '{prompt[:30]}...' -> Primary: {primary_topic}, Subtopics: {subtopics}")
98
- return primary_topic, subtopics
99
-
 
100
  def get_selected_models(self):
101
  """Return the list of selected models, always with model_Custm as primary."""
102
- # Always prioritize model_Custm for all specializations
103
  return ["model_Custm.py", "model_PrTr.py"]
104
 
105
  def choose_model(self, prompt=None):
106
- """Choose model_Custm regardless of prompt content."""
 
 
 
 
107
  try:
108
- # Ensure model_Custm is imported and registered
109
- import importlib.util
110
- import os
111
-
112
- # Get the directory containing this file
113
- this_dir = os.path.dirname(os.path.abspath(__file__))
114
-
115
- # Load model_Custm
116
- model_path = os.path.join(this_dir, "model_Custm.py")
117
- if os.path.exists(model_path):
118
- spec = importlib.util.spec_from_file_location("model_custm", model_path)
119
- model_module = importlib.util.module_from_spec(spec)
120
- spec.loader.exec_module(model_module)
121
-
122
- # Register in service registry
123
- from service_registry import registry, MODEL, ensure_models_registered
124
- ensure_models_registered() # Make sure it's registered
125
-
126
- # Return the model class
127
- return model_module.Wildnerve_tlm01
128
- else:
129
- self.logger.error(f"model_Custm.py not found at {model_path}")
130
- return None
131
- except Exception as e:
132
- self.logger.error(f"Error in choose_model: {e}")
133
  return None
134
 
135
  # Register the PromptAnalyzer in the service registry to resolve dependencies.
 
32
  - Provides candidate model identifiers or a single best match.
33
  """
34
  def __init__(self):
35
+ self.logger = logging.getLogger(__name__)
36
+
37
+ # Define topic keywords
38
+ self.predefined_topics = {
39
+ "programming": ["code", "function", "class", "algorithm", "programming", "python", "javascript", "java", "c++", "developer", "api"],
40
+ "science": ["science", "physics", "chemistry", "biology", "scientific", "experiment", "hypothesis", "theory"],
41
+ "mathematics": ["math", "equation", "calculus", "algebra", "geometry", "theorem", "mathematical"],
42
+ "history": ["history", "historical", "ancient", "century", "war", "civilization", "empire"]
43
  }
44
+
45
+ # IMPORTANT CHANGE: Don't load AutoModel, directly use model_Custm.Wildnerve_tlm01
46
+ try:
47
+ # Import the Wildnerve model directly - no AutoModel usage
48
+ from model_Custm import Wildnerve_tlm01
49
+ self.model_class = Wildnerve_tlm01
50
+ self.logger.info("Successfully imported Wildnerve_tlm01 from model_Custm")
51
+ except Exception as e:
52
+ self.logger.warning(f"Failed to import Wildnerve_tlm01: {e}")
53
+ self.model_class = None
54
+
55
+ def analyze_prompt(self, prompt):
56
+ """Analyze prompt to determine primary and secondary topics"""
57
+ # Simple keyword-based classification
58
+ prompt_lower = prompt.lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  topic_scores = {}
60
+
61
  for topic, keywords in self.predefined_topics.items():
62
+ score = sum(1 for keyword in keywords if keyword in prompt_lower)
63
+ topic_scores[topic] = score
64
+
65
+ # Find the topic with the highest score
66
+ if not topic_scores or max(topic_scores.values()) == 0:
67
+ return "general", []
68
+
69
+ primary_topic = max(topic_scores.items(), key=lambda x: x[1])[0]
70
+
71
+ # Get secondary topics (any with non-zero scores except primary)
72
+ secondary_topics = [t for t, s in topic_scores.items()
73
+ if s > 0 and t != primary_topic]
74
+
75
+ return primary_topic, secondary_topics
76
+
77
  def get_selected_models(self):
78
  """Return the list of selected models, always with model_Custm as primary."""
79
+ # Always use model_Custm.py as the primary model
80
  return ["model_Custm.py", "model_PrTr.py"]
81
 
82
  def choose_model(self, prompt=None):
83
+ """Always choose model_Custm regardless of prompt content"""
84
+ if self.model_class:
85
+ return self.model_class
86
+
87
+ # Try importing again if initial import failed
88
  try:
89
+ from model_Custm import Wildnerve_tlm01
90
+ return Wildnerve_tlm01
91
+ except ImportError as e:
92
+ self.logger.error(f"Failed to import Wildnerve_tlm01: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  return None
94
 
95
  # Register the PromptAnalyzer in the service registry to resolve dependencies.
transformer_patches.py CHANGED
@@ -260,3 +260,47 @@ def apply_patch_to_layer(layer):
260
  return out
261
 
262
  layer.forward = forward_with_debug
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  return out
261
 
262
  layer.forward = forward_with_debug
263
+
264
+ """
265
+ Patches for the transformers library to ensure compatibility
266
+ """
267
+ import logging
268
+ from types import FunctionType
269
+
270
+ logger = logging.getLogger(__name__)
271
+
272
+ def apply_transformers_patches():
273
+ """Apply patches to transformers library"""
274
+ try:
275
+ import torch
276
+ import transformers
277
+
278
+ # Only apply safe patches that don't interfere with GPU usage
279
+ # Don't replace torch.device with a CPU-only version!
280
+
281
+ # Fix AutoModel.from_pretrained to handle device mapping safely
282
+ if hasattr(transformers, 'AutoModel'):
283
+ original_from_pretrained = transformers.AutoModel.from_pretrained
284
+
285
+ def safe_from_pretrained(*args, **kwargs):
286
+ # Keep any device_map parameter but handle it safely
287
+ if 'device_map' in kwargs and not isinstance(kwargs['device_map'], (str, dict)):
288
+ logger.info("Fixing invalid device_map parameter")
289
+ kwargs['device_map'] = "auto" if torch.cuda.is_available() else None
290
+
291
+ # Use cuda for faster performance if available
292
+ if 'torch_dtype' not in kwargs:
293
+ kwargs['torch_dtype'] = torch.float16 if torch.cuda.is_available() else torch.float32
294
+
295
+ return original_from_pretrained(*args, **kwargs)
296
+
297
+ transformers.AutoModel.from_pretrained = safe_from_pretrained
298
+ logger.info("Applied patch to AutoModel.from_pretrained that preserves GPU usage")
299
+
300
+ return True
301
+ except Exception as e:
302
+ logger.error(f"Failed to apply transformers patches: {e}")
303
+ return False
304
+
305
+ # Apply patches when module is imported
306
+ apply_transformers_patches()