1-1-3-8 commited on
Commit
c94bb03
·
verified ·
1 Parent(s): 9e399de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -60
app.py CHANGED
@@ -51,10 +51,11 @@ def _precompute_can_open(seq, min_loop=3, allow_gu=True):
51
  break
52
  return can
53
 
 
54
  # --- constrained processor ---
55
  class BalancedParenProcessor(LogitsProcessor):
56
  def __init__(self, lp_id, rp_id, dot_id, total_len, can_open,
57
- dot_bias=0.8, paren_penalty=0.5, window=5):
58
  self.lp_id, self.rp_id, self.dot_id = lp_id, rp_id, dot_id
59
  self.total_len = total_len
60
  self.step = 0
@@ -66,76 +67,77 @@ class BalancedParenProcessor(LogitsProcessor):
66
  self.window=window
67
 
68
  def __call__(self, input_ids, scores):
69
- mask=torch.full_like(scores,float("-inf"))
70
- remaining=self.total_len-self.step
71
- allowed=[]
72
- must_close=(remaining==self.depth and self.depth>0)
73
- pos=self.step
 
 
 
 
74
  if must_close:
75
- allowed=[self.rp_id]
76
  else:
77
- if self.depth>0:
78
  allowed.append(self.rp_id)
79
- if remaining-1>self.depth and pos<len(self.can_open) and self.can_open[pos]:
 
 
 
80
  allowed.append(self.lp_id)
 
81
  allowed.append(self.dot_id)
82
- mask[:,allowed]=0.0
83
- scores=scores+mask
84
- scores[:,self.dot_id]+=self.dot_bias
85
- if len(self.history)>=self.window and all(t in (self.lp_id,self.rp_id) for t in self.history[-self.window:]):
86
- scores[:,self.lp_id]-=self.paren_penalty
87
- scores[:,self.rp_id]-=self.paren_penalty
88
- return scores
89
 
90
- def update(self, tok):
91
- if tok==self.lp_id:
92
- self.depth+=1
93
- elif tok==self.rp_id:
94
- self.depth=max(0,self.depth-1)
95
- self.history.append(tok)
96
- self.step+=1
97
-
98
- def _top_p_sample(logits, top_p=0.9, temperature=0.8):
99
- logits=logits/temperature
100
- probs=torch.softmax(logits,dim=-1)
101
- sorted_probs,sorted_idx=torch.sort(probs,descending=True)
102
- cum=torch.cumsum(sorted_probs,dim=-1)
103
- mask=cum>top_p
104
- mask[...,0]=False
105
- sorted_probs[mask]=0
106
- sorted_probs/=sorted_probs.sum(dim=-1,keepdim=True)
107
- idx=torch.multinomial(sorted_probs,1)
108
- return sorted_idx.gather(-1,idx).squeeze(-1)
109
 
110
  # --- generator ---
111
  def _generate_db(seq):
112
- tok,model,device=_load_model_and_tokenizer()
113
- n=len(seq)
114
- prompt=f"RNA: {seq}\nDot-bracket (exactly {n} characters using only '(' ')' '.'):\n"
115
- lp=_char_token_id(tok,"("); rp=_char_token_id(tok,")"); dot=_char_token_id(tok,".")
116
- can=_precompute_can_open(seq)
117
- proc=BalancedParenProcessor(lp,rp,dot,n,can)
118
- procs=LogitsProcessorList([proc])
119
- inputs=tok(prompt,return_tensors="pt").to(device)
120
- cur=inputs["input_ids"]
121
- generated=[]
122
- for _ in range(n):
123
- out=model(cur)
124
- logits=out.logits[:,-1,:]
125
- for p in procs:
126
- logits=p(cur,logits)
127
- next_id=_top_p_sample(logits,0.9,0.8)
128
- next_id=next_id.to(device)
129
- tokid=next_id.item()
130
- generated.append(tokid)
131
- proc.update(tokid)
132
- cur=torch.cat([cur,next_id.view(1,1)],dim=1)
133
- text=tok.decode(generated,skip_special_tokens=True)
134
- db="".join(c for c in text if c in "().")[:n]
135
- if len(db)!=n:
136
- db=(db+"."*n)[:n]
 
 
 
