santh-cpu commited on
Commit
ce50f5c
·
verified ·
1 Parent(s): 22825f5

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +199 -0
model.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, T5EncoderModel, RobertaTokenizer
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
+ MAX_LEN = 256
10
+ THRESHOLD = 0.475
11
+ REPO_ID = "santh-cpu/ai_code_detect"
12
+
13
+ class PolyglotMetricEngine(nn.Module):
14
+ def __init__(self, base_model):
15
+ super().__init__()
16
+ self.model = base_model
17
+
18
+ @torch.no_grad()
19
+ def forward(self, input_ids, attention_mask):
20
+ B, L = input_ids.shape
21
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
22
+ logits_raw = self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False).logits
23
+ logits_raw = logits_raw.float()
24
+ shift_logits = logits_raw[:, :-1, :].contiguous()
25
+ shift_labels = input_ids[:, 1:].contiguous()
26
+ shift_mask = attention_mask[:, 1:].float()
27
+ log_probs_all = F.log_softmax(shift_logits, dim=-1)
28
+ probs_all = log_probs_all.exp()
29
+ log_prob = log_probs_all.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
30
+ K_max = min(1001, shift_logits.size(-1))
31
+ topk_vals, topk_idx = torch.topk(shift_logits, K_max, dim=-1)
32
+ rank_approx = (log_probs_all.gather(2, topk_idx) > log_prob.unsqueeze(-1)).sum(-1).float() + 1.0
33
+ true_rank_log = torch.log1p(rank_approx)
34
+ top10_mass = log_probs_all.gather(2, topk_idx[:, :, :10]).exp().sum(dim=-1)
35
+ lp_topk = log_probs_all.gather(2, topk_idx[:, :, :2])
36
+ gap_1_2 = (lp_topk[:, :, 0] - lp_topk[:, :, 1]).clamp(-20, 20)
37
+ entropy = -(probs_all * log_probs_all).sum(dim=-1)
38
+ varentropy = (probs_all * (-log_probs_all - entropy.unsqueeze(-1))**2).sum(dim=-1)
39
+ r10_flag = (true_rank_log <= math.log1p(10)).float()
40
+ r100_flag = ((true_rank_log > math.log1p(10)) & (true_rank_log <= math.log1p(100))).float()
41
+ r1k_flag = ((true_rank_log > math.log1p(100)) & (true_rank_log <= math.log1p(1000))).float()
42
+ rtail_flag = (true_rank_log > math.log1p(1000)).float()
43
+ valid_n = shift_mask.sum(dim=1, keepdim=True).clamp(min=1)
44
+ lp_mean = (log_prob * shift_mask).sum(1, keepdim=True) / valid_n
45
+ lp_var = ((log_prob - lp_mean)**2 * shift_mask).sum(1, keepdim=True) / valid_n
46
+ lp_std = lp_var.sqrt().clamp(min=1e-4)
47
+ surprisal_z = ((log_prob - lp_mean) / lp_std) * shift_mask
48
+ entropy_shift = F.pad(entropy[:, :-1], (1, 0), value=0.)
49
+ entropy_delta = (entropy - entropy_shift) * shift_mask
50
+ cum_positions = torch.arange(1, L, device=input_ids.device).unsqueeze(0).float()
51
+ cum_rank = (true_rank_log * shift_mask).cumsum(dim=1) / cum_positions
52
+ is_special = torch.zeros_like(shift_mask)
53
+ m = shift_mask
54
+ token_feats_12 = torch.stack([
55
+ log_prob*m, true_rank_log*m, entropy*m, varentropy*m,
56
+ top10_mass*m, gap_1_2*m, surprisal_z, entropy_delta,
57
+ cum_rank*m, is_special*m, r10_flag*m, r100_flag*m,
58
+ ], dim=-1)
59
+ out_token = torch.zeros(B, MAX_LEN, 12, device=input_ids.device)
60
+ out_token[:, :L-1, :] = token_feats_12
61
+ seq_feats = self._compute_seq_feats(
62
+ log_prob, entropy, varentropy, top10_mass,
63
+ gap_1_2, surprisal_z, r10_flag, r100_flag,
64
+ r1k_flag, rtail_flag, shift_mask, valid_n
65
+ )
66
+ return out_token.detach(), seq_feats.detach()
67
+
68
+ def _compute_seq_feats(self, log_prob, entropy, varentropy, top10_mass,
69
+ gap_1_2, surprisal_z, r10_flag, r100_flag,
70
+ r1k_flag, rtail_flag, mask, valid_n):
71
+ feats = []
72
+ def masked_moments(x):
73
+ n = valid_n.squeeze(1)
74
+ mu = (x * mask).sum(1) / n
75
+ dev = (x - mu.unsqueeze(1)) * mask
76
+ var = (dev**2).sum(1) / n
77
+ std = var.sqrt().clamp(min=1e-6)
78
+ skew = (dev**3).sum(1) / (n * std**3 + 1e-8)
79
+ kurt = (dev**4).sum(1) / (n * var**2 + 1e-8)
80
+ return mu, std, skew.clamp(-10, 10), kurt.clamp(0, 50)
81
+ for feat in [log_prob, entropy, varentropy, top10_mass]:
82
+ feats += list(masked_moments(feat))
83
+ for lag in [1, 5]:
84
+ e_shift = F.pad(entropy[:, lag:], (0, lag)) * mask
85
+ e_norm = entropy - (entropy*mask).sum(1, keepdim=True)/valid_n
86
+ e_shift_norm = e_shift - (e_shift*mask).sum(1, keepdim=True)/valid_n
87
+ num = (e_norm * e_shift_norm * mask).sum(1)
88
+ denom = (((e_norm**2*mask).sum(1)+1e-8).sqrt() * ((e_shift_norm**2*mask).sum(1)+1e-8).sqrt())
89
+ feats.append((num/denom).clamp(-1, 1))
90
+ n = valid_n.squeeze(1)
91
+ for flag in [r10_flag, r100_flag, r1k_flag, rtail_flag]:
92
+ feats.append((flag*mask).sum(1)/n)
93
+ feats += list(masked_moments(gap_1_2)[:2])
94
+ feats += list(masked_moments(surprisal_z)[:2])
95
+ pos = torch.arange(entropy.shape[1], device=entropy.device).float().unsqueeze(0)
96
+ pos_mu = (pos*mask).sum(1)/n
97
+ ent_mu = (entropy*mask).sum(1)/n
98
+ cov = ((pos - pos_mu.unsqueeze(1)) * (entropy - ent_mu.unsqueeze(1)) * mask).sum(1)
99
+ var_pos = ((pos - pos_mu.unsqueeze(1))**2 * mask).sum(1)
100
+ feats.append((cov/(var_pos+1e-8)).clamp(-5, 5))
101
+ abs_surp = surprisal_z.abs()
102
+ mu_b, std_b, _, _ = masked_moments(abs_surp)
103
+ feats.append((std_b/(mu_b+1e-6)).clamp(0, 20))
104
+ feats.append((top10_mass*mask).sum(1)/n)
105
+ feats.append((((top10_mass-(top10_mass*mask).sum(1,keepdim=True)/valid_n)**2*mask).sum(1)/n).sqrt())
106
+ ent_median = entropy.median(dim=1).values.unsqueeze(1)
107
+ feats.append(((entropy < ent_median)*mask).sum(1)/n)
108
+ lp_std = masked_moments(log_prob)[1]
109
+ ent_mu2 = (entropy*mask).sum(1)/n
110
+ feats.append((lp_std/(ent_mu2+1e-4)).clamp(0, 20))
111
+ return torch.nan_to_num(torch.stack(feats, dim=1), nan=0., posinf=20., neginf=-20.)
112
+
113
+ class GatedTemporalMixer(nn.Module):
114
+ def __init__(self, dim, kernel_size=7):
115
+ super().__init__()
116
+ self.conv = nn.Conv1d(dim, dim*2, kernel_size, padding=(kernel_size-1), groups=dim)
117
+ self.norm = nn.LayerNorm(dim)
118
+ def forward(self, x):
119
+ h = self.conv(x.transpose(1,2))[:, :, :x.shape[1]]
120
+ gate, val = h.chunk(2, dim=1)
121
+ return self.norm((torch.sigmoid(gate)*val).transpose(1,2) + x)
122
+
123
+ class PerTokenEncoder(nn.Module):
124
+ def __init__(self):
125
+ super().__init__()
126
+ self.feat_norm = nn.LayerNorm(12)
127
+ self.proj_in = nn.Linear(12, 128)
128
+ self.mixer1 = GatedTemporalMixer(128, 7)
129
+ self.mixer2 = GatedTemporalMixer(128, 15)
130
+ self.ff = nn.Sequential(
131
+ nn.Linear(128, 256), nn.GELU(), nn.Dropout(0.1),
132
+ nn.Linear(256, 128), nn.LayerNorm(128)
133
+ )
134
+ self.attn_q = nn.Linear(128, 1, bias=False)
135
+ self.proj_out = nn.Linear(128, 256)
136
+ def forward(self, x, mask):
137
+ x_proj = F.gelu(self.proj_in(self.feat_norm(x)))
138
+ mixed = self.mixer2(self.mixer1(x_proj))
139
+ hidden = mixed + self.ff(mixed)
140
+ scores = self.attn_q(hidden).squeeze(-1).masked_fill(mask==0, float('-inf'))
141
+ return self.proj_out((hidden * torch.softmax(scores, dim=-1).unsqueeze(-1)).sum(1))
142
+
143
+ class SeqFeatMLP(nn.Module):
144
+ def __init__(self):
145
+ super().__init__()
146
+ self.net = nn.Sequential(
147
+ nn.LayerNorm(32), nn.Linear(32, 128), nn.GELU(),
148
+ nn.Dropout(0.1), nn.Linear(128, 64), nn.LayerNorm(64)
149
+ )
150
+ def forward(self, x): return self.net(x)
151
+
152
+ class PolyglotClassifierV3(nn.Module):
153
+ def __init__(self, base):
154
+ super().__init__()
155
+ self.encoder = base
156
+ self.token_enc = PerTokenEncoder()
157
+ self.seq_mlp = SeqFeatMLP()
158
+ fused = base.config.hidden_size + 256 + 64
159
+ self.classifier = nn.Sequential(
160
+ nn.LayerNorm(fused), nn.Linear(fused, 512),
161
+ nn.GELU(), nn.Dropout(0.2),
162
+ nn.Linear(512, 128), nn.GELU(),
163
+ nn.Dropout(0.1), nn.Linear(128, 1)
164
+ )
165
+ def forward(self, ids, mask, tf, sf):
166
+ hs = self.encoder(input_ids=ids, attention_mask=mask).last_hidden_state
167
+ sem = (hs * mask.unsqueeze(-1)).sum(1) / mask.unsqueeze(-1).sum(1).clamp(min=1e-4)
168
+ return self.classifier(torch.cat([sem, self.token_enc(tf, mask), self.seq_mlp(sf)], dim=-1)).squeeze(-1)
169
+
170
+ gen_tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
171
+ if gen_tokenizer.pad_token is None:
172
+ gen_tokenizer.pad_token = gen_tokenizer.eos_token
173
+ t5_tokenizer = RobertaTokenizer.from_pretrained("Salesforce/codet5-base", extra_special_tokens=None)
174
+
175
+ gen_base = AutoModelForCausalLM.from_pretrained("Salesforce/codegen-350M-mono", torch_dtype=torch.float16).to(DEVICE)
176
+ metric_engine = PolyglotMetricEngine(gen_base).eval()
177
+
178
+ t5_base = T5EncoderModel.from_pretrained("Salesforce/codet5-base")
179
+ detector = PolyglotClassifierV3(t5_base).to(DEVICE)
180
+
181
+ weights_path = hf_hub_download(repo_id=REPO_ID, filename="model_weights.pt")
182
+ detector.load_state_dict(torch.load(weights_path, map_location=DEVICE))
183
+ detector.eval()
184
+
185
+ def predict(code: str, threshold: float = THRESHOLD) -> dict:
186
+ if len(code.strip()) < 200:
187
+ return {"prediction": "Too Short", "ai_probability": None}
188
+
189
+ with torch.no_grad():
190
+ g = gen_tokenizer(code, return_tensors="pt", padding="max_length", truncation=True, max_length=MAX_LEN).to(DEVICE)
191
+ tf, sf = metric_engine(g["input_ids"], g["attention_mask"])
192
+ t = t5_tokenizer(code, return_tensors="pt", padding="max_length", truncation=True, max_length=MAX_LEN).to(DEVICE)
193
+ logits = detector(t["input_ids"], t["attention_mask"], tf.float(), sf.float())
194
+ prob = torch.sigmoid(logits).item()
195
+
196
+ return {
197
+ "prediction": "AI Generated" if prob >= threshold else "Human Written",
198
+ "ai_probability": round(prob * 100, 2)
199
+ }