openfree commited on
Commit
4fd4c0d
·
verified ·
1 Parent(s): 3564514

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -10
app.py CHANGED
@@ -239,20 +239,58 @@ def dna_embed(seq: str, model_name: str = "zhihan1996/DNABERT-2-117M") -> Dict:
239
  try:
240
  import einops
241
  except ImportError:
242
- return {"error": "einops package required. Installing..."}
243
 
244
- # DNABERT-2 특별 처리
245
- from transformers import BertModel, BertTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
- tokenizer = BertTokenizer.from_pretrained(model_name, trust_remote_code=True)
248
- model = BertModel.from_pretrained(model_name, trust_remote_code=True)
249
  model.eval()
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  with torch.no_grad():
252
- inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=512)
 
 
 
 
 
 
253
  outputs = model(**inputs)
254
- hidden = outputs.last_hidden_state.mean(dim=1).squeeze(0)
255
- vec = hidden.cpu().numpy()
 
 
 
 
 
256
 
257
  # 메모리 정리
258
  del model
@@ -264,8 +302,9 @@ def dna_embed(seq: str, model_name: str = "zhihan1996/DNABERT-2-117M") -> Dict:
264
  "embedding": vec.tolist()[:10], # 미리보기용 첫 10개만
265
  "size": vec.shape[0]
266
  }
 
267
  except Exception as e:
268
- return {"error": str(e)}
269
 
270
  def build_context(query: str, docs: List[str], index, model, use_web: bool, web_k: int) -> Tuple[str, List[Dict]]:
271
  """Build context from sources"""
@@ -363,7 +402,8 @@ with st.sidebar:
363
  )
364
  dna_model = st.text_input(
365
  "DNA Model",
366
- value="zhihan1996/DNABERT-2-117M"
 
367
  )
368
 
369
  use_web = st.checkbox("Enable web search", value=True)
 
239
  try:
240
  import einops
241
  except ImportError:
242
+ return {"error": "einops package required. Please wait for installation and refresh the page."}
243
 
244
+ # 간단한 대안: 더 안정적인 모델 사용
245
+ # DNABERT-2가 문제를 일으키면 기본 BERT 사용
246
+ try:
247
+ from transformers import AutoTokenizer, AutoModel
248
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
249
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
250
+ except Exception as model_error:
251
+ # 대체 모델 사용
252
+ try:
253
+ from transformers import BertTokenizer, BertModel
254
+ # 기본 BERT 모델로 폴백
255
+ fallback_model = "bert-base-uncased"
256
+ tokenizer = BertTokenizer.from_pretrained(fallback_model)
257
+ model = BertModel.from_pretrained(fallback_model)
258
+ st.warning(f"DNABERT-2 로딩 실패. 대체 모델 사용중: {fallback_model}")
259
+ except:
260
+ return {"error": f"모델 로딩 실패: {str(model_error)}"}
261
 
 
 
262
  model.eval()
263
 
264
+ # DNA 서열을 k-mer로 변환 (DNABERT 스타일)
265
+ def seq_to_kmer(seq, k=6):
266
+ """DNA 서열을 k-mer로 변환"""
267
+ kmers = []
268
+ for i in range(len(seq) - k + 1):
269
+ kmers.append(seq[i:i+k])
270
+ return ' '.join(kmers)
271
+
272
+ # k-mer 변환 또는 직접 사용
273
+ if len(seq) > 6:
274
+ input_seq = seq_to_kmer(seq, k=6)
275
+ else:
276
+ input_seq = seq
277
+
278
  with torch.no_grad():
279
+ inputs = tokenizer(
280
+ input_seq,
281
+ return_tensors="pt",
282
+ truncation=True,
283
+ max_length=512,
284
+ padding=True
285
+ )
286
  outputs = model(**inputs)
287
+
288
+ # last_hidden_state 또는 pooler_output 사용
289
+ if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
290
+ vec = outputs.pooler_output.squeeze(0).cpu().numpy()
291
+ else:
292
+ hidden = outputs.last_hidden_state.mean(dim=1).squeeze(0)
293
+ vec = hidden.cpu().numpy()
294
 
295
  # 메모리 정리
296
  del model
 
302
  "embedding": vec.tolist()[:10], # 미리보기용 첫 10개만
303
  "size": vec.shape[0]
304
  }
305
+
306
  except Exception as e:
307
+ return {"error": f"분석 중 오류 발생: {str(e)[:200]}"}
308
 
309
  def build_context(query: str, docs: List[str], index, model, use_web: bool, web_k: int) -> Tuple[str, List[Dict]]:
310
  """Build context from sources"""
 
402
  )
403
  dna_model = st.text_input(
404
  "DNA Model",
405
+ value="bert-base-uncased", # 더 안정적인 기본 모델
406
+ help="Options: bert-base-uncased (stable), zhihan1996/DNABERT-2-117M (specialized but may require more memory)"
407
  )
408
 
409
  use_web = st.checkbox("Enable web search", value=True)