Spaces:
Running
Running
File size: 1,978 Bytes
5c47583 9d695a9 70511f8 9d695a9 01678a2 9d695a9 70511f8 9d695a9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
import os
class HumanizerEncoder:
def __init__(self, model_path, device):
self.device = device
self.model = None
self.tokenizer = None
if not os.path.exists(model_path):
print("[HUMANIZER] No pretrained weights found")
return
# ===== Initialize T5 tokenizer & model =====
try:
self.tokenizer = T5Tokenizer.from_pretrained("t5-small")
# Load model from checkpoint if it's a state_dict
checkpoint = torch.load(model_path, map_location=device)
if isinstance(checkpoint, dict):
self.model = T5ForConditionalGeneration.from_pretrained("t5-small")
self.model.load_state_dict(checkpoint, strict=False)
else:
self.model = checkpoint # full model saved
self.model = self.model.to(device)
self.model.eval()
print("[HUMANIZER] Humanizer loaded →", model_path)
except Exception as e:
print("[HUMANIZER] Failed to load model:", e)
self.model = None
# ===== Real text rewriting =====
def generate(self, text: str, max_length=64):
if self.model is None or self.tokenizer is None:
return text # fallback: just return original text
# Tokenize input
inputs = self.tokenizer(
text,
return_tensors="pt",
max_length=max_length,
truncation=True
).to(self.device)
# Generate output
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
max_length=max_length,
num_beams=4,
early_stopping=True
)
# Decode generated text
rewritten_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
return rewritten_text |