Feature Extraction
Transformers
Safetensors
flash_transformer
biology
genomics
long-context
custom_code
Instructions to use isyslab/DNAFlash with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use isyslab/DNAFlash with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="isyslab/DNAFlash", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("isyslab/DNAFlash", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
Update dnaflash.py
Browse files- dnaflash.py +1 -1
dnaflash.py
CHANGED
|
@@ -474,7 +474,7 @@ class FLASHTransformerForSequenceClassification(FLASHTransformerForPretrained):
|
|
| 474 |
return_dict=return_dict,
|
| 475 |
)
|
| 476 |
hidden_states = outputs["hidden_states"]
|
| 477 |
-
input_mask_expanded =
|
| 478 |
mean_pooled = torch.sum(hidden_states * input_mask_expanded, dim=1) / input_mask_expanded.sum(dim=1) # 计算加权平均
|
| 479 |
logits = self.score(mean_pooled)
|
| 480 |
|
|
|
|
| 474 |
return_dict=return_dict,
|
| 475 |
)
|
| 476 |
hidden_states = outputs["hidden_states"]
|
| 477 |
+
input_mask_expanded = attention_mask.unsqueeze(-1) # 维度匹配
|
| 478 |
mean_pooled = torch.sum(hidden_states * input_mask_expanded, dim=1) / input_mask_expanded.sum(dim=1) # 计算加权平均
|
| 479 |
logits = self.score(mean_pooled)
|
| 480 |
|