Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import shutil | |
| import uuid | |
| import json | |
| import torch | |
| import numpy as np | |
| import io | |
| from collections import Counter | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from transformers import AutoTokenizer, AutoModel | |
| from sklearn.cluster import HDBSCAN | |
| from Bio import SeqIO | |
| from huggingface_hub import snapshot_download | |
| print("Initializing BioSentinel...") | |
| def load_cpu_safe_model(): | |
| checkpoint = "zhihan1996/DNABERT-2-117M" | |
| local_path = "./dnabert_local" | |
| if not os.path.exists(local_path): | |
| print("Downloading model...") | |
| snapshot_download(repo_id=checkpoint, local_dir=local_path) | |
| config_path = os.path.join(local_path, "config.json") | |
| if os.path.exists(config_path): | |
| with open(config_path, "r") as f: | |
| config_data = json.load(f) | |
| config_data["use_flash_attn"] = False | |
| config_data["auto_map"] = { | |
| "AutoConfig": "configuration_bert.BertConfig", | |
| "AutoModel": "bert_layers.BertModel", | |
| "AutoModelForMaskedLM": "bert_layers.BertForMaskedLM" | |
| } | |
| with open(config_path, "w") as f: | |
| json.dump(config_data, f, indent=2) | |
| bert_layer_path = os.path.join(local_path, "bert_layers.py") | |
| if os.path.exists(bert_layer_path): | |
| with open(bert_layer_path, "r") as f: content = f.read() | |
| if "from .flash_attn_triton import" in content: | |
| content = content.replace( | |
| "from .flash_attn_triton import flash_attn_qkvpacked_func", | |
| "flash_attn_qkvpacked_func = None # Patched" | |
| ) | |
| content = content.replace( | |
| "from .flash_attn_triton import flash_attn_func", | |
| "flash_attn_func = None # Patched" | |
| ) | |
| content = content.replace("self.config.use_flash_attn", "False") | |
| with open(bert_layer_path, "w") as f: f.write(content) | |
| print("Configuration patched for CPU.") | |
| tokenizer = AutoTokenizer.from_pretrained(local_path, trust_remote_code=True) | |
| model = AutoModel.from_pretrained(local_path, trust_remote_code=True) | |
| return tokenizer, model | |
| try: | |
| tokenizer, model = load_cpu_safe_model() | |
| device = torch.device("cpu") | |
| model.to(device) | |
| model.eval() | |
| print("AI Engine Ready.") | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| sys.exit(1) | |
| app = FastAPI() | |
| UPLOAD_DIR = "temp_uploads" | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| def get_embeddings(sequences, batch_size=4): | |
| all_embeddings = [] | |
| for i in range(0, len(sequences), batch_size): | |
| batch = sequences[i : i + batch_size] | |
| try: | |
| inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| if isinstance(outputs, tuple): | |
| hidden_states = outputs[0] | |
| else: | |
| hidden_states = outputs.last_hidden_state | |
| attention_mask = inputs['attention_mask'].unsqueeze(-1).expand(hidden_states.size()) | |
| sum_embeddings = torch.sum(hidden_states * attention_mask, 1) | |
| sum_mask = torch.clamp(attention_mask.sum(1), min=1e-9) | |
| mean_embeddings = sum_embeddings / sum_mask | |
| all_embeddings.append(mean_embeddings.cpu().numpy()) | |
| except Exception as e: | |
| print(f"Embedding Error: {e}") | |
| return np.concatenate(all_embeddings, axis=0) if all_embeddings else np.array([]) | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| try: | |
| while True: | |
| data = await websocket.receive_text() | |
| if data.strip().startswith("@"): # FASTQ | |
| sequences = [] | |
| try: | |
| for record in SeqIO.parse(io.StringIO(data), "fastq"): | |
| sequences.append(str(record.seq)) | |
| except: | |
| sequences = data.splitlines() | |
| else: | |
| sequences = [s.strip() for s in data.splitlines() if s.strip()] | |
| if len(sequences) < 2: | |
| await websocket.send_json({"error": "Send at least 2 sequences"}) | |
| continue | |
| await websocket.send_json({"status": "processing", "msg": f"Processing {len(sequences)} reads..."}) | |
| embeddings = get_embeddings(sequences) | |
| clusterer = HDBSCAN(min_cluster_size=2, metric='euclidean') | |
| labels = clusterer.fit_predict(embeddings) | |
| counts = dict(Counter(labels.tolist())) | |
| response = { | |
| "status": "complete", | |
| "total": len(sequences), | |
| "clusters": len(counts) - (1 if -1 in counts else 0), | |
| "anomalies": counts.get(-1, 0), | |
| "data": [{"cluster": int(l), "seq_idx": i} for i, l in enumerate(labels)] | |
| } | |
| await websocket.send_json(response) | |
| except WebSocketDisconnect: | |
| print("Disconnected") |