137
  return db
138
 
 
139
  # --- structural element translation ---
140
  def dotbracket_to_structural(dot_str):
141
  if not dot_str: return "<start><external_loop><end>"
 
51
  break
52
  return can
53
 
54
+ # --- constrained processor ---
55
  # --- constrained processor ---
56
  class BalancedParenProcessor(LogitsProcessor):
57
  def __init__(self, lp_id, rp_id, dot_id, total_len, can_open,
58
+ dot_bias=0.0, paren_penalty=0.0, window=5):
59
  self.lp_id, self.rp_id, self.dot_id = lp_id, rp_id, dot_id
60
  self.total_len = total_len
61
  self.step = 0
 
67
  self.window=window
68
 
69
  def __call__(self, input_ids, scores):
70
+ # restrict to only three tokens
71
+ mask = torch.full_like(scores, float("-inf"))
72
+ remaining = self.total_len - self.step
73
+ allowed = []
74
+
75
+ # If we must close to avoid running out of room, force )
76
+ must_close = (remaining == self.depth and self.depth > 0)
77
+ pos = self.step
78
+
79
  if must_close:
80
+ allowed = [self.rp_id]
81
  else:
82
+ if self.depth > 0:
83
  allowed.append(self.rp_id)
84
+
85
+ # allow opening if there will still be room to close later
86
+ # (be a bit less strict than remaining-1 > depth to encourage stems)
87
+ if remaining - 2 >= self.depth and pos < len(self.can_open) and self.can_open[pos]:
88
  allowed.append(self.lp_id)
89
+
90
  allowed.append(self.dot_id)
 
 
 
 
 
 
 
91
 
92
+ mask[:, allowed] = 0.0
93
+ scores = scores + mask
94
+
95
+ # (no dot boost by default)
96
+ if self.dot_bias != 0.0:
97
+ scores[:, self.dot_id] += self.dot_bias
98
+
99
+ # optional mild anti-run regularizer
100
+ if self.paren_penalty and len(self.history) >= self.window and all(
101
+ t in (self.lp_id, self.rp_id) for t in self.history[-self.window:]
102
+ ):
103
+ scores[:, self.lp_id] -= self.paren_penalty
104
+ scores[:, self.rp_id] -= self.paren_penalty
105
+
106
+ return scores
 
 
 
 
107
 
108
  # --- generator ---
109
  def _generate_db(seq):
110
+ tok, model, device = _load_model_and_tokenizer()
111
+ n = len(seq)
112
+ prompt = f"RNA: {seq}\nDot-bracket (exactly {n} characters using only '(' ')' '.'):\n"
113
+ lp = _char_token_id(tok, "("); rp = _char_token_id(tok, ")"); dot = _char_token_id(tok, ".")
114
+ can = _precompute_can_open(seq, min_loop=3) # try 2 if you still get few stems
115
+ proc = BalancedParenProcessor(lp, rp, dot, n, can, dot_bias=0.0, paren_penalty=0.0)
116
+ procs = LogitsProcessorList([proc])
117
+ inputs = tok(prompt, return_tensors="pt").to(device)
118
+ cur = inputs["input_ids"]
119
+ generated = []
120
+
121
+ with torch.no_grad():
122
+ for _ in range(n):
123
+ out = model(cur)
124
+ logits = out.logits[:, -1, :]
125
+ for p in procs:
126
+ logits = p(cur, logits)
127
+ next_id = _top_p_sample(logits, top_p=0.9, temperature=0.8)
128
+ next_id = next_id.to(device)
129
+ tokid = next_id.item()
130
+ generated.append(tokid)
131
+ proc.update(tokid)
132
+ cur = torch.cat([cur, next_id.view(1, 1)], dim=1)
133
+
134
+ text = tok.decode(generated, skip_special_tokens=True)
135
+ db = "".join(c for c in text if c in "().")[:n]
136
+ if len(db) != n:
137
+ db = (db + "." * n)[:n]
138
  return db
139
 
140
+
141
  # --- structural element translation ---
142
  def dotbracket_to_structural(dot_str):
143
  if not dot_str: return "<start><external_loop><end>"