Nour611 commited on
Commit
92ede4e
·
verified ·
1 Parent(s): 0dd54c3

Upload hybrid_module.py

Browse files
Files changed (1) hide show
  1. hybrid_module.py +147 -0
hybrid_module.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # hybrid_module.py
3
+
4
+ import torch
5
+ import pickle
6
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ # ---------- Load Bigram ----------
10
+
11
+ def load_bigram(repo_id="bayan10/AutoComplete", filename="bigram_model_v4.pkl"):
12
+ path = hf_hub_download(repo_id=repo_id, filename=filename)
13
+ with open(path, "rb") as f:
14
+ data = pickle.load(f)
15
+ return data["unigrams"], data["bigrams"]
16
+
17
+ # ---------- Load GPT-2 ----------
18
+ def load_gpt2(model_name="aubmindlab/aragpt2-base"):
19
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
20
+ model = GPT2LMHeadModel.from_pretrained(model_name)
21
+ tokenizer.pad_token = tokenizer.eos_token
22
+ model.config.pad_token_id = tokenizer.eos_token_id
23
+ model.eval()
24
+ return tokenizer, model
25
+
26
+ # ---------- GPT-2 scoring ----------
27
+ def gpt2_next_token_probs(prefix, tokenizer, model, top_k=50):
28
+ inputs = tokenizer(
29
+ prefix,
30
+ return_tensors="pt",
31
+ truncation=True,
32
+ max_length=1024
33
+ )
34
+
35
+ with torch.no_grad():
36
+ outputs = model(**inputs)
37
+ logits = outputs.logits[0, -1]
38
+
39
+ probs = torch.softmax(logits, dim=-1)
40
+ top_probs, top_ids = torch.topk(probs, top_k)
41
+
42
+ prob_dict = {}
43
+ for idx, prob in zip(top_ids, top_probs):
44
+ word = tokenizer.decode([idx]).strip()
45
+ if word:
46
+ prob_dict[word] = prob.item()
47
+
48
+ return prob_dict
49
+
50
+ # ---------- Statistical autocomplete ----------
51
+
52
+
53
+ def statistical_autocomplete(text, unigrams, bigrams, top_k=20):
54
+ tokens = text.strip().split()
55
+ if not tokens:
56
+ return []
57
+
58
+ last_word = tokens[-1]
59
+ candidates = []
60
+
61
+ if last_word in bigrams:
62
+ for w, c in bigrams[last_word].items():
63
+ if len(w) < 3 or w == last_word:
64
+ continue
65
+ candidates.append((w, c))
66
+
67
+ if not candidates:
68
+ for w, c in unigrams.items():
69
+ if len(w) < 3:
70
+ continue
71
+ candidates.append((w, c))
72
+
73
+ total = sum(c for _, c in candidates)
74
+ preds = [(w, c / total) for w, c in candidates]
75
+ preds.sort(key=lambda x: x[1], reverse=True)
76
+ preds = merge_similar_predictions(preds, top_k=top_k)
77
+ return preds[:top_k]
78
+
79
+ # ---------- Hybrid autocomplete ----------
80
+ def hybrid_autocomplete(prefix, unigrams, bigrams, tokenizer, model, alpha=0.6, k=5):
81
+ words = prefix.strip().split()
82
+ if len(words) < 1:
83
+ return []
84
+
85
+ last_word = words[-1]
86
+ if last_word not in bigrams:
87
+ return []
88
+
89
+ # -------- Statistical (Bigram) --------
90
+ stat_candidates = statistical_autocomplete(
91
+ prefix,
92
+ unigrams,
93
+ bigrams,
94
+ top_k=20
95
+ )
96
+
97
+ # -------- Neural (GPT-2) — ONCE --------
98
+ gpt2_probs = gpt2_next_token_probs(prefix, tokenizer, model, top_k=50)
99
+
100
+ # -------- Hybrid scoring --------
101
+ results = []
102
+ for w, stat_p in stat_candidates:
103
+ neural_p = gpt2_probs.get(w, 1e-8) # small value if not found
104
+ score = alpha * stat_p + (1 - alpha) * neural_p
105
+ results.append((w, score))
106
+
107
+ return sorted(results, key=lambda x: x[1], reverse=True)[:k]
108
+
109
+ import re
110
+ from collections import defaultdict
111
+
112
+ def canonical_form(word):
113
+ word = re.sub("[إأآا]", "ا", word)
114
+ word = re.sub("ى", "ي", word)
115
+ word = re.sub("ؤ", "و", word)
116
+ word = re.sub("ئ", "ي", word)
117
+ word = re.sub("ة", "ه", word)
118
+ word = re.sub(r"[ًٌٍَُِّْ]", "", word)
119
+ return word
120
+
121
+
122
+
123
+ def merge_similar_predictions(preds, top_k=20):
124
+ groups = defaultdict(lambda: {"score": 0.0, "words": []})
125
+
126
+ for w, p in preds:
127
+ key = canonical_form(w)
128
+ groups[key]["score"] += p
129
+ groups[key]["words"].append(w)
130
+
131
+ merged = sorted(
132
+ groups.values(),
133
+ key=lambda x: x["score"],
134
+ reverse=True
135
+ )
136
+
137
+ return [
138
+ (group["words"][0], group["score"])
139
+ for group in merged[:top_k]
140
+ ]
141
+
142
+
143
+
144
+
145
+
146
+
147
+