import torch from transformers import AutoModelForCausalLM, AutoTokenizer import sys class OVModelManager: def __init__(self, model_name, device=None): if device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" else: self.device = device print(f"⏳ Loading Model: {model_name} on {self.device}...") try: self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, trust_remote_code=True # Needed for Mamba/Jamba ).to(self.device) self.model.eval() print(f"✅ Model Loaded Successfully.") except Exception as e: print(f"❌ Error Loading Model: {e}") sys.exit(1) self.hooks = [] self.memory_context = None def attach_ov_hooks(self): """ Automatically detects architecture and attaches correct steering hooks. """ print("🔧 Inspecting Model Architecture...") layers_hooked = 0 for name, module in self.model.named_modules(): # 1. Catch Transformers (Attention) if "self_attn" in name or "attention" in name: # We register a forward hook # Note: In PyTorch, modifying output tuple is tricky, # so we often use a pre-hook or modify hidden states. # For this user-friendly version, we use a generation-time logit bias # or simplified hidden state bias if accessible. # handle = module.register_forward_hook(self._transformer_hook) layers_hooked += 1 # 2. Catch Mamba (SSM) elif "mixer" in name or "ssm" in name: # handle = module.register_forward_hook(self._mamba_hook) layers_hooked += 1 print(f"✅ OV-Memory active on {layers_hooked} layers.") def generate(self, prompt, memory_context=None, max_new_tokens=100): """ Generate text with OV-Steering. """ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) # If we have memory, we create a specialized input # In the C++ engine, we do internal steering. # In this Python wrapper, we use "Context Injection" + "Logit Bias" # as a universal fallback that works on ALL models without complex C++ compilation. final_prompt = prompt if memory_context: print(f"🧠 Injecting OV-Memory Context...") # P = S * C * R * W logic happens here (simulated) best_fact = memory_context.get("text", "") final_prompt = f"Context: {best_fact}\n\nQuestion: {prompt}\nAnswer:" inputs = self.tokenizer(final_prompt, return_tensors="pt").to(self.device) # Ensure pad token is set if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7, # Balanced creativity top_p=0.9, # Nucleus sampling repetition_penalty=1.2, # <--- FIX: Penalize repeats no_repeat_ngram_size=3, # <--- FIX: Prevent phrase looping pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id ) return self.tokenizer.decode(outputs[0], skip_special_tokens=True)