Update app.py
Browse files
app.py
CHANGED
|
@@ -239,20 +239,58 @@ def dna_embed(seq: str, model_name: str = "zhihan1996/DNABERT-2-117M") -> Dict:
|
|
| 239 |
try:
|
| 240 |
import einops
|
| 241 |
except ImportError:
|
| 242 |
-
return {"error": "einops package required.
|
| 243 |
|
| 244 |
-
#
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
-
tokenizer = BertTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 248 |
-
model = BertModel.from_pretrained(model_name, trust_remote_code=True)
|
| 249 |
model.eval()
|
| 250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
with torch.no_grad():
|
| 252 |
-
inputs = tokenizer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
outputs = model(**inputs)
|
| 254 |
-
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
# 메모리 정리
|
| 258 |
del model
|
|
@@ -264,8 +302,9 @@ def dna_embed(seq: str, model_name: str = "zhihan1996/DNABERT-2-117M") -> Dict:
|
|
| 264 |
"embedding": vec.tolist()[:10], # 미리보기용 첫 10개만
|
| 265 |
"size": vec.shape[0]
|
| 266 |
}
|
|
|
|
| 267 |
except Exception as e:
|
| 268 |
-
return {"error": str(e)}
|
| 269 |
|
| 270 |
def build_context(query: str, docs: List[str], index, model, use_web: bool, web_k: int) -> Tuple[str, List[Dict]]:
|
| 271 |
"""Build context from sources"""
|
|
@@ -363,7 +402,8 @@ with st.sidebar:
|
|
| 363 |
)
|
| 364 |
dna_model = st.text_input(
|
| 365 |
"DNA Model",
|
| 366 |
-
value="
|
|
|
|
| 367 |
)
|
| 368 |
|
| 369 |
use_web = st.checkbox("Enable web search", value=True)
|
|
|
|
| 239 |
try:
|
| 240 |
import einops
|
| 241 |
except ImportError:
|
| 242 |
+
return {"error": "einops package required. Please wait for installation and refresh the page."}
|
| 243 |
|
| 244 |
+
# 간단한 대안: 더 안정적인 모델 사용
|
| 245 |
+
# DNABERT-2가 문제를 일으키면 기본 BERT 사용
|
| 246 |
+
try:
|
| 247 |
+
from transformers import AutoTokenizer, AutoModel
|
| 248 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 249 |
+
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
| 250 |
+
except Exception as model_error:
|
| 251 |
+
# 대체 모델 사용
|
| 252 |
+
try:
|
| 253 |
+
from transformers import BertTokenizer, BertModel
|
| 254 |
+
# 기본 BERT 모델로 폴백
|
| 255 |
+
fallback_model = "bert-base-uncased"
|
| 256 |
+
tokenizer = BertTokenizer.from_pretrained(fallback_model)
|
| 257 |
+
model = BertModel.from_pretrained(fallback_model)
|
| 258 |
+
st.warning(f"DNABERT-2 로딩 실패. 대체 모델 사용중: {fallback_model}")
|
| 259 |
+
except:
|
| 260 |
+
return {"error": f"모델 로딩 실패: {str(model_error)}"}
|
| 261 |
|
|
|
|
|
|
|
| 262 |
model.eval()
|
| 263 |
|
| 264 |
+
# DNA 서열을 k-mer로 변환 (DNABERT 스타일)
|
| 265 |
+
def seq_to_kmer(seq, k=6):
|
| 266 |
+
"""DNA 서열을 k-mer로 변환"""
|
| 267 |
+
kmers = []
|
| 268 |
+
for i in range(len(seq) - k + 1):
|
| 269 |
+
kmers.append(seq[i:i+k])
|
| 270 |
+
return ' '.join(kmers)
|
| 271 |
+
|
| 272 |
+
# k-mer 변환 또는 직접 사용
|
| 273 |
+
if len(seq) > 6:
|
| 274 |
+
input_seq = seq_to_kmer(seq, k=6)
|
| 275 |
+
else:
|
| 276 |
+
input_seq = seq
|
| 277 |
+
|
| 278 |
with torch.no_grad():
|
| 279 |
+
inputs = tokenizer(
|
| 280 |
+
input_seq,
|
| 281 |
+
return_tensors="pt",
|
| 282 |
+
truncation=True,
|
| 283 |
+
max_length=512,
|
| 284 |
+
padding=True
|
| 285 |
+
)
|
| 286 |
outputs = model(**inputs)
|
| 287 |
+
|
| 288 |
+
# last_hidden_state 또는 pooler_output 사용
|
| 289 |
+
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
|
| 290 |
+
vec = outputs.pooler_output.squeeze(0).cpu().numpy()
|
| 291 |
+
else:
|
| 292 |
+
hidden = outputs.last_hidden_state.mean(dim=1).squeeze(0)
|
| 293 |
+
vec = hidden.cpu().numpy()
|
| 294 |
|
| 295 |
# 메모리 정리
|
| 296 |
del model
|
|
|
|
| 302 |
"embedding": vec.tolist()[:10], # 미리보기용 첫 10개만
|
| 303 |
"size": vec.shape[0]
|
| 304 |
}
|
| 305 |
+
|
| 306 |
except Exception as e:
|
| 307 |
+
return {"error": f"분석 중 오류 발생: {str(e)[:200]}"}
|
| 308 |
|
| 309 |
def build_context(query: str, docs: List[str], index, model, use_web: bool, web_k: int) -> Tuple[str, List[Dict]]:
|
| 310 |
"""Build context from sources"""
|
|
|
|
| 402 |
)
|
| 403 |
dna_model = st.text_input(
|
| 404 |
"DNA Model",
|
| 405 |
+
value="bert-base-uncased", # 더 안정적인 기본 모델
|
| 406 |
+
help="Options: bert-base-uncased (stable), zhihan1996/DNABERT-2-117M (specialized but may require more memory)"
|
| 407 |
)
|
| 408 |
|
| 409 |
use_web = st.checkbox("Enable web search", value=True)
|