File size: 3,866 Bytes
92ede4e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
# hybrid_module.py
import torch
import pickle
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from huggingface_hub import hf_hub_download
# ---------- Load Bigram ----------
def load_bigram(repo_id="bayan10/AutoComplete", filename="bigram_model_v4.pkl"):
path = hf_hub_download(repo_id=repo_id, filename=filename)
with open(path, "rb") as f:
data = pickle.load(f)
return data["unigrams"], data["bigrams"]
# ---------- Load GPT-2 ----------
def load_gpt2(model_name="aubmindlab/aragpt2-base"):
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
model.eval()
return tokenizer, model
# ---------- GPT-2 scoring ----------
def gpt2_next_token_probs(prefix, tokenizer, model, top_k=50):
inputs = tokenizer(
prefix,
return_tensors="pt",
truncation=True,
max_length=1024
)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits[0, -1]
probs = torch.softmax(logits, dim=-1)
top_probs, top_ids = torch.topk(probs, top_k)
prob_dict = {}
for idx, prob in zip(top_ids, top_probs):
word = tokenizer.decode([idx]).strip()
if word:
prob_dict[word] = prob.item()
return prob_dict
# ---------- Statistical autocomplete ----------
def statistical_autocomplete(text, unigrams, bigrams, top_k=20):
tokens = text.strip().split()
if not tokens:
return []
last_word = tokens[-1]
candidates = []
if last_word in bigrams:
for w, c in bigrams[last_word].items():
if len(w) < 3 or w == last_word:
continue
candidates.append((w, c))
if not candidates:
for w, c in unigrams.items():
if len(w) < 3:
continue
candidates.append((w, c))
total = sum(c for _, c in candidates)
preds = [(w, c / total) for w, c in candidates]
preds.sort(key=lambda x: x[1], reverse=True)
preds = merge_similar_predictions(preds, top_k=top_k)
return preds[:top_k]
# ---------- Hybrid autocomplete ----------
def hybrid_autocomplete(prefix, unigrams, bigrams, tokenizer, model, alpha=0.6, k=5):
words = prefix.strip().split()
if len(words) < 1:
return []
last_word = words[-1]
if last_word not in bigrams:
return []
# -------- Statistical (Bigram) --------
stat_candidates = statistical_autocomplete(
prefix,
unigrams,
bigrams,
top_k=20
)
# -------- Neural (GPT-2) — ONCE --------
gpt2_probs = gpt2_next_token_probs(prefix, tokenizer, model, top_k=50)
# -------- Hybrid scoring --------
results = []
for w, stat_p in stat_candidates:
neural_p = gpt2_probs.get(w, 1e-8) # small value if not found
score = alpha * stat_p + (1 - alpha) * neural_p
results.append((w, score))
return sorted(results, key=lambda x: x[1], reverse=True)[:k]
import re
from collections import defaultdict
def canonical_form(word):
word = re.sub("[إأآا]", "ا", word)
word = re.sub("ى", "ي", word)
word = re.sub("ؤ", "و", word)
word = re.sub("ئ", "ي", word)
word = re.sub("ة", "ه", word)
word = re.sub(r"[ًٌٍَُِّْ]", "", word)
return word
def merge_similar_predictions(preds, top_k=20):
groups = defaultdict(lambda: {"score": 0.0, "words": []})
for w, p in preds:
key = canonical_form(w)
groups[key]["score"] += p
groups[key]["words"].append(w)
merged = sorted(
groups.values(),
key=lambda x: x["score"],
reverse=True
)
return [
(group["words"][0], group["score"])
for group in merged[:top_k]
]
|