Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -51,10 +51,11 @@ def _precompute_can_open(seq, min_loop=3, allow_gu=True):
|
|
| 51 |
break
|
| 52 |
return can
|
| 53 |
|
|
|
|
| 54 |
# --- constrained processor ---
|
| 55 |
class BalancedParenProcessor(LogitsProcessor):
|
| 56 |
def __init__(self, lp_id, rp_id, dot_id, total_len, can_open,
|
| 57 |
-
dot_bias=0.
|
| 58 |
self.lp_id, self.rp_id, self.dot_id = lp_id, rp_id, dot_id
|
| 59 |
self.total_len = total_len
|
| 60 |
self.step = 0
|
|
@@ -66,76 +67,77 @@ class BalancedParenProcessor(LogitsProcessor):
|
|
| 66 |
self.window=window
|
| 67 |
|
| 68 |
def __call__(self, input_ids, scores):
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
if must_close:
|
| 75 |
-
allowed=[self.rp_id]
|
| 76 |
else:
|
| 77 |
-
if self.depth>0:
|
| 78 |
allowed.append(self.rp_id)
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
| 80 |
allowed.append(self.lp_id)
|
|
|
|
| 81 |
allowed.append(self.dot_id)
|
| 82 |
-
mask[:,allowed]=0.0
|
| 83 |
-
scores=scores+mask
|
| 84 |
-
scores[:,self.dot_id]+=self.dot_bias
|
| 85 |
-
if len(self.history)>=self.window and all(t in (self.lp_id,self.rp_id) for t in self.history[-self.window:]):
|
| 86 |
-
scores[:,self.lp_id]-=self.paren_penalty
|
| 87 |
-
scores[:,self.rp_id]-=self.paren_penalty
|
| 88 |
-
return scores
|
| 89 |
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
sorted_probs[mask]=0
|
| 106 |
-
sorted_probs/=sorted_probs.sum(dim=-1,keepdim=True)
|
| 107 |
-
idx=torch.multinomial(sorted_probs,1)
|
| 108 |
-
return sorted_idx.gather(-1,idx).squeeze(-1)
|
| 109 |
|
| 110 |
# --- generator ---
|
| 111 |
def _generate_db(seq):
|
| 112 |
-
tok,model,device=_load_model_and_tokenizer()
|
| 113 |
-
n=len(seq)
|
| 114 |
-
prompt=f"RNA: {seq}\nDot-bracket (exactly {n} characters using only '(' ')' '.'):\n"
|
| 115 |
-
lp=_char_token_id(tok,"("); rp=_char_token_id(tok,")"); dot=_char_token_id(tok,".")
|
| 116 |
-
can=_precompute_can_open(seq)
|
| 117 |
-
proc=BalancedParenProcessor(lp,rp,dot,n,can)
|
| 118 |
-
procs=LogitsProcessorList([proc])
|
| 119 |
-
inputs=tok(prompt,return_tensors="pt").to(device)
|
| 120 |
-
cur=inputs["input_ids"]
|
| 121 |
-
generated=[]
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
logits=
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
| 137 |
return db
|
| 138 |
|
|
|
|
| 139 |
# --- structural element translation ---
|
| 140 |
def dotbracket_to_structural(dot_str):
|
| 141 |
if not dot_str: return "<start><external_loop><end>"
|
|
|
|
| 51 |
break
|
| 52 |
return can
|
| 53 |
|
| 54 |
+
# --- constrained processor ---
|
| 55 |
# --- constrained processor ---
|
| 56 |
class BalancedParenProcessor(LogitsProcessor):
|
| 57 |
def __init__(self, lp_id, rp_id, dot_id, total_len, can_open,
|
| 58 |
+
dot_bias=0.0, paren_penalty=0.0, window=5):
|
| 59 |
self.lp_id, self.rp_id, self.dot_id = lp_id, rp_id, dot_id
|
| 60 |
self.total_len = total_len
|
| 61 |
self.step = 0
|
|
|
|
| 67 |
self.window=window
|
| 68 |
|
| 69 |
def __call__(self, input_ids, scores):
|
| 70 |
+
# restrict to only three tokens
|
| 71 |
+
mask = torch.full_like(scores, float("-inf"))
|
| 72 |
+
remaining = self.total_len - self.step
|
| 73 |
+
allowed = []
|
| 74 |
+
|
| 75 |
+
# If we must close to avoid running out of room, force )
|
| 76 |
+
must_close = (remaining == self.depth and self.depth > 0)
|
| 77 |
+
pos = self.step
|
| 78 |
+
|
| 79 |
if must_close:
|
| 80 |
+
allowed = [self.rp_id]
|
| 81 |
else:
|
| 82 |
+
if self.depth > 0:
|
| 83 |
allowed.append(self.rp_id)
|
| 84 |
+
|
| 85 |
+
# allow opening if there will still be room to close later
|
| 86 |
+
# (be a bit less strict than remaining-1 > depth to encourage stems)
|
| 87 |
+
if remaining - 2 >= self.depth and pos < len(self.can_open) and self.can_open[pos]:
|
| 88 |
allowed.append(self.lp_id)
|
| 89 |
+
|
| 90 |
allowed.append(self.dot_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
+
mask[:, allowed] = 0.0
|
| 93 |
+
scores = scores + mask
|
| 94 |
+
|
| 95 |
+
# (no dot boost by default)
|
| 96 |
+
if self.dot_bias != 0.0:
|
| 97 |
+
scores[:, self.dot_id] += self.dot_bias
|
| 98 |
+
|
| 99 |
+
# optional mild anti-run regularizer
|
| 100 |
+
if self.paren_penalty and len(self.history) >= self.window and all(
|
| 101 |
+
t in (self.lp_id, self.rp_id) for t in self.history[-self.window:]
|
| 102 |
+
):
|
| 103 |
+
scores[:, self.lp_id] -= self.paren_penalty
|
| 104 |
+
scores[:, self.rp_id] -= self.paren_penalty
|
| 105 |
+
|
| 106 |
+
return scores
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
# --- generator ---
|
| 109 |
def _generate_db(seq):
|
| 110 |
+
tok, model, device = _load_model_and_tokenizer()
|
| 111 |
+
n = len(seq)
|
| 112 |
+
prompt = f"RNA: {seq}\nDot-bracket (exactly {n} characters using only '(' ')' '.'):\n"
|
| 113 |
+
lp = _char_token_id(tok, "("); rp = _char_token_id(tok, ")"); dot = _char_token_id(tok, ".")
|
| 114 |
+
can = _precompute_can_open(seq, min_loop=3) # try 2 if you still get few stems
|
| 115 |
+
proc = BalancedParenProcessor(lp, rp, dot, n, can, dot_bias=0.0, paren_penalty=0.0)
|
| 116 |
+
procs = LogitsProcessorList([proc])
|
| 117 |
+
inputs = tok(prompt, return_tensors="pt").to(device)
|
| 118 |
+
cur = inputs["input_ids"]
|
| 119 |
+
generated = []
|
| 120 |
+
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
for _ in range(n):
|
| 123 |
+
out = model(cur)
|
| 124 |
+
logits = out.logits[:, -1, :]
|
| 125 |
+
for p in procs:
|
| 126 |
+
logits = p(cur, logits)
|
| 127 |
+
next_id = _top_p_sample(logits, top_p=0.9, temperature=0.8)
|
| 128 |
+
next_id = next_id.to(device)
|
| 129 |
+
tokid = next_id.item()
|
| 130 |
+
generated.append(tokid)
|
| 131 |
+
proc.update(tokid)
|
| 132 |
+
cur = torch.cat([cur, next_id.view(1, 1)], dim=1)
|
| 133 |
+
|
| 134 |
+
text = tok.decode(generated, skip_special_tokens=True)
|
| 135 |
+
db = "".join(c for c in text if c in "().")[:n]
|
| 136 |
+
if len(db) != n:
|
| 137 |
+
db = (db + "." * n)[:n]
|
| 138 |
return db
|
| 139 |
|
| 140 |
+
|
| 141 |
# --- structural element translation ---
|
| 142 |
def dotbracket_to_structural(dot_str):
|
| 143 |
if not dot_str: return "<start><external_loop><end>"
|