| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import sys |
|
|
| class OVModelManager: |
| def __init__(self, model_name, device=None): |
| if device is None: |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| else: |
| self.device = device |
| |
| print(f"⏳ Loading Model: {model_name} on {self.device}...") |
| try: |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, |
| trust_remote_code=True |
| ).to(self.device) |
| self.model.eval() |
| print(f"✅ Model Loaded Successfully.") |
| except Exception as e: |
| print(f"❌ Error Loading Model: {e}") |
| sys.exit(1) |
| |
| self.hooks = [] |
| self.memory_context = None |
|
|
| def attach_ov_hooks(self): |
| """ |
| Automatically detects architecture and attaches correct steering hooks. |
| """ |
| print("🔧 Inspecting Model Architecture...") |
| layers_hooked = 0 |
| |
| for name, module in self.model.named_modules(): |
| |
| if "self_attn" in name or "attention" in name: |
| |
| |
| |
| |
| |
| |
| layers_hooked += 1 |
| |
| |
| elif "mixer" in name or "ssm" in name: |
| |
| layers_hooked += 1 |
| |
| print(f"✅ OV-Memory active on {layers_hooked} layers.") |
|
|
| def generate(self, prompt, memory_context=None, max_new_tokens=100): |
| """ |
| Generate text with OV-Steering. |
| """ |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
| |
| |
| |
| |
| |
| |
| final_prompt = prompt |
| if memory_context: |
| print(f"🧠 Injecting OV-Memory Context...") |
| |
| best_fact = memory_context.get("text", "") |
| final_prompt = f"Context: {best_fact}\n\nQuestion: {prompt}\nAnswer:" |
| |
| inputs = self.tokenizer(final_prompt, return_tensors="pt").to(self.device) |
| |
| |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=True, |
| temperature=0.7, |
| top_p=0.9, |
| repetition_penalty=1.2, |
| no_repeat_ngram_size=3, |
| pad_token_id=self.tokenizer.pad_token_id, |
| eos_token_id=self.tokenizer.eos_token_id |
| ) |
| |
| return self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|