mvi-ai-engine / core /humanizer_encoder.py
Musombi's picture
Update core/humanizer_encoder.py
9d695a9 verified
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
import os
class HumanizerEncoder:
def __init__(self, model_path, device):
self.device = device
self.model = None
self.tokenizer = None
if not os.path.exists(model_path):
print("[HUMANIZER] No pretrained weights found")
return
# ===== Initialize T5 tokenizer & model =====
try:
self.tokenizer = T5Tokenizer.from_pretrained("t5-small")
# Load model from checkpoint if it's a state_dict
checkpoint = torch.load(model_path, map_location=device)
if isinstance(checkpoint, dict):
self.model = T5ForConditionalGeneration.from_pretrained("t5-small")
self.model.load_state_dict(checkpoint, strict=False)
else:
self.model = checkpoint # full model saved
self.model = self.model.to(device)
self.model.eval()
print("[HUMANIZER] Humanizer loaded →", model_path)
except Exception as e:
print("[HUMANIZER] Failed to load model:", e)
self.model = None
# ===== Real text rewriting =====
def generate(self, text: str, max_length=64):
if self.model is None or self.tokenizer is None:
return text # fallback: just return original text
# Tokenize input
inputs = self.tokenizer(
text,
return_tensors="pt",
max_length=max_length,
truncation=True
).to(self.device)
# Generate output
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
max_length=max_length,
num_beams=4,
early_stopping=True
)
# Decode generated text
rewritten_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
return rewritten_text