Upload folder using huggingface_hub
Browse files- .gitattributes +13 -0
- Dockerfile +47 -0
- README.md +13 -0
- assistant_female_voice.wav +3 -0
- attention_mask_research.md +186 -0
- compare_generation.py +129 -0
- hotkey.txt +1 -0
- model/__init__.py +0 -0
- model/data.py +76 -0
- model/model_tracker.py +224 -0
- model/model_updater.py +93 -0
- model/storage/__init__.py +0 -0
- model/storage/chain/chain_model_metadata_store.py +177 -0
- model/storage/disk/__init__.py +0 -0
- model/storage/disk/disk_model_store.py +124 -0
- model/storage/disk/utils.py +121 -0
- model/storage/eval_leaderboard.py +185 -0
- model/storage/hugging_face/__init__.py +0 -0
- model/storage/hugging_face/hugging_face_model_store.py +175 -0
- model/storage/local_model_store.py +30 -0
- model/storage/model_metadata_store.py +17 -0
- model/storage/mysql_model_queue.py +847 -0
- model/storage/remote_model_store.py +17 -0
- model/storage/reputation_store.py +244 -0
- requirements.txt +13 -0
- server.py +211 -0
- spk_001.wav +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assistant_female_voice.wav filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
models/VoxCPM-0.5B/assets/voxcpm_model.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
models/iic/SenseVoiceSmall/fig/aed_figure.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
models/iic/SenseVoiceSmall/fig/asr_results.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
models/iic/SenseVoiceSmall/fig/inference.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
models/iic/SenseVoiceSmall/fig/sensevoice.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
models/iic/SenseVoiceSmall/fig/ser_figure.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
models/iic/SenseVoiceSmall/fig/ser_table.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
models/iic/speech_zipenhancer_ans_multiloss_16k_base/description/matrix.jpg filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
models/iic/speech_zipenhancer_ans_multiloss_16k_base/description/matrix_voicebank.jpg filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
models/iic/speech_zipenhancer_ans_multiloss_16k_base/examples/speech_with_noise1.wav filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
models/v10/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
spk_001.wav filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04
|
| 2 |
+
|
| 3 |
+
# Set environment variables
|
| 4 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 5 |
+
DEBIAN_FRONTEND=noninteractive \
|
| 6 |
+
CUDA_HOME=/usr/local/cuda \
|
| 7 |
+
PATH=/usr/local/cuda/bin:$PATH \
|
| 8 |
+
LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH \
|
| 9 |
+
NVIDIA_VISIBLE_DEVICES=all \
|
| 10 |
+
NVIDIA_DRIVER_CAPABILITIES=compute,utility \
|
| 11 |
+
HF_HOME=/app/models \
|
| 12 |
+
TRITON_CACHE_DIR=/tmp/triton_cache \
|
| 13 |
+
XDG_CACHE_HOME=/tmp \
|
| 14 |
+
NUMBA_CACHE_DIR=/tmp/numba_cache
|
| 15 |
+
|
| 16 |
+
# Install system dependencies
|
| 17 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 18 |
+
python3 \
|
| 19 |
+
python3-pip \
|
| 20 |
+
python3-dev \
|
| 21 |
+
build-essential \
|
| 22 |
+
git \
|
| 23 |
+
ffmpeg \
|
| 24 |
+
libsndfile1 \
|
| 25 |
+
curl \
|
| 26 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 27 |
+
|
| 28 |
+
# Upgrade pip and install build tools
|
| 29 |
+
RUN python3 -m pip install --upgrade pip setuptools wheel uv
|
| 30 |
+
|
| 31 |
+
WORKDIR /app
|
| 32 |
+
|
| 33 |
+
# Create Numba cache directory
|
| 34 |
+
RUN mkdir -p /tmp/numba_cache /tmp/triton_cache && \
|
| 35 |
+
chown nobody:nogroup /tmp/numba_cache /tmp/triton_cache && \
|
| 36 |
+
chmod 700 /tmp/numba_cache /tmp/triton_cache
|
| 37 |
+
|
| 38 |
+
COPY requirements.txt .
|
| 39 |
+
|
| 40 |
+
# Install other requirements
|
| 41 |
+
RUN python3 -m uv pip install --no-cache-dir -r requirements.txt --prerelease=allow
|
| 42 |
+
|
| 43 |
+
COPY . .
|
| 44 |
+
|
| 45 |
+
EXPOSE 8010
|
| 46 |
+
|
| 47 |
+
# CMD ["python3", "server.py"]
|
README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- any-to-any
|
| 5 |
+
- omega
|
| 6 |
+
- omegalabs
|
| 7 |
+
- bittensor
|
| 8 |
+
- agi
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
This is an Any-to-Any model checkpoint for the OMEGA Labs x Bittensor Any-to-Any subnet.
|
| 12 |
+
|
| 13 |
+
Check out the [git repo](https://github.com/omegalabsinc/omegalabs-anytoany-bittensor) and find OMEGA on X: [@omegalabsai](https://x.com/omegalabsai).
|
assistant_female_voice.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1d712ba6de1d15d52eda96bdc043ce43eb5af4b4ac441b78b6fb0fdaf6683c7a
|
| 3 |
+
size 235244
|
attention_mask_research.md
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Attention Masks and Pad Tokens in Transformer Generation: Research Questions
|
| 2 |
+
|
| 3 |
+
## Core Problem Statement
|
| 4 |
+
|
| 5 |
+
When running transformer models (specifically Llama-3.2-1B-Instruct) for text generation, we encounter warnings about missing attention masks and pad tokens, even for single input sequences. This leads to inconsistent generation outputs despite identical inputs.
|
| 6 |
+
|
| 7 |
+
### Warning Messages Observed
|
| 8 |
+
```
|
| 9 |
+
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
|
| 10 |
+
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
|
| 11 |
+
The attention mask is not set and cannot be inferred from input because pad token is same as eos token.
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
## Key Research Questions
|
| 15 |
+
|
| 16 |
+
### 1. Why do single inputs require attention masks?
|
| 17 |
+
**Initial Assumption**: Single sequences without padding shouldn't need attention masks.
|
| 18 |
+
**Observed Reality**: Even single inputs show different generation outputs when attention masks are missing.
|
| 19 |
+
|
| 20 |
+
### 2. What is the relationship between pad tokens and attention masks?
|
| 21 |
+
**Question**: How do pad_token_id and attention_mask work together in the generation process?
|
| 22 |
+
|
| 23 |
+
### 3. Why does pad_token_id = eos_token_id cause issues?
|
| 24 |
+
**Specific Issue**: When padding token equals end-of-sequence token, what ambiguity does this create?
|
| 25 |
+
|
| 26 |
+
## Code Analysis
|
| 27 |
+
|
| 28 |
+
### Current Implementation (Problematic)
|
| 29 |
+
```python
|
| 30 |
+
def chat_current(system_prompt: str, user_prompt: str) -> str:
|
| 31 |
+
messages = [
|
| 32 |
+
{"role": "system", "content": system_prompt},
|
| 33 |
+
{"role": "user", "content": user_prompt},
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
# Only returns input_ids tensor
|
| 37 |
+
input_ids = tok.apply_chat_template(
|
| 38 |
+
messages,
|
| 39 |
+
add_generation_prompt=True,
|
| 40 |
+
return_tensors="pt"
|
| 41 |
+
).to(lm.device)
|
| 42 |
+
|
| 43 |
+
with torch.inference_mode():
|
| 44 |
+
output_ids = lm.generate(
|
| 45 |
+
input_ids, # Missing: attention_mask, pad_token_id
|
| 46 |
+
max_new_tokens=2048,
|
| 47 |
+
do_sample=True,
|
| 48 |
+
temperature=0.2,
|
| 49 |
+
repetition_penalty=1.1,
|
| 50 |
+
top_k=100,
|
| 51 |
+
top_p=0.95,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
return tok.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
### Fixed Implementation
|
| 58 |
+
```python
|
| 59 |
+
def chat_fixed(system_prompt: str, user_prompt: str) -> str:
|
| 60 |
+
messages = [
|
| 61 |
+
{"role": "system", "content": system_prompt},
|
| 62 |
+
{"role": "user", "content": user_prompt},
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
# Returns dictionary with input_ids AND attention_mask
|
| 66 |
+
inputs = tok.apply_chat_template(
|
| 67 |
+
messages,
|
| 68 |
+
add_generation_prompt=True,
|
| 69 |
+
return_tensors="pt",
|
| 70 |
+
return_dict=True # KEY CHANGE: Get both components
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
input_ids = inputs["input_ids"].to(lm.device)
|
| 74 |
+
attention_mask = inputs["attention_mask"].to(lm.device)
|
| 75 |
+
|
| 76 |
+
with torch.inference_mode():
|
| 77 |
+
output_ids = lm.generate(
|
| 78 |
+
input_ids=input_ids,
|
| 79 |
+
attention_mask=attention_mask, # Explicit attention guidance
|
| 80 |
+
pad_token_id=tok.eos_token_id, # Explicit pad token
|
| 81 |
+
max_new_tokens=2048,
|
| 82 |
+
do_sample=True,
|
| 83 |
+
temperature=0.2,
|
| 84 |
+
repetition_penalty=1.1,
|
| 85 |
+
top_k=100,
|
| 86 |
+
top_p=0.95,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
return tok.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
### Model and Tokenizer Setup
|
| 93 |
+
```python
|
| 94 |
+
model_name = "models/Llama-3.2-1B-Instruct"
|
| 95 |
+
tok = AutoTokenizer.from_pretrained(model_name)
|
| 96 |
+
# Critical: Set pad token if not available
|
| 97 |
+
if tok.pad_token is None:
|
| 98 |
+
tok.pad_token = tok.eos_token
|
| 99 |
+
|
| 100 |
+
lm = AutoModelForCausalLM.from_pretrained(
|
| 101 |
+
model_name,
|
| 102 |
+
torch_dtype=torch.bfloat16,
|
| 103 |
+
device_map="cuda",
|
| 104 |
+
).eval()
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
## Observed Behavioral Differences
|
| 108 |
+
|
| 109 |
+
### Input Structure Analysis
|
| 110 |
+
```python
|
| 111 |
+
# Single input contains multiple components:
|
| 112 |
+
messages = [
|
| 113 |
+
{"role": "system", "content": "You are a helpful assistant..."},
|
| 114 |
+
{"role": "user", "content": "What is the capital of France?"},
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
# After apply_chat_template, becomes token sequence:
|
| 118 |
+
# [system_tokens, user_tokens, assistant_start_token]
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
## Technical Hypotheses for Investigation
|
| 122 |
+
|
| 123 |
+
### Hypothesis 1: Internal Masking Ambiguity
|
| 124 |
+
When attention_mask is missing, the model cannot distinguish between:
|
| 125 |
+
- Real input tokens that should influence generation
|
| 126 |
+
- Structural tokens (system prompts, role markers)
|
| 127 |
+
- Token boundaries between different message roles
|
| 128 |
+
|
| 129 |
+
### Hypothesis 2: EOS Token Dual Purpose Confusion
|
| 130 |
+
When `pad_token_id == eos_token_id`, the model faces ambiguity:
|
| 131 |
+
```python
|
| 132 |
+
# Same token (128001) serves dual purposes:
|
| 133 |
+
# 1. End of sequence marker
|
| 134 |
+
# 2. Padding token for batch processing
|
| 135 |
+
# Model cannot infer which purpose applies in context
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
### Hypothesis 3: Autoregressive Generation Context Boundary Issues
|
| 139 |
+
During generation, model needs to know:
|
| 140 |
+
- Which input tokens provide valid context for next token prediction
|
| 141 |
+
- Where the "prompt" ends and "generation" begins
|
| 142 |
+
- How to weight attention across different input components
|
| 143 |
+
|
| 144 |
+
## Research Objectives
|
| 145 |
+
|
| 146 |
+
### Primary Questions
|
| 147 |
+
1. **Mechanism Analysis**: How exactly does missing attention_mask affect the internal attention computation?
|
| 148 |
+
2. **Consistency Impact**: Why do identical inputs produce different outputs without proper masking?
|
| 149 |
+
3. **Single vs Batch Behavior**: What differences exist between single sequence and batched sequence processing?
|
| 150 |
+
|
| 151 |
+
### Secondary Questions
|
| 152 |
+
1. **Model-Specific Behavior**: Do different transformer architectures handle missing attention masks differently?
|
| 153 |
+
2. **Generation Parameter Interaction**: How do attention mask issues interact with sampling parameters (temperature, top_p, etc.)?
|
| 154 |
+
3. **Performance Impact**: What computational overhead does proper attention masking add?
|
| 155 |
+
|
| 156 |
+
## Key Technical Areas for Deep Research
|
| 157 |
+
|
| 158 |
+
### Attention Mechanism Internals
|
| 159 |
+
- How attention weights are computed with/without explicit masks
|
| 160 |
+
- Impact on multi-head attention distributions
|
| 161 |
+
- Interaction with causal masking in autoregressive models
|
| 162 |
+
|
| 163 |
+
### Tokenizer Behavior
|
| 164 |
+
- How `apply_chat_template` constructs input sequences
|
| 165 |
+
- Default attention mask generation behavior
|
| 166 |
+
- Role of special tokens in attention computation
|
| 167 |
+
|
| 168 |
+
### Generation Process
|
| 169 |
+
- How `model.generate()` handles missing parameters
|
| 170 |
+
- Internal assumptions and fallback behaviors
|
| 171 |
+
- Impact on sampling and beam search algorithms
|
| 172 |
+
|
| 173 |
+
## Expected Research Outcomes
|
| 174 |
+
|
| 175 |
+
Understanding of:
|
| 176 |
+
1. Exact mechanism causing output inconsistency
|
| 177 |
+
2. Best practices for single sequence generation
|
| 178 |
+
3. Relationship between attention masking and generation quality
|
| 179 |
+
4. Guidelines for production transformer deployment
|
| 180 |
+
|
| 181 |
+
## References for Deep Research
|
| 182 |
+
|
| 183 |
+
- Hugging Face Transformers documentation on attention masks
|
| 184 |
+
- Technical blogs on transformer attention mechanisms (2024)
|
| 185 |
+
- Community discussions on pad token vs attention mask differences
|
| 186 |
+
- Official model documentation for Llama architecture attention handling
|
compare_generation.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 4 |
+
|
| 5 |
+
# Load model and tokenizer (same as server.py)
|
| 6 |
+
model_name = "models/Llama-3.2-1B-Instruct"
|
| 7 |
+
tok = AutoTokenizer.from_pretrained(model_name)
|
| 8 |
+
lm = AutoModelForCausalLM.from_pretrained(
|
| 9 |
+
model_name,
|
| 10 |
+
torch_dtype=torch.bfloat16,
|
| 11 |
+
device_map="cuda",
|
| 12 |
+
).eval()
|
| 13 |
+
|
| 14 |
+
def chat_current(system_prompt: str, user_prompt: str) -> str:
|
| 15 |
+
"""
|
| 16 |
+
Current implementation (same as server.py) - will show warnings
|
| 17 |
+
"""
|
| 18 |
+
print("🔴 Running CURRENT implementation (with warnings)...")
|
| 19 |
+
|
| 20 |
+
messages = [
|
| 21 |
+
{"role": "system", "content": system_prompt},
|
| 22 |
+
{"role": "user", "content": user_prompt},
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
input_ids = tok.apply_chat_template(
|
| 27 |
+
messages,
|
| 28 |
+
add_generation_prompt=True,
|
| 29 |
+
return_tensors="pt"
|
| 30 |
+
).to(lm.device)
|
| 31 |
+
|
| 32 |
+
with torch.inference_mode():
|
| 33 |
+
output_ids = lm.generate(
|
| 34 |
+
input_ids, # No attention_mask, no pad_token_id
|
| 35 |
+
max_new_tokens=2048,
|
| 36 |
+
do_sample=True,
|
| 37 |
+
temperature=0.2,
|
| 38 |
+
repetition_penalty=1.1,
|
| 39 |
+
top_k=100,
|
| 40 |
+
top_p=0.95,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
answer = tok.decode(
|
| 44 |
+
output_ids[0][input_ids.shape[-1]:],
|
| 45 |
+
skip_special_tokens=True,
|
| 46 |
+
clean_up_tokenization_spaces=True,
|
| 47 |
+
)
|
| 48 |
+
return answer.strip()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def chat_fixed(system_prompt: str, user_prompt: str) -> str:
|
| 52 |
+
"""
|
| 53 |
+
Fixed implementation - proper attention mask and pad token
|
| 54 |
+
"""
|
| 55 |
+
print("🟢 Running FIXED implementation (no warnings)...")
|
| 56 |
+
|
| 57 |
+
messages = [
|
| 58 |
+
{"role": "system", "content": system_prompt},
|
| 59 |
+
{"role": "user", "content": user_prompt},
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
# Get both input_ids and attention_mask
|
| 63 |
+
inputs = tok.apply_chat_template(
|
| 64 |
+
messages,
|
| 65 |
+
add_generation_prompt=True,
|
| 66 |
+
return_tensors="pt",
|
| 67 |
+
return_dict=True # Returns dict with input_ids and attention_mask
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Move to device
|
| 71 |
+
input_ids = inputs["input_ids"].to(lm.device)
|
| 72 |
+
attention_mask = inputs["attention_mask"].to(lm.device)
|
| 73 |
+
|
| 74 |
+
with torch.inference_mode():
|
| 75 |
+
output_ids = lm.generate(
|
| 76 |
+
input_ids=input_ids,
|
| 77 |
+
attention_mask=attention_mask, # Proper attention mask
|
| 78 |
+
pad_token_id=tok.eos_token_id, # Explicit pad token
|
| 79 |
+
max_new_tokens=2048,
|
| 80 |
+
do_sample=True,
|
| 81 |
+
temperature=0.2,
|
| 82 |
+
repetition_penalty=1.1,
|
| 83 |
+
top_k=100,
|
| 84 |
+
top_p=0.95,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
answer = tok.decode(
|
| 88 |
+
output_ids[0][input_ids.shape[-1]:],
|
| 89 |
+
skip_special_tokens=True,
|
| 90 |
+
clean_up_tokenization_spaces=True,
|
| 91 |
+
)
|
| 92 |
+
return answer.strip()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def compare_generations():
|
| 96 |
+
"""Compare both implementations"""
|
| 97 |
+
system_prompt = "You are a helpful assistant who tries to help answer the user's question."
|
| 98 |
+
user_prompt = "Create a report on anxiety in work. How do I manage time and stress effectively?"
|
| 99 |
+
|
| 100 |
+
print("=" * 60)
|
| 101 |
+
print("COMPARING GENERATION METHODS")
|
| 102 |
+
print("=" * 60)
|
| 103 |
+
print(f"System: {system_prompt}")
|
| 104 |
+
print(f"User: {user_prompt}")
|
| 105 |
+
print("=" * 60)
|
| 106 |
+
|
| 107 |
+
# Test current implementation
|
| 108 |
+
print("\n" + "=" * 60)
|
| 109 |
+
current_output = chat_current(system_prompt, user_prompt)
|
| 110 |
+
print(f"CURRENT OUTPUT:\n{current_output}")
|
| 111 |
+
|
| 112 |
+
print("\n" + "=" * 60)
|
| 113 |
+
# Test fixed implementation
|
| 114 |
+
fixed_output = chat_fixed(system_prompt, user_prompt)
|
| 115 |
+
print(f"FIXED OUTPUT:\n{fixed_output}")
|
| 116 |
+
|
| 117 |
+
print("\n" + "=" * 60)
|
| 118 |
+
print("COMPARISON:")
|
| 119 |
+
print(f"Outputs are identical: {current_output == fixed_output}")
|
| 120 |
+
print(f"Current length: {len(current_output)} chars")
|
| 121 |
+
print(f"Fixed length: {len(fixed_output)} chars")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
# Set pad token for the fixed version
|
| 126 |
+
if tok.pad_token is None:
|
| 127 |
+
tok.pad_token = tok.eos_token
|
| 128 |
+
|
| 129 |
+
compare_generations()
|
hotkey.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
5DJMUw8aibqe5BkyF33m87hBrWKuxkiuGPFxdVGzva6DSFLK
|
model/__init__.py
ADDED
|
File without changes
|
model/data.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, ClassVar, Dict, Optional, Type
|
| 2 |
+
# from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
| 3 |
+
from pydantic import BaseModel, Field, PositiveInt, ConfigDict
|
| 4 |
+
|
| 5 |
+
# The maximum bytes for metadata on the chain.
|
| 6 |
+
MAX_METADATA_BYTES = 128
|
| 7 |
+
# The length, in bytes, of a git commit hash.
|
| 8 |
+
GIT_COMMIT_LENGTH = 40
|
| 9 |
+
# The length, in bytes, of a base64 encoded sha256 hash.
|
| 10 |
+
SHA256_BASE_64_LENGTH = 44
|
| 11 |
+
# The max length, in characters, of the competition id
|
| 12 |
+
MAX_COMPETITION_ID_LENGTH = 2
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ModelId(BaseModel):
|
| 16 |
+
"""Uniquely identifies a trained model"""
|
| 17 |
+
|
| 18 |
+
MAX_REPO_ID_LENGTH: ClassVar[int] = (
|
| 19 |
+
MAX_METADATA_BYTES
|
| 20 |
+
- GIT_COMMIT_LENGTH
|
| 21 |
+
- SHA256_BASE_64_LENGTH
|
| 22 |
+
- MAX_COMPETITION_ID_LENGTH
|
| 23 |
+
- 4 # separators
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
namespace: str = Field(
|
| 27 |
+
description="Namespace where the model can be found. ex. Hugging Face username/org."
|
| 28 |
+
)
|
| 29 |
+
name: str = Field(description="Name of the model.")
|
| 30 |
+
|
| 31 |
+
epoch: str = Field(description="The epoch number to submit as your checkpoint to evaluate e.g. 10")
|
| 32 |
+
|
| 33 |
+
# When handling a model locally the commit and hash are not necessary.
|
| 34 |
+
# Commit must be filled when trying to download from a remote store.
|
| 35 |
+
commit: Optional[str] = Field(
|
| 36 |
+
description="Commit of the model. May be empty if not yet committed."
|
| 37 |
+
)
|
| 38 |
+
# Hash is filled automatically when uploading to or downloading from a remote store.
|
| 39 |
+
hash: Optional[str] = Field(description="Hash of the trained model.")
|
| 40 |
+
# Identifier for competition
|
| 41 |
+
competition_id: Optional[str] = Field(description="The competition id")
|
| 42 |
+
|
| 43 |
+
def to_compressed_str(self) -> str:
|
| 44 |
+
"""Returns a compressed string representation."""
|
| 45 |
+
return f"{self.namespace}:{self.name}:{self.epoch}:{self.commit}:{self.hash}:{self.competition_id}"
|
| 46 |
+
|
| 47 |
+
@classmethod
|
| 48 |
+
def from_compressed_str(cls, cs: str) -> Type["ModelId"]:
|
| 49 |
+
"""Returns an instance of this class from a compressed string representation"""
|
| 50 |
+
tokens = cs.split(":")
|
| 51 |
+
return cls(
|
| 52 |
+
namespace=tokens[0],
|
| 53 |
+
name=tokens[1],
|
| 54 |
+
epoch=tokens[2] if tokens[2] != "None" else None,
|
| 55 |
+
commit=tokens[3] if tokens[3] != "None" else None,
|
| 56 |
+
hash=tokens[4] if tokens[4] != "None" else None,
|
| 57 |
+
competition_id=(
|
| 58 |
+
tokens[5] if len(tokens) >= 6 and tokens[5] != "None" else None
|
| 59 |
+
),
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class Model(BaseModel):
|
| 64 |
+
"""Represents a pre trained foundation model."""
|
| 65 |
+
|
| 66 |
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
| 67 |
+
|
| 68 |
+
id: ModelId = Field(description="Identifier for this model.")
|
| 69 |
+
local_repo_dir: str = Field(description="Local repository with the required files.")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class ModelMetadata(BaseModel):
|
| 73 |
+
id: ModelId = Field(description="Identifier for this trained model.")
|
| 74 |
+
block: PositiveInt = Field(
|
| 75 |
+
description="Block on which this model was claimed on the chain."
|
| 76 |
+
)
|
model/model_tracker.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import datetime
|
| 3 |
+
import threading
|
| 4 |
+
from typing import Dict, List, Optional, Set
|
| 5 |
+
import pickle
|
| 6 |
+
import bittensor as bt
|
| 7 |
+
import hashlib
|
| 8 |
+
|
| 9 |
+
from model.data import ModelMetadata
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class NoopLock:
|
| 13 |
+
def __enter__(self):
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ModelTracker:
|
| 21 |
+
"""Tracks the current model for each miner.
|
| 22 |
+
|
| 23 |
+
Thread safe.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
thread_safe: bool = True,
|
| 29 |
+
):
|
| 30 |
+
# Create a dict from miner hotkey to model metadata.
|
| 31 |
+
self.miner_hotkey_to_model_metadata_dict: dict[str, ModelMetadata] = dict()
|
| 32 |
+
# Create a dict from miner hotkey to last time it was evaluated/loaded/updated
|
| 33 |
+
self.miner_hotkey_to_last_touched_dict: dict[str, datetime.datetime] = dict()
|
| 34 |
+
# Create a dict from miner hotkey to model hash.
|
| 35 |
+
self.miner_hotkey_to_model_hash_dict: dict[str, str] = dict()
|
| 36 |
+
|
| 37 |
+
# List of overwritten models that may be safe to delete if not curently in use.
|
| 38 |
+
self.old_model_metadata: list[tuple[str, ModelMetadata]] = []
|
| 39 |
+
# List of model metadata that are currently in use.
|
| 40 |
+
self.model_metadata_in_use: set[tuple[str, str]] = set()
|
| 41 |
+
|
| 42 |
+
# Make this class thread safe because it will be accessed by multiple threads.
|
| 43 |
+
# One for the downloading new models loop and one for the validating models loop.
|
| 44 |
+
self.lock = threading.RLock() if thread_safe else NoopLock()
|
| 45 |
+
|
| 46 |
+
def save_state(self, filepath):
|
| 47 |
+
"""Save the current state to the provided filepath."""
|
| 48 |
+
|
| 49 |
+
# Open a writable binary file for pickle.
|
| 50 |
+
with self.lock:
|
| 51 |
+
with open(filepath, "wb") as f:
|
| 52 |
+
pickle.dump(self.miner_hotkey_to_model_metadata_dict, f)
|
| 53 |
+
|
| 54 |
+
def load_state(self, filepath):
|
| 55 |
+
"""Load the state from the provided filepath."""
|
| 56 |
+
|
| 57 |
+
# Open a readable binary file for pickle.
|
| 58 |
+
with open(filepath, "rb") as f:
|
| 59 |
+
self.miner_hotkey_to_model_metadata_dict = pickle.load(f)
|
| 60 |
+
|
| 61 |
+
def get_miner_hotkey_to_model_metadata_dict(self) -> Dict[str, ModelMetadata]:
|
| 62 |
+
"""Returns the mapping from miner hotkey to model metadata."""
|
| 63 |
+
|
| 64 |
+
# Return a copy to ensure outside code can't modify the scores.
|
| 65 |
+
with self.lock:
|
| 66 |
+
return copy.deepcopy(self.miner_hotkey_to_model_metadata_dict)
|
| 67 |
+
|
| 68 |
+
def get_model_metadata_for_miner_hotkey(
|
| 69 |
+
self, hotkey: str
|
| 70 |
+
) -> Optional[ModelMetadata]:
|
| 71 |
+
"""Returns the model metadata for a given hotkey if any."""
|
| 72 |
+
|
| 73 |
+
with self.lock:
|
| 74 |
+
if hotkey in self.miner_hotkey_to_model_metadata_dict:
|
| 75 |
+
return self.miner_hotkey_to_model_metadata_dict[hotkey]
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
def take_model_metadata_for_miner_hotkey(self, hotkey: str) -> Optional[ModelMetadata]:
|
| 79 |
+
"""Returns the model metadata for a given hotkey if any. Also, marks it as in use to prevent race conditions."""
|
| 80 |
+
|
| 81 |
+
with self.lock:
|
| 82 |
+
if hotkey in self.miner_hotkey_to_model_metadata_dict:
|
| 83 |
+
metadata = self.miner_hotkey_to_model_metadata_dict[hotkey]
|
| 84 |
+
self.model_metadata_in_use.add((hotkey, metadata.id.hash))
|
| 85 |
+
return metadata
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
def release_all(self):
|
| 89 |
+
with self.lock:
|
| 90 |
+
self.model_metadata_in_use.clear()
|
| 91 |
+
|
| 92 |
+
def release_model_metadata_for_miner_hotkey(self, hotkey: str, metadata: ModelMetadata):
|
| 93 |
+
with self.lock:
|
| 94 |
+
pair = (hotkey, metadata.id.hash)
|
| 95 |
+
if pair not in self.model_metadata_in_use:
|
| 96 |
+
bt.logging.error("Model metadata is not in use!")
|
| 97 |
+
|
| 98 |
+
if (hotkey, metadata) in self.old_model_metadata:
|
| 99 |
+
bt.logging.trace(f"Releasing old model metadata for hotkey: {hotkey}")
|
| 100 |
+
|
| 101 |
+
self.model_metadata_in_use.remove(pair)
|
| 102 |
+
|
| 103 |
+
def get_miner_hotkey_to_last_touched_dict(self) -> Dict[str, datetime.datetime]:
|
| 104 |
+
"""Returns the mapping from miner hotkey to last time it was touched."""
|
| 105 |
+
|
| 106 |
+
# Return a copy to ensure outside code can't modify the scores.
|
| 107 |
+
with self.lock:
|
| 108 |
+
return copy.deepcopy(self.miner_hotkey_to_last_touched_dict)
|
| 109 |
+
|
| 110 |
+
def on_hotkeys_updated(self, incoming_hotkeys: Set[str]):
|
| 111 |
+
"""Notifies the tracker which hotkeys are currently being tracked on the metagraph."""
|
| 112 |
+
|
| 113 |
+
with self.lock:
|
| 114 |
+
existing_hotkeys = set(self.miner_hotkey_to_model_metadata_dict.keys())
|
| 115 |
+
for hotkey in existing_hotkeys - incoming_hotkeys:
|
| 116 |
+
del self.miner_hotkey_to_model_metadata_dict[hotkey]
|
| 117 |
+
bt.logging.trace(f"Removed outdated hotkey metadata: {hotkey} from ModelTracker")
|
| 118 |
+
|
| 119 |
+
existing_hotkeys = set(self.miner_hotkey_to_last_touched_dict.keys())
|
| 120 |
+
for hotkey in existing_hotkeys - incoming_hotkeys:
|
| 121 |
+
del self.miner_hotkey_to_last_touched_dict[hotkey]
|
| 122 |
+
bt.logging.trace(f"Removed outdated hotkey timestamp: {hotkey} from ModelTracker")
|
| 123 |
+
|
| 124 |
+
def get_and_clear_old_models(self) -> list[tuple[str, ModelMetadata]]:
|
| 125 |
+
with self.lock:
|
| 126 |
+
to_delete = []
|
| 127 |
+
still_in_use = []
|
| 128 |
+
for hotkey, model in self.old_model_metadata:
|
| 129 |
+
if (hotkey, model.id.hash) in self.model_metadata_in_use:
|
| 130 |
+
still_in_use.append((hotkey, model))
|
| 131 |
+
else:
|
| 132 |
+
to_delete.append((hotkey, model))
|
| 133 |
+
self.old_model_metadata = still_in_use
|
| 134 |
+
|
| 135 |
+
return to_delete
|
| 136 |
+
|
| 137 |
+
def on_miner_model_updated(
|
| 138 |
+
self,
|
| 139 |
+
hotkey: str,
|
| 140 |
+
model_metadata: ModelMetadata,
|
| 141 |
+
) -> None:
|
| 142 |
+
"""Notifies the tracker that a miner has had their associated model updated.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
hotkey (str): The miner's hotkey.
|
| 146 |
+
model_metadata (ModelMetadata): The latest model metadata of the miner.
|
| 147 |
+
"""
|
| 148 |
+
with self.lock:
|
| 149 |
+
if hotkey in self.miner_hotkey_to_model_metadata_dict:
|
| 150 |
+
old_metadata = self.miner_hotkey_to_model_metadata_dict[hotkey]
|
| 151 |
+
self.old_model_metadata.append((hotkey, old_metadata))
|
| 152 |
+
|
| 153 |
+
self.miner_hotkey_to_model_metadata_dict[hotkey] = model_metadata
|
| 154 |
+
self.miner_hotkey_to_last_touched_dict[hotkey] = datetime.datetime.now()
|
| 155 |
+
|
| 156 |
+
bt.logging.trace(f"Updated Miner {hotkey}. ModelMetadata={model_metadata}.")
|
| 157 |
+
|
| 158 |
+
def touch_miner_model(self, hotkey: str) -> None:
|
| 159 |
+
"""Notifies the tracker that a miner has been touched."""
|
| 160 |
+
|
| 161 |
+
now = datetime.datetime.now()
|
| 162 |
+
with self.lock:
|
| 163 |
+
self.miner_hotkey_to_last_touched_dict[hotkey] = now
|
| 164 |
+
|
| 165 |
+
bt.logging.trace(f"Touched Miner {hotkey}. datetime={now}.")
|
| 166 |
+
|
| 167 |
+
def touch_all_miner_models(self) -> None:
|
| 168 |
+
"""Touch all miner models."""
|
| 169 |
+
|
| 170 |
+
now = datetime.datetime.now()
|
| 171 |
+
with self.lock:
|
| 172 |
+
for hotkey in list(self.miner_hotkey_to_model_metadata_dict.keys()):
|
| 173 |
+
self.miner_hotkey_to_last_touched_dict[hotkey] = now
|
| 174 |
+
|
| 175 |
+
bt.logging.trace(f"Touched All Miners. datetime={now}.")
|
| 176 |
+
|
| 177 |
+
def update_model_hash(self, hotkey: str, new_model_hash: str) -> bool:
|
| 178 |
+
"""
|
| 179 |
+
Update the model_hash for a given hotkey.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
hotkey (str): The miner's hotkey.
|
| 183 |
+
new_model_hash (str): The new model hash to be set.
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
bool: True if the update was successful, False if the hotkey was not found.
|
| 187 |
+
"""
|
| 188 |
+
with self.lock:
|
| 189 |
+
self.miner_hotkey_to_model_hash_dict[hotkey] = new_model_hash
|
| 190 |
+
return True
|
| 191 |
+
|
| 192 |
+
def calculate_file_hash(self, file_path: str) -> str:
|
| 193 |
+
"""Calculate SHA1 hash of a file."""
|
| 194 |
+
sha1 = hashlib.sha1()
|
| 195 |
+
with open(file_path, 'rb') as f:
|
| 196 |
+
while True:
|
| 197 |
+
data = f.read(65536) # Read in 64kb chunks
|
| 198 |
+
if not data:
|
| 199 |
+
break
|
| 200 |
+
sha1.update(data)
|
| 201 |
+
return sha1.hexdigest()
|
| 202 |
+
|
| 203 |
+
def is_model_unique(self, hotkey_to_check: str, block_to_check: int, model_checkpoint_path: str) -> bool:
|
| 204 |
+
"""Check if a model with a given model_hash is already in use."""
|
| 205 |
+
# generate hash from model_checkpoint_path
|
| 206 |
+
hash_to_check = self.calculate_file_hash(model_checkpoint_path)
|
| 207 |
+
|
| 208 |
+
with self.lock:
|
| 209 |
+
for hotkey, metadata in self.miner_hotkey_to_model_metadata_dict.items():
|
| 210 |
+
if hotkey == hotkey_to_check or hotkey not in self.miner_hotkey_to_model_hash_dict:
|
| 211 |
+
continue
|
| 212 |
+
|
| 213 |
+
if self.miner_hotkey_to_model_hash_dict[hotkey] == hash_to_check and metadata.block < block_to_check:
|
| 214 |
+
bt.logging.warning(
|
| 215 |
+
f"*** Model with hash {hash_to_check} on block {block_to_check} is not unique. Already in use by {hotkey} on block {metadata.block} for model {metadata.id.namespace}/{metadata.id.name}. ***"
|
| 216 |
+
)
|
| 217 |
+
# Update the model hash for the hotkey
|
| 218 |
+
self.update_model_hash(hotkey_to_check, hash_to_check)
|
| 219 |
+
return False, hash_to_check
|
| 220 |
+
|
| 221 |
+
# Update the model hash for the hotkey
|
| 222 |
+
self.update_model_hash(hotkey_to_check, hash_to_check)
|
| 223 |
+
return True, hash_to_check
|
| 224 |
+
|
model/model_updater.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import bittensor as bt
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from constants import CompetitionParameters, COMPETITION_SCHEDULE
|
| 4 |
+
import constants
|
| 5 |
+
from model.data import ModelMetadata, Model
|
| 6 |
+
from model.model_tracker import ModelTracker
|
| 7 |
+
from model.storage.local_model_store import LocalModelStore
|
| 8 |
+
from model.storage.model_metadata_store import ModelMetadataStore
|
| 9 |
+
from model.storage.remote_model_store import RemoteModelStore
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ModelUpdater:
|
| 13 |
+
"""Checks if the currently tracked model for a hotkey matches what the miner committed to the chain."""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
metadata_store: ModelMetadataStore,
|
| 18 |
+
remote_store: RemoteModelStore,
|
| 19 |
+
local_store: LocalModelStore,
|
| 20 |
+
model_tracker: ModelTracker,
|
| 21 |
+
):
|
| 22 |
+
self.metadata_store = metadata_store
|
| 23 |
+
self.remote_store = remote_store
|
| 24 |
+
self.local_store = local_store
|
| 25 |
+
self.model_tracker = model_tracker
|
| 26 |
+
self.min_block: Optional[int] = None
|
| 27 |
+
|
| 28 |
+
def set_min_block(self, val: Optional[int]):
|
| 29 |
+
self.min_block = val
|
| 30 |
+
|
| 31 |
+
@classmethod
|
| 32 |
+
def get_competition_parameters(cls, id: str) -> Optional[CompetitionParameters]:
|
| 33 |
+
for x in COMPETITION_SCHEDULE:
|
| 34 |
+
if x.competition_id == id:
|
| 35 |
+
return x
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
async def _get_metadata(self, hotkey: str) -> Optional[ModelMetadata]:
|
| 39 |
+
"""Get metadata about a model by hotkey"""
|
| 40 |
+
return await self.metadata_store.retrieve_model_metadata(hotkey)
|
| 41 |
+
|
| 42 |
+
async def sync_model(self, hotkey: str) -> bool:
|
| 43 |
+
"""Updates local model for a hotkey if out of sync and returns if it was updated."""
|
| 44 |
+
# Get the metadata for the miner.
|
| 45 |
+
metadata = await self._get_metadata(hotkey)
|
| 46 |
+
|
| 47 |
+
if not metadata:
|
| 48 |
+
bt.logging.trace(
|
| 49 |
+
f"No valid metadata found on the chain for hotkey {hotkey}"
|
| 50 |
+
)
|
| 51 |
+
return False
|
| 52 |
+
|
| 53 |
+
if self.min_block and metadata.block < self.min_block:
|
| 54 |
+
bt.logging.trace(
|
| 55 |
+
f"Skipping model for {hotkey} since it was submitted at block {metadata.block} which is less than the minimum block {self.min_block}"
|
| 56 |
+
)
|
| 57 |
+
return False
|
| 58 |
+
|
| 59 |
+
# Backwards compatability for models submitted before competition id added
|
| 60 |
+
if metadata.id.competition_id is None:
|
| 61 |
+
metadata.id.competition_id = constants.ORIGINAL_COMPETITION_ID
|
| 62 |
+
|
| 63 |
+
parameters = ModelUpdater.get_competition_parameters(metadata.id.competition_id)
|
| 64 |
+
if not parameters:
|
| 65 |
+
bt.logging.trace(
|
| 66 |
+
f"No competition parameters found for {metadata.id.competition_id}"
|
| 67 |
+
)
|
| 68 |
+
return False
|
| 69 |
+
|
| 70 |
+
# Check what model id the model tracker currently has for this hotkey.
|
| 71 |
+
tracker_model_metadata = self.model_tracker.get_model_metadata_for_miner_hotkey(
|
| 72 |
+
hotkey
|
| 73 |
+
)
|
| 74 |
+
if metadata == tracker_model_metadata:
|
| 75 |
+
return False
|
| 76 |
+
bt.logging.debug(f"Syncing model for hotkey {hotkey}")
|
| 77 |
+
# Get the local path based on the local store to download to (top level hotkey path)
|
| 78 |
+
path = self.local_store.get_path(hotkey)
|
| 79 |
+
|
| 80 |
+
# bt.logging.warning(f"Downloading model to {path}")
|
| 81 |
+
# # Otherwise we need to download the new model based on the metadata.
|
| 82 |
+
# model = await self.remote_store.download_model(metadata.id, path, parameters)
|
| 83 |
+
# bt.logging.warning(f"Downloaded model to {path}")
|
| 84 |
+
# # Check that the hash of the downloaded content matches.
|
| 85 |
+
# if model.id.hash != metadata.id.hash:
|
| 86 |
+
# raise ValueError(
|
| 87 |
+
# f"Sync for hotkey {hotkey} failed. Hash of content downloaded from hugging face does not match chain metadata. {metadata}"
|
| 88 |
+
# )
|
| 89 |
+
|
| 90 |
+
# Update the tracker
|
| 91 |
+
self.model_tracker.on_miner_model_updated(hotkey, metadata)
|
| 92 |
+
bt.logging.info(f"Model for hotkey {hotkey} updated to {metadata}")
|
| 93 |
+
return True
|
model/storage/__init__.py
ADDED
|
File without changes
|
model/storage/chain/chain_model_metadata_store.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import functools
|
| 3 |
+
import bittensor as bt
|
| 4 |
+
import os
|
| 5 |
+
from model.data import ModelId, ModelMetadata
|
| 6 |
+
import constants
|
| 7 |
+
from model.storage.model_metadata_store import ModelMetadataStore
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
from utilities import utils
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ChainModelMetadataStore(ModelMetadataStore):
|
| 14 |
+
"""Chain based implementation for storing and retrieving metadata about a model."""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
subtensor: bt.subtensor,
|
| 19 |
+
subnet_uid: int,
|
| 20 |
+
wallet: Optional[bt.wallet] = None,
|
| 21 |
+
):
|
| 22 |
+
self.subtensor = subtensor
|
| 23 |
+
self.wallet = (
|
| 24 |
+
wallet # Wallet is only needed to write to the chain, not to read.
|
| 25 |
+
)
|
| 26 |
+
self.subnet_uid = subnet_uid
|
| 27 |
+
|
| 28 |
+
# this is a hacky way to prime the get_metadata function
|
| 29 |
+
SN21_OWNER_KEY = "5GsHpHeCGhxstoEEZTR64VUashnDP4n7ir7LbNdRfXpkMU7R"
|
| 30 |
+
metadata = bt.extrinsics_subpackage.serving.get_metadata(self.subtensor, self.subnet_uid, SN21_OWNER_KEY)
|
| 31 |
+
bt.logging.debug(f"primed get_metadata call successfully: metadata={metadata} (ok to be None)")
|
| 32 |
+
|
| 33 |
+
async def store_model_metadata(self, hotkey: str, model_id: ModelId):
|
| 34 |
+
"""Stores model metadata on this subnet for a specific wallet."""
|
| 35 |
+
if self.wallet is None:
|
| 36 |
+
raise ValueError("No wallet available to write to the chain.")
|
| 37 |
+
|
| 38 |
+
# Wrap calls to the subtensor in a subprocess with a timeout to handle potential hangs.
|
| 39 |
+
# partial = functools.partial(
|
| 40 |
+
# self.subtensor.commit,
|
| 41 |
+
# self.wallet,
|
| 42 |
+
# self.subnet_uid,
|
| 43 |
+
# model_id.to_compressed_str(),
|
| 44 |
+
# )
|
| 45 |
+
# utils.run_in_subprocess(partial, 60)
|
| 46 |
+
commit = self.subtensor.commit(self.wallet, self.subnet_uid, model_id.to_compressed_str())
|
| 47 |
+
print(f"success commit: {commit}")
|
| 48 |
+
if not commit:
|
| 49 |
+
raise ValueError("Failed to commit model metadata to the chain.")
|
| 50 |
+
return commit
|
| 51 |
+
|
| 52 |
+
async def retrieve_model_metadata(self, hotkey: str) -> Optional[ModelMetadata]:
|
| 53 |
+
"""Retrieves model metadata on this subnet for specific hotkey"""
|
| 54 |
+
|
| 55 |
+
# Wrap calls to the subtensor in a subprocess with a timeout to handle potential hangs.
|
| 56 |
+
partial = functools.partial(
|
| 57 |
+
bt.extrinsics_subpackage.serving.get_metadata, self.subtensor, self.subnet_uid, hotkey
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
metadata = utils.run_in_subprocess(partial, 60)
|
| 61 |
+
|
| 62 |
+
if not metadata:
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
commitment = metadata["info"]["fields"][0][0]
|
| 66 |
+
|
| 67 |
+
hex_data_tuple = commitment[list(commitment.keys())[0]][0]
|
| 68 |
+
|
| 69 |
+
chain_str = ''.join(chr(num) for num in hex_data_tuple)
|
| 70 |
+
|
| 71 |
+
model_id = None
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
model_id = ModelId.from_compressed_str(chain_str)
|
| 75 |
+
except:
|
| 76 |
+
# If the metadata format is not correct on the chain then we return None.
|
| 77 |
+
bt.logging.trace(
|
| 78 |
+
f"Failed to parse the metadata on the chain for hotkey {hotkey}."
|
| 79 |
+
)
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
model_metadata = ModelMetadata(id=model_id, block=metadata["block"])
|
| 83 |
+
return model_metadata
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# Can only commit data every ~20 minutes.
|
| 87 |
+
async def test_store_model_metadata():
|
| 88 |
+
"""Verifies that the ChainModelMetadataStore can store data on the chain."""
|
| 89 |
+
model_id = ModelId(
|
| 90 |
+
namespace="TestPath", name="TestModel", hash="TestHash1", commit="1.0"
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Use a different subnet that does not leverage chain storage to avoid conflicts.
|
| 94 |
+
# TODO switch to a mocked version when it supports commits.
|
| 95 |
+
subtensor = bt.subtensor()
|
| 96 |
+
|
| 97 |
+
# Uses .env configured wallet/hotkey/uid for the test.
|
| 98 |
+
coldkey = os.getenv("TEST_COLDKEY")
|
| 99 |
+
hotkey = os.getenv("TEST_HOTKEY")
|
| 100 |
+
net_uid = int(os.getenv("TEST_SUBNET_UID"))
|
| 101 |
+
|
| 102 |
+
wallet = bt.wallet(name=coldkey, hotkey=hotkey)
|
| 103 |
+
|
| 104 |
+
metadata_store = ChainModelMetadataStore(
|
| 105 |
+
subtensor=subtensor, wallet=wallet, subnet_uid=net_uid
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Store the metadata on chain.
|
| 109 |
+
await metadata_store.store_model_metadata(hotkey=hotkey, model_id=model_id)
|
| 110 |
+
|
| 111 |
+
print(f"Finished storing {model_id} on the chain.")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
async def test_retrieve_model_metadata():
|
| 115 |
+
"""Verifies that the ChainModelMetadataStore can retrieve data from the chain."""
|
| 116 |
+
expected_model_id = ModelId(
|
| 117 |
+
namespace="TestPath", name="TestModel", hash="TestHash1", commit="1.0"
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Use a different subnet that does not leverage chain storage to avoid conflicts.
|
| 121 |
+
# TODO switch to a mocked version when it supports commits.
|
| 122 |
+
subtensor = bt.subtensor()
|
| 123 |
+
|
| 124 |
+
# Uses .env configured hotkey/uid for the test.
|
| 125 |
+
net_uid = int(os.getenv("TEST_SUBNET_UID"))
|
| 126 |
+
hotkey_address = os.getenv("TEST_HOTKEY_ADDRESS")
|
| 127 |
+
|
| 128 |
+
# Do not require a wallet for retrieving data.
|
| 129 |
+
metadata_store = ChainModelMetadataStore(
|
| 130 |
+
subtensor=subtensor, wallet=None, subnet_uid=net_uid
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Retrieve the metadata from the chain.
|
| 134 |
+
model_metadata = await metadata_store.retrieve_model_metadata(hotkey_address)
|
| 135 |
+
|
| 136 |
+
print(f"Expecting matching model id: {expected_model_id == model_metadata.id}")
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# Can only commit data every ~20 minutes.
|
| 140 |
+
async def test_roundtrip_model_metadata():
|
| 141 |
+
"""Verifies that the ChainModelMetadataStore can roundtrip data on the chain."""
|
| 142 |
+
model_id = ModelId(
|
| 143 |
+
namespace="TestPath", name="TestModel", hash="TestHash1", commit="1.0"
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# Use a different subnet that does not leverage chain storage to avoid conflicts.
|
| 147 |
+
# TODO switch to a mocked version when it supports commits.
|
| 148 |
+
subtensor = bt.subtensor()
|
| 149 |
+
|
| 150 |
+
# Uses .env configured wallet/hotkey/uid for the test.
|
| 151 |
+
coldkey = os.getenv("TEST_COLDKEY")
|
| 152 |
+
hotkey = os.getenv("TEST_HOTKEY")
|
| 153 |
+
net_uid = int(os.getenv("TEST_SUBNET_UID"))
|
| 154 |
+
|
| 155 |
+
wallet = bt.wallet(name=coldkey, hotkey=hotkey)
|
| 156 |
+
|
| 157 |
+
metadata_store = ChainModelMetadataStore(
|
| 158 |
+
subtensor=subtensor, wallet=wallet, subnet_uid=net_uid
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Store the metadata on chain.
|
| 162 |
+
await metadata_store.store_model_metadata(hotkey=hotkey, model_id=model_id)
|
| 163 |
+
|
| 164 |
+
# May need to use the underlying publish_metadata function with wait_for_inclusion: True to pass here.
|
| 165 |
+
# Otherwise it defaults to False and we only wait for finalization not necessarily inclusion.
|
| 166 |
+
|
| 167 |
+
# Retrieve the metadata from the chain.
|
| 168 |
+
model_metadata = await metadata_store.retrieve_model_metadata(hotkey)
|
| 169 |
+
|
| 170 |
+
print(f"Expecting matching metadata: {model_id == model_metadata.id}")
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
# Can only commit data every ~20 minutes.
|
| 175 |
+
# asyncio.run(test_roundtrip_model_metadata())
|
| 176 |
+
# asyncio.run(test_store_model_metadata())
|
| 177 |
+
asyncio.run(test_retrieve_model_metadata())
|
model/storage/disk/__init__.py
ADDED
|
File without changes
|
model/storage/disk/disk_model_store.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import bittensor as bt
|
| 2 |
+
import datetime
|
| 3 |
+
import os
|
| 4 |
+
from typing import Dict
|
| 5 |
+
from constants import CompetitionParameters
|
| 6 |
+
from model.data import Model, ModelId
|
| 7 |
+
from model.storage.disk import utils
|
| 8 |
+
from model.storage.local_model_store import LocalModelStore
|
| 9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DiskModelStore(LocalModelStore):
|
| 14 |
+
"""Local storage based implementation for storing and retrieving a model on disk."""
|
| 15 |
+
|
| 16 |
+
def __init__(self, base_dir: str):
|
| 17 |
+
self.base_dir = base_dir
|
| 18 |
+
os.makedirs(utils.get_local_miners_dir(base_dir), exist_ok=True)
|
| 19 |
+
|
| 20 |
+
def get_path(self, hotkey: str) -> str:
|
| 21 |
+
"""Returns the path to where this store would locate this hotkey."""
|
| 22 |
+
return utils.get_local_miner_dir(self.base_dir, hotkey)
|
| 23 |
+
|
| 24 |
+
def store_model(self, hotkey: str, model: Model, hf_model: AutoModelForCausalLM, hf_tokenizer: AutoTokenizer ) -> ModelId:
|
| 25 |
+
"""Stores a trained model locally."""
|
| 26 |
+
# get the path to where the model should be stored
|
| 27 |
+
model_dir = os.path.join(self.get_path(hotkey), model.id.name)
|
| 28 |
+
hf_model.save_pretrained(model_dir)
|
| 29 |
+
hf_tokenizer.save_pretrained(model_dir)
|
| 30 |
+
model.local_repo_dir = model_dir
|
| 31 |
+
|
| 32 |
+
return model.id
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def retrieve_model(
|
| 36 |
+
self, hotkey: str, model_id: ModelId, model_parameters: CompetitionParameters
|
| 37 |
+
) -> Model:
|
| 38 |
+
"""Retrieves a trained model locally."""
|
| 39 |
+
|
| 40 |
+
# get the path to where the model should be stored
|
| 41 |
+
model_dir = os.path.join(self.get_path(hotkey), model_id.name)
|
| 42 |
+
return Model(id=model_id, local_repo_dir=model_dir)
|
| 43 |
+
|
| 44 |
+
def delete_unreferenced_models(
|
| 45 |
+
self,
|
| 46 |
+
valid_models_by_hotkey: Dict[str, ModelId],
|
| 47 |
+
model_touched_by_hotkey: Dict[str, datetime.datetime],
|
| 48 |
+
grace_period_seconds: int,
|
| 49 |
+
):
|
| 50 |
+
"""Check across all of local storage and delete unreferenced models out of grace period."""
|
| 51 |
+
# TODO: THIS METHOD IS NOT UP TO DATE YET
|
| 52 |
+
raise NotImplementedError("This method is not implemented yet.")
|
| 53 |
+
# Expected directory structure is as follows.
|
| 54 |
+
# self.base_dir/models/hotkey/models--namespace--name/snapshots/commit/config.json + other files.
|
| 55 |
+
|
| 56 |
+
# Create a set of valid model paths up to where we expect to see the actual files.
|
| 57 |
+
valid_model_paths = set()
|
| 58 |
+
for hotkey, model_id in valid_models_by_hotkey.items():
|
| 59 |
+
valid_model_paths.add(
|
| 60 |
+
utils.get_local_model_snapshot_dir(self.base_dir, hotkey, model_id)
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# For each hotkey path on disk using listdir to go one level deep.
|
| 64 |
+
miners_dir = Path(utils.get_local_miners_dir(self.base_dir))
|
| 65 |
+
hotkey_subfolder_names = [d.name for d in miners_dir.iterdir() if d.is_dir()]
|
| 66 |
+
|
| 67 |
+
for hotkey in hotkey_subfolder_names:
|
| 68 |
+
# Reconstruct the path from the hotkey
|
| 69 |
+
hotkey_path = utils.get_local_miner_dir(self.base_dir, hotkey)
|
| 70 |
+
|
| 71 |
+
# If it is not in valid_hotkeys and out of grace period remove it.
|
| 72 |
+
if hotkey not in valid_models_by_hotkey:
|
| 73 |
+
deleted_hotkey = utils.remove_dir_out_of_grace(
|
| 74 |
+
hotkey_path, grace_period_seconds
|
| 75 |
+
)
|
| 76 |
+
if deleted_hotkey:
|
| 77 |
+
bt.logging.trace(
|
| 78 |
+
f"Removed directory for unreferenced hotkey: {hotkey}."
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
# Check all the models--namespace--name subfolder paths.
|
| 82 |
+
hotkey_dir = Path(hotkey_path)
|
| 83 |
+
model_subfolder_paths = [
|
| 84 |
+
str(d) for d in hotkey_dir.iterdir() if d.is_dir()
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
# Check all the snapshots subfolder paths
|
| 88 |
+
for model_path in model_subfolder_paths:
|
| 89 |
+
model_dir = Path(model_path)
|
| 90 |
+
snapshot_subfolder_paths = [
|
| 91 |
+
str(d) for d in model_dir.iterdir() if d.is_dir()
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
# Check all the commit paths.
|
| 95 |
+
for snapshot_path in snapshot_subfolder_paths:
|
| 96 |
+
snapshot_dir = Path(snapshot_path)
|
| 97 |
+
commit_subfolder_paths = [
|
| 98 |
+
str(d) for d in snapshot_dir.iterdir() if d.is_dir()
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
# Reached the end. Check all the actual commit subfolders for the files.
|
| 102 |
+
for commit_path in commit_subfolder_paths:
|
| 103 |
+
if commit_path not in valid_model_paths:
|
| 104 |
+
deleted_model = utils.remove_dir_out_of_grace(
|
| 105 |
+
commit_path, grace_period_seconds
|
| 106 |
+
)
|
| 107 |
+
if deleted_model:
|
| 108 |
+
bt.logging.trace(
|
| 109 |
+
f"Removing directory for unreferenced model at: {commit_path}."
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
last_touched = model_touched_by_hotkey.get(hotkey)
|
| 113 |
+
if last_touched is not None:
|
| 114 |
+
deleted_model = (
|
| 115 |
+
utils.remove_dir_out_of_grace_by_datetime(
|
| 116 |
+
commit_path,
|
| 117 |
+
grace_period_seconds,
|
| 118 |
+
last_touched,
|
| 119 |
+
)
|
| 120 |
+
)
|
| 121 |
+
if deleted_model:
|
| 122 |
+
bt.logging.trace(
|
| 123 |
+
f"Removing directory for stale model at: {commit_path}."
|
| 124 |
+
)
|
model/storage/disk/utils.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import datetime
|
| 3 |
+
import hashlib
|
| 4 |
+
import os
|
| 5 |
+
import shutil
|
| 6 |
+
import sys
|
| 7 |
+
from model.data import ModelId
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_local_miners_dir(base_dir: str) -> str:
|
| 11 |
+
return os.path.join(base_dir, "models")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_local_miner_dir(base_dir: str, hotkey: str) -> str:
|
| 15 |
+
return os.path.join(get_local_miners_dir(base_dir), hotkey)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Hugging face stores models under models--namespace--name/snapshots/commit when downloading.
|
| 19 |
+
def get_local_model_dir(base_dir: str, hotkey: str, model_id: ModelId) -> str:
|
| 20 |
+
return os.path.join(
|
| 21 |
+
get_local_miner_dir(base_dir, hotkey),
|
| 22 |
+
"models" + "--" + model_id.namespace + "--" + model_id.name,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_local_model_snapshot_dir(base_dir: str, hotkey: str, model_id: ModelId) -> str:
|
| 27 |
+
return os.path.join(
|
| 28 |
+
get_local_model_dir(base_dir, hotkey, model_id),
|
| 29 |
+
"snapshots",
|
| 30 |
+
model_id.commit,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_hf_download_path(local_path: str, model_id: ModelId) -> str:
|
| 35 |
+
return os.path.join(
|
| 36 |
+
local_path,
|
| 37 |
+
"models" + "--" + model_id.namespace + "--" + model_id.name,
|
| 38 |
+
"snapshots",
|
| 39 |
+
model_id.commit,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_newest_datetime_under_path(path: str) -> datetime.datetime:
|
| 44 |
+
newest_filetime = sys.maxsize
|
| 45 |
+
|
| 46 |
+
# Check to see if any file at any level was modified more recently than the current one.
|
| 47 |
+
for cur_path, dirnames, filenames in os.walk(path):
|
| 48 |
+
for filename in filenames:
|
| 49 |
+
path = os.path.join(cur_path, filename)
|
| 50 |
+
try:
|
| 51 |
+
mod_time = os.stat(path).st_mtime
|
| 52 |
+
if mod_time < newest_filetime:
|
| 53 |
+
newest_filetime = mod_time
|
| 54 |
+
except:
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
if newest_filetime == sys.maxsize:
|
| 58 |
+
return datetime.datetime.max
|
| 59 |
+
|
| 60 |
+
return datetime.datetime.fromtimestamp(newest_filetime)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def remove_dir_out_of_grace_by_datetime(path: str, grace_period_seconds: int, last_modified: datetime.datetime) -> bool:
|
| 64 |
+
"""Removes a dir if the last modified time is out of grace period secs. Returns if it was deleted."""
|
| 65 |
+
grace = datetime.timedelta(seconds=grace_period_seconds)
|
| 66 |
+
|
| 67 |
+
if last_modified < datetime.datetime.now() - grace:
|
| 68 |
+
shutil.rmtree(path=path, ignore_errors=True)
|
| 69 |
+
return True
|
| 70 |
+
|
| 71 |
+
return False
|
| 72 |
+
|
| 73 |
+
def remove_dir_out_of_grace(path: str, grace_period_seconds: int) -> bool:
|
| 74 |
+
"""Removes a dir if the last modified time is out of grace period secs. Returns if it was deleted."""
|
| 75 |
+
last_modified = get_newest_datetime_under_path(path)
|
| 76 |
+
return remove_dir_out_of_grace_by_datetime(path, grace_period_seconds, last_modified)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def realize_symlinks_in_directory(path: str) -> int:
|
| 80 |
+
"""Realizes all symlinks in the given directory, moving the linked file to the location. Returns count removed."""
|
| 81 |
+
realized_symlinks = 0
|
| 82 |
+
|
| 83 |
+
for cur_path, dirnames, filenames in os.walk(path):
|
| 84 |
+
for filename in filenames:
|
| 85 |
+
path = os.path.abspath(os.path.join(cur_path, filename))
|
| 86 |
+
# Get path resolving symlinks if encountered
|
| 87 |
+
real_path = os.path.realpath(path)
|
| 88 |
+
# If different then move
|
| 89 |
+
if path != real_path:
|
| 90 |
+
realized_symlinks += 1
|
| 91 |
+
shutil.move(real_path, path)
|
| 92 |
+
|
| 93 |
+
return realized_symlinks
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_hash_of_file(path: str) -> str:
|
| 97 |
+
blocksize = 64 * 1024
|
| 98 |
+
file_hash = hashlib.sha256()
|
| 99 |
+
with open(path, "rb") as fp:
|
| 100 |
+
while True:
|
| 101 |
+
data = fp.read(blocksize)
|
| 102 |
+
if not data:
|
| 103 |
+
break
|
| 104 |
+
file_hash.update(data)
|
| 105 |
+
return base64.b64encode(file_hash.digest()).decode("utf-8")
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_hash_of_directory(path: str) -> str:
|
| 109 |
+
dir_hash = hashlib.sha256()
|
| 110 |
+
|
| 111 |
+
# Recursively walk everything under the directory for files.
|
| 112 |
+
for cur_path, dirnames, filenames in os.walk(path):
|
| 113 |
+
# Ensure we walk future directories in a consistent order.
|
| 114 |
+
dirnames.sort()
|
| 115 |
+
# Ensure we walk files in a consistent order.
|
| 116 |
+
for filename in sorted(filenames):
|
| 117 |
+
path = os.path.join(cur_path, filename)
|
| 118 |
+
file_hash = get_hash_of_file(path)
|
| 119 |
+
dir_hash.update(file_hash.encode())
|
| 120 |
+
|
| 121 |
+
return base64.b64encode(dir_hash.digest()).decode("utf-8")
|
model/storage/eval_leaderboard.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sqlalchemy import create_engine, Column, Integer, Float, String, DateTime, ForeignKey
|
| 2 |
+
from sqlalchemy.ext.declarative import declarative_base
|
| 3 |
+
from sqlalchemy.orm import Session, sessionmaker, relationship
|
| 4 |
+
from sqlalchemy.exc import OperationalError
|
| 5 |
+
from contextlib import contextmanager
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
import bittensor as bt
|
| 8 |
+
from typing import Optional, Dict, List
|
| 9 |
+
import time
|
| 10 |
+
from vali_api.config import DBHOST, DBNAME, DBUSER, DBPASS
|
| 11 |
+
|
| 12 |
+
Base = declarative_base()
|
| 13 |
+
|
| 14 |
+
# Global variables for engine and Session
|
| 15 |
+
_engine: Optional[object] = None
|
| 16 |
+
Session: Optional[sessionmaker] = None
|
| 17 |
+
|
| 18 |
+
def init_database():
|
| 19 |
+
"""Initialize the database connection and create tables."""
|
| 20 |
+
global _engine, Session
|
| 21 |
+
|
| 22 |
+
if _engine is not None:
|
| 23 |
+
bt.logging.warning("Database already initialized")
|
| 24 |
+
return
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
connection_string = f'mysql://{DBUSER}:{DBPASS}@{DBHOST}/{DBNAME}'
|
| 28 |
+
_engine = create_engine(connection_string)
|
| 29 |
+
Session = sessionmaker(bind=_engine)
|
| 30 |
+
Base.metadata.create_all(_engine)
|
| 31 |
+
bt.logging.info("Database initialized successfully")
|
| 32 |
+
except Exception as e:
|
| 33 |
+
bt.logging.error(f"Failed to initialize database: {str(e)}")
|
| 34 |
+
raise
|
| 35 |
+
|
| 36 |
+
def get_session() -> Session:
|
| 37 |
+
"""Get a database session."""
|
| 38 |
+
if Session is None:
|
| 39 |
+
raise RuntimeError("Database not initialized. Call init_database() first.")
|
| 40 |
+
return Session()
|
| 41 |
+
|
| 42 |
+
class EvaluationModel(Base):
|
| 43 |
+
__tablename__ = 'sn21_evals_test'
|
| 44 |
+
|
| 45 |
+
eval_id = Column(Integer, primary_key=True)
|
| 46 |
+
miner_hotkey = Column(String(255))
|
| 47 |
+
miner_uid = Column(Integer)
|
| 48 |
+
model_name = Column(String(255))
|
| 49 |
+
model_type = Column(String(255))
|
| 50 |
+
eval_date = Column(DateTime)
|
| 51 |
+
competition_id = Column(String(10))
|
| 52 |
+
|
| 53 |
+
results = relationship("EvaluationResult", back_populates="evaluation")
|
| 54 |
+
|
| 55 |
+
class EvaluationResult(Base):
|
| 56 |
+
__tablename__ = 'sn21_eval_results_test'
|
| 57 |
+
|
| 58 |
+
eval_result_id = Column(Integer, primary_key=True)
|
| 59 |
+
eval_id = Column(Integer, ForeignKey('sn21_evals_test.eval_id'))
|
| 60 |
+
task = Column(String(255))
|
| 61 |
+
result_name = Column(String(255))
|
| 62 |
+
result = Column(Float)
|
| 63 |
+
competition_id = Column(String(10))
|
| 64 |
+
deleted_at = Column(DateTime)
|
| 65 |
+
|
| 66 |
+
evaluation = relationship("EvaluationModel", back_populates="results")
|
| 67 |
+
|
| 68 |
+
class EvalLeaderboardManager:
|
| 69 |
+
def __init__(self, max_retries=3, retry_delay=1):
|
| 70 |
+
if Session is None:
|
| 71 |
+
raise RuntimeError("Database not initialized. Call init_database() first.")
|
| 72 |
+
|
| 73 |
+
self.session = get_session()
|
| 74 |
+
self.max_retries = max_retries
|
| 75 |
+
self.retry_delay = retry_delay
|
| 76 |
+
|
| 77 |
+
@contextmanager
|
| 78 |
+
def session_scope(self):
|
| 79 |
+
"""Provide a transactional scope around a series of operations."""
|
| 80 |
+
session = get_session()
|
| 81 |
+
try:
|
| 82 |
+
yield session
|
| 83 |
+
session.commit()
|
| 84 |
+
except Exception as e:
|
| 85 |
+
session.rollback()
|
| 86 |
+
raise
|
| 87 |
+
finally:
|
| 88 |
+
session.close()
|
| 89 |
+
|
| 90 |
+
def execute_with_retry(self, operation, *args, **kwargs):
|
| 91 |
+
"""Execute database operation with retry logic."""
|
| 92 |
+
for attempt in range(self.max_retries):
|
| 93 |
+
try:
|
| 94 |
+
return operation(*args, **kwargs)
|
| 95 |
+
except OperationalError as e:
|
| 96 |
+
if attempt < self.max_retries - 1:
|
| 97 |
+
bt.logging.warning(f"Database error. Attempt {attempt + 1}/{self.max_retries}. Retrying...")
|
| 98 |
+
time.sleep(self.retry_delay)
|
| 99 |
+
else:
|
| 100 |
+
raise
|
| 101 |
+
except Exception as e:
|
| 102 |
+
bt.logging.error(f"Error executing operation: {str(e)}")
|
| 103 |
+
raise
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def get_metrics_timeseries(self) -> Dict[str, List[Dict]]:
|
| 107 |
+
"""
|
| 108 |
+
Get time series data for all turn metrics.
|
| 109 |
+
Returns data in format {metric_name: [{date: "", models: [{modelType: "", modelName: "", score: 123}, ...]}, ...]}
|
| 110 |
+
Groups models by model_type and includes model names in the data.
|
| 111 |
+
"""
|
| 112 |
+
def _get_timeseries():
|
| 113 |
+
with self.session_scope() as session:
|
| 114 |
+
evals = session.query(EvaluationModel).order_by(EvaluationModel.eval_date).all()
|
| 115 |
+
|
| 116 |
+
metrics = set()
|
| 117 |
+
# Get all non-deleted results that aren't the excluded metric
|
| 118 |
+
sample_results = session.query(EvaluationResult).filter(
|
| 119 |
+
EvaluationResult.result_name != "exact_match_stderr,flexible-extract",
|
| 120 |
+
EvaluationResult.deleted_at.is_(None) # SQLAlchemy's proper way to check for NULL
|
| 121 |
+
).all()
|
| 122 |
+
for result in sample_results:
|
| 123 |
+
metric_name = result.result_name
|
| 124 |
+
metrics.add(metric_name)
|
| 125 |
+
timeseries_data = {metric: {} for metric in metrics}
|
| 126 |
+
|
| 127 |
+
for eval_ in evals:
|
| 128 |
+
date_str = eval_.eval_date.strftime('%Y-%m-%d')
|
| 129 |
+
model_type = eval_.model_type
|
| 130 |
+
model_name = eval_.model_name
|
| 131 |
+
competition_id = eval_.competition_id
|
| 132 |
+
|
| 133 |
+
# Only process non-deleted results
|
| 134 |
+
for result in [r for r in eval_.results if r.deleted_at is None]:
|
| 135 |
+
metric_name = result.result_name
|
| 136 |
+
if metric_name in metrics:
|
| 137 |
+
if date_str not in timeseries_data[metric_name]:
|
| 138 |
+
timeseries_data[metric_name][date_str] = {
|
| 139 |
+
'date': date_str,
|
| 140 |
+
'models': []
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
# Check if we already have an entry for this model_type
|
| 144 |
+
existing_model = next(
|
| 145 |
+
(m for m in timeseries_data[metric_name][date_str]['models']
|
| 146 |
+
if m['modelType'] == model_type),
|
| 147 |
+
None
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
if existing_model:
|
| 151 |
+
# Update existing entry
|
| 152 |
+
existing_model['score'] = result.result
|
| 153 |
+
if model_name not in existing_model['modelName']:
|
| 154 |
+
existing_model['modelName'] = model_name
|
| 155 |
+
else:
|
| 156 |
+
# Add new entry
|
| 157 |
+
timeseries_data[metric_name][date_str]['models'].append({
|
| 158 |
+
'modelType': model_type,
|
| 159 |
+
'modelName': model_name,
|
| 160 |
+
'score': result.result,
|
| 161 |
+
'competition_id': competition_id,
|
| 162 |
+
'task': result.task
|
| 163 |
+
})
|
| 164 |
+
|
| 165 |
+
return {
|
| 166 |
+
metric: [
|
| 167 |
+
data_point for data_point in sorted(data.values(), key=lambda x: x['date'])
|
| 168 |
+
if data_point['models'] # Filter out entries with empty models list
|
| 169 |
+
]
|
| 170 |
+
for metric, data in timeseries_data.items()
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
try:
|
| 174 |
+
return self.execute_with_retry(_get_timeseries)
|
| 175 |
+
except Exception as e:
|
| 176 |
+
bt.logging.error(f"Failed to get metrics timeseries: {str(e)}")
|
| 177 |
+
return {}
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def close(self):
|
| 181 |
+
"""Safely close the session."""
|
| 182 |
+
try:
|
| 183 |
+
self.session.close()
|
| 184 |
+
except:
|
| 185 |
+
pass
|
model/storage/hugging_face/__init__.py
ADDED
|
File without changes
|
model/storage/hugging_face/hugging_face_model_store.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from omegaconf import OmegaConf
|
| 3 |
+
from huggingface_hub import HfApi
|
| 4 |
+
from model.data import Model, ModelId
|
| 5 |
+
from model.storage.disk import utils
|
| 6 |
+
from constants import CompetitionParameters, MAX_HUGGING_FACE_BYTES
|
| 7 |
+
|
| 8 |
+
from model.storage.remote_model_store import RemoteModelStore
|
| 9 |
+
from huggingface_hub import HfApi, file_exists
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
MODEL_FILE_PT = "meta_model_{epoch}.pt"
|
| 14 |
+
ADAPTER_FILE_PT = "adapter_{epoch}.pt"
|
| 15 |
+
CONFIG_FILE = "training_config.yml"
|
| 16 |
+
CONFIG_FILE_MOSHI = "config.yaml"
|
| 17 |
+
README_FILE = "README.md"
|
| 18 |
+
HOTKEY_FILE = "hotkey.txt"
|
| 19 |
+
|
| 20 |
+
### MOSHI ###
|
| 21 |
+
LM_FILE_PT_MOSHI = "model.safetensors"
|
| 22 |
+
MIMI_FILE_PT_MOSHI = "tokenizer-e351c8d8-checkpoint125.safetensors"
|
| 23 |
+
TOKENIZER_FILE_MOSHI = "tokenizer_spm_32k_3.model"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def check_config(ckpt_dir):
|
| 27 |
+
config_file = os.path.join(ckpt_dir, CONFIG_FILE)
|
| 28 |
+
cfg = OmegaConf.load(config_file)
|
| 29 |
+
if cfg.model.use_clip:
|
| 30 |
+
raise ValueError("Cannot upload checkpoints with CLIP embeddings")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_required_files(epoch: int, model_type: str):
|
| 34 |
+
if model_type == "o1":
|
| 35 |
+
return [
|
| 36 |
+
MODEL_FILE_PT.format(epoch=epoch),
|
| 37 |
+
CONFIG_FILE,
|
| 38 |
+
]
|
| 39 |
+
elif model_type == "v1":
|
| 40 |
+
return [
|
| 41 |
+
LM_FILE_PT_MOSHI,
|
| 42 |
+
MIMI_FILE_PT_MOSHI,
|
| 43 |
+
TOKENIZER_FILE_MOSHI,
|
| 44 |
+
CONFIG_FILE_MOSHI
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def export_readme(ckpt_dir: str):
|
| 49 |
+
readme_file = os.path.join(ckpt_dir, README_FILE)
|
| 50 |
+
with open(readme_file, "w") as f:
|
| 51 |
+
f.write(
|
| 52 |
+
f"""---
|
| 53 |
+
license: mit
|
| 54 |
+
tags:
|
| 55 |
+
- any-to-any
|
| 56 |
+
- omega
|
| 57 |
+
- omegalabs
|
| 58 |
+
- bittensor
|
| 59 |
+
- agi
|
| 60 |
+
---
|
| 61 |
+
|
| 62 |
+
This is an Any-to-Any model checkpoint for the OMEGA Labs x Bittensor Any-to-Any subnet.
|
| 63 |
+
|
| 64 |
+
Check out the [git repo](https://github.com/omegalabsinc/omegalabs-anytoany-bittensor) and find OMEGA on X: [@omegalabsai](https://x.com/omegalabsai).
|
| 65 |
+
"""
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def export_hotkey(ckpt_dir: str, hotkey: str):
|
| 70 |
+
hotkey_file = os.path.join(ckpt_dir, HOTKEY_FILE)
|
| 71 |
+
with open(hotkey_file, "w") as f:
|
| 72 |
+
f.write(hotkey)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class HuggingFaceModelStore(RemoteModelStore):
|
| 76 |
+
"""Hugging Face based implementation for storing and retrieving a model."""
|
| 77 |
+
|
| 78 |
+
@classmethod
|
| 79 |
+
def assert_access_token_exists(cls) -> str:
|
| 80 |
+
"""Asserts that the access token exists."""
|
| 81 |
+
if not os.getenv("HF_ACCESS_TOKEN"):
|
| 82 |
+
raise ValueError("No Hugging Face access token found to write to the hub.")
|
| 83 |
+
return os.getenv("HF_ACCESS_TOKEN")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
async def upload_model(
|
| 87 |
+
self, model: Model,
|
| 88 |
+
competition_parameters: CompetitionParameters,
|
| 89 |
+
hotkey: str
|
| 90 |
+
) -> ModelId:
|
| 91 |
+
"""Uploads a trained model to Hugging Face."""
|
| 92 |
+
token = HuggingFaceModelStore.assert_access_token_exists()
|
| 93 |
+
api = HfApi(token=token)
|
| 94 |
+
export_readme(model.local_repo_dir)
|
| 95 |
+
export_hotkey(model.local_repo_dir, hotkey)
|
| 96 |
+
hf_repo_id = model.id.namespace + "/" + model.id.name
|
| 97 |
+
api.create_repo(
|
| 98 |
+
repo_id=hf_repo_id,
|
| 99 |
+
exist_ok=True,
|
| 100 |
+
private=True,
|
| 101 |
+
)
|
| 102 |
+
commit_info = api.upload_folder(repo_id=hf_repo_id, folder_path=model.local_repo_dir, use_auth_token=True)
|
| 103 |
+
|
| 104 |
+
print(f"Successfully uploaded model repository '{model.local_repo_dir}' to {hf_repo_id}")
|
| 105 |
+
|
| 106 |
+
model_id_with_commit = ModelId(
|
| 107 |
+
namespace=model.id.namespace,
|
| 108 |
+
name=model.id.name,
|
| 109 |
+
epoch=model.id.epoch,
|
| 110 |
+
hash=model.id.hash,
|
| 111 |
+
commit=commit_info.oid,
|
| 112 |
+
competition_id=model.id.competition_id,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
return model_id_with_commit
|
| 116 |
+
# # TODO consider skipping the redownload if a hash is already provided.
|
| 117 |
+
# # To get the hash we need to redownload it at a local tmp directory after which it can be deleted.
|
| 118 |
+
# with tempfile.TemporaryDirectory() as temp_dir:
|
| 119 |
+
# model_with_hash = await self.download_model(
|
| 120 |
+
# model_id_with_commit, temp_dir, competition_parameters
|
| 121 |
+
# )
|
| 122 |
+
# # Return a ModelId with both the correct commit and hash.
|
| 123 |
+
# return model_with_hash.id
|
| 124 |
+
|
| 125 |
+
async def download_model(
|
| 126 |
+
self,
|
| 127 |
+
model_id: ModelId,
|
| 128 |
+
local_path: str,
|
| 129 |
+
model_parameters: CompetitionParameters,
|
| 130 |
+
) -> Model:
|
| 131 |
+
"""Retrieves a trained model from Hugging Face."""
|
| 132 |
+
if not model_id.commit:
|
| 133 |
+
raise ValueError("No Hugging Face commit id found to read from the hub.")
|
| 134 |
+
|
| 135 |
+
repo_id = model_id.namespace + "/" + model_id.name
|
| 136 |
+
|
| 137 |
+
# Check ModelInfo for the size of model.safetensors file before downloading.
|
| 138 |
+
try:
|
| 139 |
+
token = HuggingFaceModelStore.assert_access_token_exists()
|
| 140 |
+
except:
|
| 141 |
+
token = None
|
| 142 |
+
api = HfApi(token=token)
|
| 143 |
+
model_info = api.model_info(
|
| 144 |
+
repo_id=repo_id, revision=model_id.commit, timeout=10, files_metadata=True
|
| 145 |
+
)
|
| 146 |
+
size = sum(repo_file.size for repo_file in model_info.siblings)
|
| 147 |
+
if size > MAX_HUGGING_FACE_BYTES:
|
| 148 |
+
raise ValueError(
|
| 149 |
+
f"Hugging Face repo over maximum size limit. Size {size}. Limit {MAX_HUGGING_FACE_BYTES}."
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
api.hf_hub_download(
|
| 153 |
+
repo_id=repo_id,
|
| 154 |
+
revision=model_id.commit,
|
| 155 |
+
filename="checkpoint.safetensors",
|
| 156 |
+
cache_dir=local_path,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Get the directory the model was stored to.
|
| 160 |
+
model_dir = utils.get_hf_download_path(local_path, model_id)
|
| 161 |
+
|
| 162 |
+
# Realize all symlinks in that directory since Transformers library does not support avoiding symlinks.
|
| 163 |
+
utils.realize_symlinks_in_directory(model_dir)
|
| 164 |
+
|
| 165 |
+
# Compute the hash of the downloaded model.
|
| 166 |
+
model_hash = utils.get_hash_of_directory(model_dir)
|
| 167 |
+
model_id_with_hash = ModelId(
|
| 168 |
+
namespace=model_id.namespace,
|
| 169 |
+
name=model_id.name,
|
| 170 |
+
commit=model_id.commit,
|
| 171 |
+
hash=model_hash,
|
| 172 |
+
competition_id=model_id.competition_id,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
return Model(id=model_id_with_hash, ckpt=model_dir)
|
model/storage/local_model_store.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from typing import Dict
|
| 3 |
+
from model.data import Model, ModelId
|
| 4 |
+
from constants import CompetitionParameters
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LocalModelStore(abc.ABC):
|
| 8 |
+
"""An abstract base class for storing and retrieving a pre trained model locally."""
|
| 9 |
+
|
| 10 |
+
@abc.abstractmethod
|
| 11 |
+
def store_model(self, hotkey: str, model: Model) -> ModelId:
|
| 12 |
+
"""Stores a trained model in the appropriate location based on implementation."""
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
@abc.abstractmethod
|
| 16 |
+
def get_path(self, hotkey: str) -> str:
|
| 17 |
+
"""Returns the path to the appropriate location based on implementation."""
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
@abc.abstractmethod
|
| 21 |
+
def retrieve_model(self, hotkey: str, model_id: ModelId, parameters: CompetitionParameters) -> Model:
|
| 22 |
+
"""Retrieves a trained model from the appropriate location based on implementation."""
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
@abc.abstractmethod
|
| 26 |
+
def delete_unreferenced_models(
|
| 27 |
+
self, valid_models_by_hotkey: Dict[str, ModelId], grace_period_seconds: int
|
| 28 |
+
):
|
| 29 |
+
"""Check across all of local storage and delete unreferenced models out of grace period."""
|
| 30 |
+
pass
|
model/storage/model_metadata_store.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from model.data import ModelId, ModelMetadata
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ModelMetadataStore(abc.ABC):
|
| 7 |
+
"""An abstract base class for storing and retrieving model metadata."""
|
| 8 |
+
|
| 9 |
+
@abc.abstractmethod
|
| 10 |
+
async def store_model_metadata(self, hotkey: str, model_id: ModelId):
|
| 11 |
+
"""Stores model metadata on this subnet for a specific miner."""
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
@abc.abstractmethod
|
| 15 |
+
async def retrieve_model_metadata(self, hotkey: str) -> Optional[ModelMetadata]:
|
| 16 |
+
"""Retrieves model metadata + block information on this subnet for specific miner, if present"""
|
| 17 |
+
pass
|
model/storage/mysql_model_queue.py
ADDED
|
@@ -0,0 +1,847 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sqlalchemy import create_engine, Column, Integer, Float, String, DateTime, Boolean, JSON, func, desc, exists, ForeignKey, or_, and_, case, text
|
| 2 |
+
from sqlalchemy.exc import OperationalError, SQLAlchemyError
|
| 3 |
+
from sqlalchemy.ext.declarative import declarative_base
|
| 4 |
+
from sqlalchemy.dialects.mysql import JSON as MySQLJSON
|
| 5 |
+
from sqlalchemy.orm import Session, sessionmaker, aliased, relationship
|
| 6 |
+
from contextlib import contextmanager
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
|
| 9 |
+
import time
|
| 10 |
+
import json
|
| 11 |
+
|
| 12 |
+
from datetime import datetime, timedelta, timezone
|
| 13 |
+
import bittensor as bt
|
| 14 |
+
from typing import Optional
|
| 15 |
+
|
| 16 |
+
from model.data import ModelId
|
| 17 |
+
from vali_api.config import DBHOST, DBNAME, DBUSER, DBPASS, IS_PROD
|
| 18 |
+
|
| 19 |
+
Base = declarative_base()
|
| 20 |
+
|
| 21 |
+
# Global variables for engine and Session
|
| 22 |
+
_engine: Optional[object] = None
|
| 23 |
+
Session: Optional[sessionmaker] = None
|
| 24 |
+
|
| 25 |
+
def init_database():
|
| 26 |
+
"""
|
| 27 |
+
Initialize the database connection and create tables.
|
| 28 |
+
Must be called before using any database operations.
|
| 29 |
+
"""
|
| 30 |
+
global _engine, Session
|
| 31 |
+
|
| 32 |
+
if _engine is not None:
|
| 33 |
+
bt.logging.warning("Database already initialized")
|
| 34 |
+
return
|
| 35 |
+
|
| 36 |
+
# Try different MySQL drivers in order of preference
|
| 37 |
+
drivers_to_try = [
|
| 38 |
+
('mysql', 'mysqlclient (MySQLdb)'),
|
| 39 |
+
('mysql+pymysql', 'PyMySQL')
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
for driver, driver_name in drivers_to_try:
|
| 43 |
+
try:
|
| 44 |
+
connection_string = f'{driver}://{DBUSER}:{DBPASS}@{DBHOST}/{DBNAME}'
|
| 45 |
+
bt.logging.info(f"Attempting database connection with {driver_name}")
|
| 46 |
+
|
| 47 |
+
_engine = create_engine(connection_string)
|
| 48 |
+
Session = sessionmaker(bind=_engine)
|
| 49 |
+
|
| 50 |
+
# Test the connection
|
| 51 |
+
with Session() as session:
|
| 52 |
+
session.execute(text('SELECT 1'))
|
| 53 |
+
|
| 54 |
+
# Create all tables
|
| 55 |
+
Base.metadata.create_all(_engine)
|
| 56 |
+
bt.logging.info(f"Database initialized successfully with {driver_name}")
|
| 57 |
+
return
|
| 58 |
+
|
| 59 |
+
except ImportError as e:
|
| 60 |
+
bt.logging.warning(f"Driver {driver_name} not available: {e}")
|
| 61 |
+
continue
|
| 62 |
+
except Exception as e:
|
| 63 |
+
bt.logging.error(f"Failed to connect with {driver_name}: {e}")
|
| 64 |
+
if driver == drivers_to_try[-1][0]: # Last driver in list
|
| 65 |
+
raise
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
raise RuntimeError("Failed to initialize database with any available MySQL driver")
|
| 69 |
+
|
| 70 |
+
def get_session() -> Session:
|
| 71 |
+
"""
|
| 72 |
+
Get a database session. Raises exception if database not initialized.
|
| 73 |
+
"""
|
| 74 |
+
if Session is None:
|
| 75 |
+
raise RuntimeError("Database not initialized. Call init_database() first.")
|
| 76 |
+
return Session()
|
| 77 |
+
|
| 78 |
+
def get_table_name(base_name: str) -> str:
|
| 79 |
+
"""Helper function to get the correct table name with suffix if not in production."""
|
| 80 |
+
return f"{base_name}{'_test' if not IS_PROD else ''}"
|
| 81 |
+
|
| 82 |
+
class ModelQueue(Base):
|
| 83 |
+
__tablename__ = get_table_name('sn21_model_queue')
|
| 84 |
+
|
| 85 |
+
hotkey = Column(String(255), primary_key=True)
|
| 86 |
+
uid = Column(String(255), primary_key=True, index=True)
|
| 87 |
+
block = Column(Integer, index=True)
|
| 88 |
+
competition_id = Column(String(255), index=True)
|
| 89 |
+
model_metadata = Column(JSON)
|
| 90 |
+
is_new = Column(Boolean, default=True)
|
| 91 |
+
is_being_scored = Column(Boolean, default=False)
|
| 92 |
+
is_being_scored_by = Column(String(255), default=None)
|
| 93 |
+
scoring_updated_at = Column(DateTime, default=None)
|
| 94 |
+
updated_at = Column(DateTime, default=datetime.utcnow)
|
| 95 |
+
|
| 96 |
+
# Relationship to use dynamic table name (lambda function)
|
| 97 |
+
scores = relationship(
|
| 98 |
+
"ScoreHistory",
|
| 99 |
+
back_populates="model",
|
| 100 |
+
foreign_keys="[ScoreHistory.hotkey, ScoreHistory.uid]",
|
| 101 |
+
primaryjoin=lambda: and_(
|
| 102 |
+
ModelQueue.hotkey == ScoreHistory.hotkey,
|
| 103 |
+
ModelQueue.uid == ScoreHistory.uid
|
| 104 |
+
)
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def __repr__(self):
|
| 108 |
+
return f"<ModelQueue(hotkey='{self.hotkey}', uid='{self.uid}', competition_id='{self.competition_id}', is_new={self.is_new})>"
|
| 109 |
+
|
| 110 |
+
class ScoreHistory(Base):
|
| 111 |
+
__tablename__ = get_table_name('sn21_score_history')
|
| 112 |
+
|
| 113 |
+
id = Column(Integer, primary_key=True)
|
| 114 |
+
hotkey = Column(String(255), ForeignKey(f"{get_table_name('sn21_model_queue')}.hotkey", ondelete='SET NULL'), index=True, nullable=True)
|
| 115 |
+
uid = Column(String(255), ForeignKey(f"{get_table_name('sn21_model_queue')}.uid", ondelete='SET NULL'), index=True, nullable=True)
|
| 116 |
+
competition_id = Column(String(255), index=True)
|
| 117 |
+
model_metadata = Column(JSON)
|
| 118 |
+
score = Column(Float)
|
| 119 |
+
scored_at = Column(DateTime, default=datetime.utcnow)
|
| 120 |
+
block = Column(Integer)
|
| 121 |
+
model_hash = Column(String(255))
|
| 122 |
+
scorer_hotkey = Column(String(255), index=True)
|
| 123 |
+
is_archived = Column(Boolean, default=False)
|
| 124 |
+
metric_scores = Column(MySQLJSON, nullable=True)
|
| 125 |
+
wandb_run_id = Column(String(255), nullable=True)
|
| 126 |
+
wandb_run_url = Column(String(512), nullable=True)
|
| 127 |
+
# Relationship to ModelQueue using dynamic table name (lambda function)
|
| 128 |
+
model = relationship(
|
| 129 |
+
"ModelQueue",
|
| 130 |
+
back_populates="scores",
|
| 131 |
+
foreign_keys=[hotkey, uid],
|
| 132 |
+
primaryjoin=lambda: and_(
|
| 133 |
+
ModelQueue.hotkey == ScoreHistory.hotkey,
|
| 134 |
+
ModelQueue.uid == ScoreHistory.uid
|
| 135 |
+
)
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def __repr__(self):
|
| 139 |
+
return f"<ScoreHistory(hotkey='{self.hotkey}', uid='{self.uid}', score={self.score}, scored_at={self.scored_at}, model_metadata={self.model_metadata} is_archived={self.is_archived})>"
|
| 140 |
+
|
| 141 |
+
class ModelIdEncoder(json.JSONEncoder):
|
| 142 |
+
def default(self, obj):
|
| 143 |
+
if isinstance(obj, ModelId):
|
| 144 |
+
return {
|
| 145 |
+
'namespace': obj.namespace,
|
| 146 |
+
'name': obj.name,
|
| 147 |
+
'epoch': obj.epoch,
|
| 148 |
+
'commit': obj.commit,
|
| 149 |
+
'hash': obj.hash,
|
| 150 |
+
'competition_id': obj.competition_id
|
| 151 |
+
}
|
| 152 |
+
return super().default(obj)
|
| 153 |
+
|
| 154 |
+
class ModelQueueManager:
|
| 155 |
+
def __init__(self, max_scores_per_model=5, rescore_interval_hours=24, max_retries=3, retry_delay=1):
|
| 156 |
+
if Session is None:
|
| 157 |
+
raise RuntimeError("Database not initialized. Call init_database() first.")
|
| 158 |
+
|
| 159 |
+
self.session = get_session()
|
| 160 |
+
self.max_scores_per_model = max_scores_per_model
|
| 161 |
+
self.rescore_interval = timedelta(hours=rescore_interval_hours)
|
| 162 |
+
self.max_retries = max_retries
|
| 163 |
+
self.retry_delay = retry_delay
|
| 164 |
+
|
| 165 |
+
@contextmanager
|
| 166 |
+
def session_scope(self):
|
| 167 |
+
"""Provide a transactional scope around a series of operations."""
|
| 168 |
+
session = get_session()
|
| 169 |
+
try:
|
| 170 |
+
yield session
|
| 171 |
+
session.commit()
|
| 172 |
+
except Exception as e:
|
| 173 |
+
session.rollback()
|
| 174 |
+
raise
|
| 175 |
+
finally:
|
| 176 |
+
session.close()
|
| 177 |
+
|
| 178 |
+
def reset_session(self):
|
| 179 |
+
"""Reset the session in case of connection issues."""
|
| 180 |
+
try:
|
| 181 |
+
self.session.close()
|
| 182 |
+
except:
|
| 183 |
+
pass
|
| 184 |
+
try:
|
| 185 |
+
self.session = get_session()
|
| 186 |
+
except RuntimeError as e:
|
| 187 |
+
bt.logging.error(f"Failed to reset session: {str(e)}")
|
| 188 |
+
raise
|
| 189 |
+
|
| 190 |
+
def execute_with_retry(self, operation, *args, **kwargs):
|
| 191 |
+
"""Execute an operation with retry logic."""
|
| 192 |
+
for attempt in range(self.max_retries):
|
| 193 |
+
try:
|
| 194 |
+
return operation(*args, **kwargs)
|
| 195 |
+
except OperationalError as e:
|
| 196 |
+
if "Lost connection" in str(e) and attempt < self.max_retries - 1:
|
| 197 |
+
bt.logging.warning(f"Lost connection to MySQL. Attempt {attempt + 1}/{self.max_retries}. Retrying...")
|
| 198 |
+
self.reset_session()
|
| 199 |
+
time.sleep(self.retry_delay)
|
| 200 |
+
else:
|
| 201 |
+
raise
|
| 202 |
+
except SQLAlchemyError as e:
|
| 203 |
+
if attempt < self.max_retries - 1:
|
| 204 |
+
bt.logging.warning(f"Database error. Attempt {attempt + 1}/{self.max_retries}. Retrying...")
|
| 205 |
+
self.reset_session()
|
| 206 |
+
time.sleep(self.retry_delay)
|
| 207 |
+
else:
|
| 208 |
+
raise
|
| 209 |
+
|
| 210 |
+
def store_updated_model(self, uid, hotkey, model_metadata, updated):
|
| 211 |
+
"""
|
| 212 |
+
Store or update model metadata with retry logic.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
uid (str): Model UID
|
| 216 |
+
hotkey (str): Model hotkey
|
| 217 |
+
model_metadata: Model metadata object
|
| 218 |
+
updated (bool): Whether this is an update
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
bool: Success status
|
| 222 |
+
"""
|
| 223 |
+
def _store_model():
|
| 224 |
+
with self.session_scope() as session:
|
| 225 |
+
try:
|
| 226 |
+
# Query existing model with lock
|
| 227 |
+
existing_model = session.query(ModelQueue).filter_by(
|
| 228 |
+
hotkey=hotkey,
|
| 229 |
+
uid=uid
|
| 230 |
+
).with_for_update().first()
|
| 231 |
+
|
| 232 |
+
# Serialize metadata
|
| 233 |
+
serialized_metadata = json.dumps(model_metadata.__dict__, cls=ModelIdEncoder)
|
| 234 |
+
|
| 235 |
+
if existing_model:
|
| 236 |
+
if existing_model.model_metadata != serialized_metadata or existing_model.block != model_metadata.block:
|
| 237 |
+
bt.logging.debug(f"Updating existing model metadata for UID={uid}, Hotkey={hotkey}. Old metadata: {existing_model.model_metadata}, New metadata: {serialized_metadata}")
|
| 238 |
+
existing_model.model_metadata = serialized_metadata
|
| 239 |
+
existing_model.is_new = True
|
| 240 |
+
existing_model.block = model_metadata.block
|
| 241 |
+
existing_model.updated_at = datetime.utcnow()
|
| 242 |
+
else:
|
| 243 |
+
# Create new model entry
|
| 244 |
+
new_model = ModelQueue(
|
| 245 |
+
hotkey=hotkey,
|
| 246 |
+
uid=uid,
|
| 247 |
+
competition_id=model_metadata.id.competition_id,
|
| 248 |
+
model_metadata=serialized_metadata,
|
| 249 |
+
is_new=True,
|
| 250 |
+
block=model_metadata.block
|
| 251 |
+
)
|
| 252 |
+
session.add(new_model)
|
| 253 |
+
bt.logging.debug(f"Stored new model for UID={uid}, Hotkey={hotkey} in database. Is new = {updated}")
|
| 254 |
+
|
| 255 |
+
return True
|
| 256 |
+
|
| 257 |
+
except Exception as e:
|
| 258 |
+
bt.logging.error(f"Error in _store_model: {str(e)}")
|
| 259 |
+
bt.logging.error(f"Model metadata: {model_metadata}")
|
| 260 |
+
raise
|
| 261 |
+
|
| 262 |
+
try:
|
| 263 |
+
return self.execute_with_retry(_store_model)
|
| 264 |
+
except Exception as e:
|
| 265 |
+
bt.logging.error(f"Failed to store model after {self.max_retries} attempts: {str(e)}")
|
| 266 |
+
return False
|
| 267 |
+
|
| 268 |
+
def get_next_model_to_score(self, competition_id: str):
|
| 269 |
+
"""
|
| 270 |
+
Get next model to score with retry logic.
|
| 271 |
+
|
| 272 |
+
The updated prioritization logic ensures:
|
| 273 |
+
1. New models (highest priority)
|
| 274 |
+
2. Models never scored with non-zero scores
|
| 275 |
+
3a. High-scoring models not scored for over a week
|
| 276 |
+
3b. Models not scored for more than 7 days (safety net for winning models)
|
| 277 |
+
4. Models eligible by standard criteria (not scored in 5 days or < 5 scores)
|
| 278 |
+
5. Everything else (lowest priority)
|
| 279 |
+
|
| 280 |
+
Zero-scored models that are frequently scored are downgraded in priority
|
| 281 |
+
to prevent them from consuming too many resources.
|
| 282 |
+
"""
|
| 283 |
+
def _get_next_model():
|
| 284 |
+
with self.session_scope() as session:
|
| 285 |
+
try:
|
| 286 |
+
now = datetime.utcnow()
|
| 287 |
+
|
| 288 |
+
# ---- START: Query to find overall highest score ----
|
| 289 |
+
overall_max_score_value = session.query(func.max(ScoreHistory.score)).filter(
|
| 290 |
+
ScoreHistory.competition_id == competition_id,
|
| 291 |
+
ScoreHistory.is_archived == False,
|
| 292 |
+
ScoreHistory.score > 0 # Consider only positive scores as relevant for "highest"
|
| 293 |
+
).scalar()
|
| 294 |
+
|
| 295 |
+
if overall_max_score_value is not None:
|
| 296 |
+
bt.logging.info(f"Overall highest positive score in competition '{competition_id}' is: {overall_max_score_value:.4f}")
|
| 297 |
+
else:
|
| 298 |
+
bt.logging.info(f"No positive scores found for competition '{competition_id}' to determine an overall highest score.")
|
| 299 |
+
# ---- END: Query ----
|
| 300 |
+
|
| 301 |
+
# Get latest score timestamp and count for each model
|
| 302 |
+
score_subquery = session.query(
|
| 303 |
+
ScoreHistory.hotkey,
|
| 304 |
+
ScoreHistory.uid,
|
| 305 |
+
func.count(ScoreHistory.id).label('score_count'),
|
| 306 |
+
func.max(ScoreHistory.scored_at).label('latest_scored_at'), # Get the latest score timestamp
|
| 307 |
+
func.max(ScoreHistory.score).label('max_score') # Get the maximum score
|
| 308 |
+
).filter(
|
| 309 |
+
ScoreHistory.is_archived == False,
|
| 310 |
+
ScoreHistory.competition_id == competition_id,
|
| 311 |
+
ScoreHistory.score > 0 # Only consider non-zero scores
|
| 312 |
+
).group_by(
|
| 313 |
+
ScoreHistory.hotkey,
|
| 314 |
+
ScoreHistory.uid
|
| 315 |
+
).subquery()
|
| 316 |
+
|
| 317 |
+
# Also track all scores (including zeros) for high-frequency zero score detection
|
| 318 |
+
all_scores_subquery = session.query(
|
| 319 |
+
ScoreHistory.hotkey,
|
| 320 |
+
ScoreHistory.uid,
|
| 321 |
+
func.count(ScoreHistory.id).label('all_score_count'),
|
| 322 |
+
func.max(ScoreHistory.scored_at).label('latest_all_scored_at'),
|
| 323 |
+
func.sum(case((ScoreHistory.score > 0, 1), else_=0)).label('non_zero_count')
|
| 324 |
+
).filter(
|
| 325 |
+
ScoreHistory.is_archived == False,
|
| 326 |
+
ScoreHistory.competition_id == competition_id
|
| 327 |
+
).group_by(
|
| 328 |
+
ScoreHistory.hotkey,
|
| 329 |
+
ScoreHistory.uid
|
| 330 |
+
).subquery()
|
| 331 |
+
|
| 332 |
+
five_days_ago = now - timedelta(days=5)
|
| 333 |
+
weekly_rescore_threshold_time = now - timedelta(days=7) # Define a 7-day threshold
|
| 334 |
+
|
| 335 |
+
# Check if we have new models before proceeding
|
| 336 |
+
have_new_models = session.query(ModelQueue).filter(
|
| 337 |
+
ModelQueue.is_being_scored == False,
|
| 338 |
+
ModelQueue.competition_id == competition_id,
|
| 339 |
+
ModelQueue.is_new == True
|
| 340 |
+
).first() is not None
|
| 341 |
+
|
| 342 |
+
# Check if we have never-scored models
|
| 343 |
+
never_scored_count = session.query(func.count(ModelQueue.uid)).filter(
|
| 344 |
+
ModelQueue.is_being_scored == False,
|
| 345 |
+
ModelQueue.competition_id == competition_id,
|
| 346 |
+
~exists().where(
|
| 347 |
+
and_(
|
| 348 |
+
ScoreHistory.hotkey == ModelQueue.hotkey,
|
| 349 |
+
ScoreHistory.uid == ModelQueue.uid,
|
| 350 |
+
ScoreHistory.score > 0
|
| 351 |
+
)
|
| 352 |
+
)
|
| 353 |
+
).scalar()
|
| 354 |
+
|
| 355 |
+
# If no new models and no never-scored models, prioritize high scoring models not scored recently
|
| 356 |
+
if not have_new_models:
|
| 357 |
+
# ---- START: Modified logic for dynamic high-score threshold ----
|
| 358 |
+
if overall_max_score_value is not None and overall_max_score_value > 0: # Ensure we have a valid max score
|
| 359 |
+
dynamic_high_score_threshold = overall_max_score_value * 0.97
|
| 360 |
+
bt.logging.info(f"Using dynamic high-score threshold for competition '{competition_id}': >= {dynamic_high_score_threshold:.4f} (based on overall max of {overall_max_score_value:.4f})")
|
| 361 |
+
|
| 362 |
+
# First try to get a high-scoring model not scored in over a week
|
| 363 |
+
top_model = session.query(ModelQueue).join(
|
| 364 |
+
score_subquery,
|
| 365 |
+
and_(
|
| 366 |
+
ModelQueue.hotkey == score_subquery.c.hotkey,
|
| 367 |
+
ModelQueue.uid == score_subquery.c.uid
|
| 368 |
+
)
|
| 369 |
+
).filter(
|
| 370 |
+
ModelQueue.is_being_scored == False,
|
| 371 |
+
ModelQueue.competition_id == competition_id,
|
| 372 |
+
score_subquery.c.latest_scored_at < weekly_rescore_threshold_time,
|
| 373 |
+
score_subquery.c.max_score >= dynamic_high_score_threshold # Use dynamic threshold
|
| 374 |
+
).order_by(
|
| 375 |
+
score_subquery.c.max_score.desc() # Highest score first
|
| 376 |
+
).with_for_update().first()
|
| 377 |
+
|
| 378 |
+
if top_model:
|
| 379 |
+
# Create a dictionary with the model's attributes
|
| 380 |
+
model_data = {
|
| 381 |
+
'hotkey': top_model.hotkey,
|
| 382 |
+
'uid': top_model.uid,
|
| 383 |
+
'block': top_model.block,
|
| 384 |
+
'competition_id': top_model.competition_id,
|
| 385 |
+
'model_metadata': top_model.model_metadata,
|
| 386 |
+
'is_new': top_model.is_new,
|
| 387 |
+
'is_being_scored': top_model.is_being_scored,
|
| 388 |
+
'is_being_scored_by': top_model.is_being_scored_by,
|
| 389 |
+
'scoring_updated_at': top_model.scoring_updated_at,
|
| 390 |
+
'updated_at': top_model.updated_at
|
| 391 |
+
}
|
| 392 |
+
bt.logging.debug(f"Found high-scoring model (dynamic threshold) to score: hotkey={model_data['hotkey']}, uid={model_data['uid']}")
|
| 393 |
+
return model_data
|
| 394 |
+
else:
|
| 395 |
+
bt.logging.info(f"Skipping dynamic high-score prioritization for competition '{competition_id}' as no overall positive max score is available or it's zero.")
|
| 396 |
+
# ---- END: Modified logic ----
|
| 397 |
+
|
| 398 |
+
# Otherwise, use the standard prioritization logic with the zero-score detection
|
| 399 |
+
next_model = session.query(ModelQueue).outerjoin(
|
| 400 |
+
score_subquery,
|
| 401 |
+
and_(
|
| 402 |
+
ModelQueue.hotkey == score_subquery.c.hotkey,
|
| 403 |
+
ModelQueue.uid == score_subquery.c.uid
|
| 404 |
+
)
|
| 405 |
+
).outerjoin(
|
| 406 |
+
all_scores_subquery,
|
| 407 |
+
and_(
|
| 408 |
+
ModelQueue.hotkey == all_scores_subquery.c.hotkey,
|
| 409 |
+
ModelQueue.uid == all_scores_subquery.c.uid
|
| 410 |
+
)
|
| 411 |
+
).filter(
|
| 412 |
+
ModelQueue.is_being_scored == False,
|
| 413 |
+
ModelQueue.competition_id == competition_id
|
| 414 |
+
).order_by(
|
| 415 |
+
desc(ModelQueue.is_new), # 1. Prioritize new models
|
| 416 |
+
(score_subquery.c.score_count == None).desc(), # 2. Prioritize models never scored (non-zero)
|
| 417 |
+
case( # 3. Prioritize models not scored for more than 7 days (safety net)
|
| 418 |
+
(and_(score_subquery.c.latest_scored_at != None, score_subquery.c.latest_scored_at < weekly_rescore_threshold_time), 0),
|
| 419 |
+
else_=1
|
| 420 |
+
),
|
| 421 |
+
# 4. Decrease priority for models with all zero scores and frequent scoring
|
| 422 |
+
case(
|
| 423 |
+
(and_(
|
| 424 |
+
all_scores_subquery.c.all_score_count > 10, # Has many scores
|
| 425 |
+
all_scores_subquery.c.non_zero_count == 0, # All scores are zero
|
| 426 |
+
all_scores_subquery.c.latest_all_scored_at > five_days_ago # Scored recently
|
| 427 |
+
), 1),
|
| 428 |
+
else_=0
|
| 429 |
+
),
|
| 430 |
+
case( # 5. Prioritize models eligible by standard criteria
|
| 431 |
+
(or_(
|
| 432 |
+
score_subquery.c.latest_scored_at == None,
|
| 433 |
+
score_subquery.c.latest_scored_at <= five_days_ago,
|
| 434 |
+
score_subquery.c.score_count < 5
|
| 435 |
+
), 0),
|
| 436 |
+
else_=1
|
| 437 |
+
),
|
| 438 |
+
func.rand() # 6. Random tie-breaker
|
| 439 |
+
).with_for_update().first()
|
| 440 |
+
|
| 441 |
+
if next_model:
|
| 442 |
+
# Create a dictionary with the model's attributes
|
| 443 |
+
model_data = {
|
| 444 |
+
'hotkey': next_model.hotkey,
|
| 445 |
+
'uid': next_model.uid,
|
| 446 |
+
'block': next_model.block,
|
| 447 |
+
'competition_id': next_model.competition_id,
|
| 448 |
+
'model_metadata': next_model.model_metadata,
|
| 449 |
+
'is_new': next_model.is_new,
|
| 450 |
+
'is_being_scored': next_model.is_being_scored,
|
| 451 |
+
'is_being_scored_by': next_model.is_being_scored_by,
|
| 452 |
+
'scoring_updated_at': next_model.scoring_updated_at,
|
| 453 |
+
'updated_at': next_model.updated_at
|
| 454 |
+
}
|
| 455 |
+
bt.logging.debug(f"Found next model to score: hotkey={model_data['hotkey']}, uid={model_data['uid']}")
|
| 456 |
+
return model_data
|
| 457 |
+
else:
|
| 458 |
+
bt.logging.debug("No models available for scoring")
|
| 459 |
+
return None
|
| 460 |
+
|
| 461 |
+
except Exception as e:
|
| 462 |
+
bt.logging.error(f"Error in _get_next_model: {str(e)}")
|
| 463 |
+
raise
|
| 464 |
+
|
| 465 |
+
try:
|
| 466 |
+
return self.execute_with_retry(_get_next_model)
|
| 467 |
+
except Exception as e:
|
| 468 |
+
bt.logging.error(f"Failed to get next model after {self.max_retries} attempts: {str(e)}")
|
| 469 |
+
return None
|
| 470 |
+
|
| 471 |
+
def mark_model_as_being_scored(self, model_hotkey, model_uid, scorer_hotkey):
|
| 472 |
+
"""Mark model as being scored with retry logic."""
|
| 473 |
+
def _mark_model():
|
| 474 |
+
with self.session_scope() as session:
|
| 475 |
+
model = session.query(ModelQueue).filter_by(
|
| 476 |
+
hotkey=model_hotkey,
|
| 477 |
+
uid=model_uid
|
| 478 |
+
).with_for_update().first()
|
| 479 |
+
|
| 480 |
+
if model and not model.is_being_scored:
|
| 481 |
+
model.is_being_scored = True
|
| 482 |
+
model.is_being_scored_by = scorer_hotkey
|
| 483 |
+
model.scoring_updated_at = datetime.utcnow()
|
| 484 |
+
return True
|
| 485 |
+
return False
|
| 486 |
+
|
| 487 |
+
try:
|
| 488 |
+
return self.execute_with_retry(_mark_model)
|
| 489 |
+
except Exception as e:
|
| 490 |
+
bt.logging.error(f"Failed to mark model as being scored after {self.max_retries} attempts: {str(e)}")
|
| 491 |
+
return False
|
| 492 |
+
|
| 493 |
+
def submit_score(self, model_hotkey, model_uid, scorer_hotkey, model_hash, score, metric_scores):
|
| 494 |
+
"""Submit score with retry logic. Mark the model in queue as scored. Remove from queue."""
|
| 495 |
+
def _submit_score():
|
| 496 |
+
with self.session_scope() as session:
|
| 497 |
+
try:
|
| 498 |
+
model = session.query(ModelQueue).filter_by(
|
| 499 |
+
hotkey=model_hotkey,
|
| 500 |
+
uid=model_uid
|
| 501 |
+
).with_for_update().first()
|
| 502 |
+
|
| 503 |
+
if not model:
|
| 504 |
+
bt.logging.error(f"No model found for hotkey {model_hotkey} and uid {model_uid}")
|
| 505 |
+
return False
|
| 506 |
+
|
| 507 |
+
"""
|
| 508 |
+
# temporarily allow scoring from any hotkey
|
| 509 |
+
new_score = ScoreHistory(
|
| 510 |
+
hotkey=model_hotkey,
|
| 511 |
+
uid=model_uid,
|
| 512 |
+
competition_id=model.competition_id,
|
| 513 |
+
score=score,
|
| 514 |
+
block=model.block,
|
| 515 |
+
model_hash=model_hash,
|
| 516 |
+
scorer_hotkey=scorer_hotkey,
|
| 517 |
+
model_metadata=model.model_metadata
|
| 518 |
+
)
|
| 519 |
+
session.add(new_score)
|
| 520 |
+
model.is_new = False
|
| 521 |
+
model.is_being_scored = False
|
| 522 |
+
model.is_being_scored_by = None
|
| 523 |
+
model.scoring_updated_at = None
|
| 524 |
+
model.updated_at = datetime.now(timezone.utc)
|
| 525 |
+
bt.logging.info(f"Successfully submitted score for model {model_hotkey} by {scorer_hotkey}")
|
| 526 |
+
return True
|
| 527 |
+
"""
|
| 528 |
+
|
| 529 |
+
if model.is_being_scored and model.is_being_scored_by == scorer_hotkey:
|
| 530 |
+
# Extract wandb fields from metric_scores if present
|
| 531 |
+
wandb_run_id = None
|
| 532 |
+
wandb_run_url = None
|
| 533 |
+
if metric_scores and isinstance(metric_scores, dict):
|
| 534 |
+
wandb_run_id = metric_scores.get('wandb_run_id')
|
| 535 |
+
wandb_run_url = metric_scores.get('wandb_run_url')
|
| 536 |
+
|
| 537 |
+
new_score = ScoreHistory(
|
| 538 |
+
hotkey=model_hotkey,
|
| 539 |
+
uid=model_uid,
|
| 540 |
+
competition_id=model.competition_id,
|
| 541 |
+
score=score,
|
| 542 |
+
block=model.block,
|
| 543 |
+
model_hash=model_hash,
|
| 544 |
+
scorer_hotkey=scorer_hotkey,
|
| 545 |
+
model_metadata=model.model_metadata,
|
| 546 |
+
metric_scores=metric_scores,
|
| 547 |
+
wandb_run_id=wandb_run_id,
|
| 548 |
+
wandb_run_url=wandb_run_url
|
| 549 |
+
)
|
| 550 |
+
session.add(new_score)
|
| 551 |
+
model.is_new = False
|
| 552 |
+
model.is_being_scored = False
|
| 553 |
+
model.is_being_scored_by = None
|
| 554 |
+
model.scoring_updated_at = None
|
| 555 |
+
model.updated_at = datetime.now(timezone.utc)
|
| 556 |
+
bt.logging.info(f"Successfully submitted score for model {model_hotkey} by {scorer_hotkey}")
|
| 557 |
+
return True
|
| 558 |
+
else:
|
| 559 |
+
bt.logging.error(f"Failed to submit score for model {model_hotkey} by {scorer_hotkey}. "
|
| 560 |
+
f"Model: {model}, is_being_scored: {model.is_being_scored}, "
|
| 561 |
+
f"is_being_scored_by: {model.is_being_scored_by}")
|
| 562 |
+
return False
|
| 563 |
+
|
| 564 |
+
except Exception as e:
|
| 565 |
+
bt.logging.error(f"Error in _submit_score: {str(e)}")
|
| 566 |
+
raise
|
| 567 |
+
|
| 568 |
+
try:
|
| 569 |
+
return self.execute_with_retry(_submit_score)
|
| 570 |
+
except Exception as e:
|
| 571 |
+
bt.logging.error(f"Failed to submit score after {self.max_retries} attempts: {str(e)}")
|
| 572 |
+
return False
|
| 573 |
+
|
| 574 |
+
def reset_stale_scoring_tasks(self, max_scoring_time_minutes=15):
|
| 575 |
+
"""Reset stale scoring tasks with retry logic."""
|
| 576 |
+
def _reset_stale_tasks():
|
| 577 |
+
with self.session_scope() as session:
|
| 578 |
+
try:
|
| 579 |
+
stale_time = datetime.utcnow() - timedelta(minutes=max_scoring_time_minutes)
|
| 580 |
+
stale_models = session.query(ModelQueue).filter(
|
| 581 |
+
ModelQueue.is_being_scored == True,
|
| 582 |
+
ModelQueue.scoring_updated_at < stale_time
|
| 583 |
+
).with_for_update().all()
|
| 584 |
+
|
| 585 |
+
reset_count = 0
|
| 586 |
+
for model in stale_models:
|
| 587 |
+
model.is_being_scored = False
|
| 588 |
+
model.is_being_scored_by = None
|
| 589 |
+
model.scoring_updated_at = None
|
| 590 |
+
reset_count += 1
|
| 591 |
+
bt.logging.info(f"Reset scoring task for stale model: hotkey={model.hotkey}, uid={model.uid}")
|
| 592 |
+
|
| 593 |
+
return reset_count
|
| 594 |
+
|
| 595 |
+
except Exception as e:
|
| 596 |
+
bt.logging.error(f"Error in _reset_stale_tasks: {str(e)}")
|
| 597 |
+
raise
|
| 598 |
+
|
| 599 |
+
try:
|
| 600 |
+
return self.execute_with_retry(_reset_stale_tasks)
|
| 601 |
+
except Exception as e:
|
| 602 |
+
bt.logging.error(f"Failed to reset stale tasks after {self.max_retries} attempts: {str(e)}")
|
| 603 |
+
return 0
|
| 604 |
+
|
| 605 |
+
def get_recent_model_scores(self, scores_per_model):
|
| 606 |
+
"""
|
| 607 |
+
Get recent scores for all models.
|
| 608 |
+
|
| 609 |
+
Args:
|
| 610 |
+
scores_per_model (int): Number of recent scores to fetch per model
|
| 611 |
+
|
| 612 |
+
Returns:
|
| 613 |
+
dict: Dictionary of model scores grouped by UID
|
| 614 |
+
"""
|
| 615 |
+
def _get_recent_scores():
|
| 616 |
+
with self.session_scope() as session:
|
| 617 |
+
try:
|
| 618 |
+
# First, create a subquery that ranks scores by timestamp for each model
|
| 619 |
+
ranked_scores = (
|
| 620 |
+
session.query(
|
| 621 |
+
ScoreHistory,
|
| 622 |
+
func.row_number().over(
|
| 623 |
+
partition_by=(ScoreHistory.hotkey, ScoreHistory.uid),
|
| 624 |
+
order_by=desc(ScoreHistory.scored_at)
|
| 625 |
+
).label('score_rank')
|
| 626 |
+
)
|
| 627 |
+
.filter(ScoreHistory.is_archived == False)
|
| 628 |
+
.filter(ScoreHistory.score != 0)
|
| 629 |
+
.subquery()
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
# Get the most recent scores for each model
|
| 633 |
+
recent_scores = session.query(ranked_scores).filter(
|
| 634 |
+
ranked_scores.c.score_rank <= scores_per_model
|
| 635 |
+
).subquery('recent_scores')
|
| 636 |
+
|
| 637 |
+
# Join with ModelQueue to get additional model information
|
| 638 |
+
results = session.query(
|
| 639 |
+
ModelQueue.uid,
|
| 640 |
+
ModelQueue.hotkey,
|
| 641 |
+
ModelQueue.competition_id,
|
| 642 |
+
ModelQueue.model_metadata,
|
| 643 |
+
recent_scores.c.score,
|
| 644 |
+
recent_scores.c.scored_at,
|
| 645 |
+
recent_scores.c.block,
|
| 646 |
+
recent_scores.c.model_hash,
|
| 647 |
+
recent_scores.c.scorer_hotkey,
|
| 648 |
+
recent_scores.c.score_rank
|
| 649 |
+
).outerjoin(
|
| 650 |
+
recent_scores,
|
| 651 |
+
and_(
|
| 652 |
+
ModelQueue.hotkey == recent_scores.c.hotkey,
|
| 653 |
+
ModelQueue.uid == recent_scores.c.uid,
|
| 654 |
+
)
|
| 655 |
+
).order_by(
|
| 656 |
+
ModelQueue.uid,
|
| 657 |
+
ModelQueue.hotkey,
|
| 658 |
+
recent_scores.c.scored_at.desc()
|
| 659 |
+
).all()
|
| 660 |
+
|
| 661 |
+
scores_by_uid = defaultdict(lambda: defaultdict(list))
|
| 662 |
+
|
| 663 |
+
for result in results:
|
| 664 |
+
if result.score is not None:
|
| 665 |
+
# Create a unique key for each hotkey+uid combination
|
| 666 |
+
model_key = f"{result.hotkey}_{result.uid}"
|
| 667 |
+
|
| 668 |
+
scores_by_uid[result.uid][model_key].append({
|
| 669 |
+
'hotkey': result.hotkey,
|
| 670 |
+
'competition_id': result.competition_id,
|
| 671 |
+
'model_metadata': result.model_metadata,
|
| 672 |
+
'score': result.score,
|
| 673 |
+
'scored_at': result.scored_at.isoformat() if result.scored_at else None,
|
| 674 |
+
'block': result.block,
|
| 675 |
+
'model_hash': result.model_hash,
|
| 676 |
+
'scorer_hotkey': result.scorer_hotkey,
|
| 677 |
+
'rank': result.score_rank
|
| 678 |
+
})
|
| 679 |
+
else:
|
| 680 |
+
# Handle models with no scores
|
| 681 |
+
model_key = f"{result.hotkey}_{result.uid}"
|
| 682 |
+
if not scores_by_uid[result.uid][model_key]: # Only add if no scores exist
|
| 683 |
+
scores_by_uid[result.uid][model_key].append({
|
| 684 |
+
'hotkey': result.hotkey,
|
| 685 |
+
'competition_id': None,
|
| 686 |
+
'model_metadata': result.model_metadata,
|
| 687 |
+
'score': None,
|
| 688 |
+
'scored_at': None,
|
| 689 |
+
'block': None,
|
| 690 |
+
'model_hash': None,
|
| 691 |
+
'scorer_hotkey': None,
|
| 692 |
+
'rank': None
|
| 693 |
+
})
|
| 694 |
+
|
| 695 |
+
# Convert defaultdict to regular dict for return
|
| 696 |
+
return {
|
| 697 |
+
uid: dict(models)
|
| 698 |
+
for uid, models in scores_by_uid.items()
|
| 699 |
+
}
|
| 700 |
+
|
| 701 |
+
except Exception as e:
|
| 702 |
+
bt.logging.error(f"Error in _get_recent_scores: {str(e)}")
|
| 703 |
+
raise
|
| 704 |
+
|
| 705 |
+
try:
|
| 706 |
+
return self.execute_with_retry(_get_recent_scores)
|
| 707 |
+
except Exception as e:
|
| 708 |
+
bt.logging.error(f"Failed to get recent scores after {self.max_retries} attempts: {str(e)}")
|
| 709 |
+
return {}
|
| 710 |
+
|
| 711 |
+
def get_all_model_scores(self):
|
| 712 |
+
"""Get all model scores with retry logic."""
|
| 713 |
+
def _get_all_scores():
|
| 714 |
+
with self.session_scope() as session:
|
| 715 |
+
try:
|
| 716 |
+
# First, get the latest score timestamps
|
| 717 |
+
latest_scores = session.query(
|
| 718 |
+
ScoreHistory.hotkey,
|
| 719 |
+
ScoreHistory.uid,
|
| 720 |
+
func.max(ScoreHistory.scored_at).label('latest_score_time')
|
| 721 |
+
).filter(
|
| 722 |
+
ScoreHistory.is_archived == False
|
| 723 |
+
).group_by(
|
| 724 |
+
ScoreHistory.hotkey,
|
| 725 |
+
ScoreHistory.uid
|
| 726 |
+
).subquery('latest_scores')
|
| 727 |
+
|
| 728 |
+
# Get score details
|
| 729 |
+
latest_score_details = session.query(
|
| 730 |
+
ScoreHistory
|
| 731 |
+
).join(
|
| 732 |
+
latest_scores,
|
| 733 |
+
and_(
|
| 734 |
+
ScoreHistory.hotkey == latest_scores.c.hotkey,
|
| 735 |
+
ScoreHistory.uid == latest_scores.c.uid,
|
| 736 |
+
ScoreHistory.scored_at == latest_scores.c.latest_score_time
|
| 737 |
+
)
|
| 738 |
+
).subquery('latest_score_details')
|
| 739 |
+
|
| 740 |
+
# Get final results
|
| 741 |
+
results = session.query(
|
| 742 |
+
ModelQueue.uid,
|
| 743 |
+
ModelQueue.hotkey,
|
| 744 |
+
ModelQueue.competition_id,
|
| 745 |
+
latest_score_details.c.score,
|
| 746 |
+
latest_score_details.c.scored_at,
|
| 747 |
+
latest_score_details.c.block,
|
| 748 |
+
latest_score_details.c.model_hash,
|
| 749 |
+
latest_score_details.c.scorer_hotkey
|
| 750 |
+
).outerjoin(
|
| 751 |
+
latest_score_details,
|
| 752 |
+
and_(
|
| 753 |
+
ModelQueue.hotkey == latest_score_details.c.hotkey,
|
| 754 |
+
ModelQueue.uid == latest_score_details.c.uid
|
| 755 |
+
)
|
| 756 |
+
).all()
|
| 757 |
+
|
| 758 |
+
scores_by_uid = defaultdict(list)
|
| 759 |
+
for result in results:
|
| 760 |
+
if result.score is not None:
|
| 761 |
+
scores_by_uid[result.uid].append({
|
| 762 |
+
'hotkey': result.hotkey,
|
| 763 |
+
'competition_id': result.competition_id,
|
| 764 |
+
'score': result.score,
|
| 765 |
+
'scored_at': result.scored_at.isoformat() if result.scored_at else None,
|
| 766 |
+
'block': result.block,
|
| 767 |
+
'model_hash': result.model_hash,
|
| 768 |
+
})
|
| 769 |
+
else:
|
| 770 |
+
scores_by_uid[result.uid].append({
|
| 771 |
+
'hotkey': result.hotkey,
|
| 772 |
+
'competition_id': result.competition_id,
|
| 773 |
+
'score': None,
|
| 774 |
+
'scored_at': None,
|
| 775 |
+
'block': None,
|
| 776 |
+
'model_hash': None,
|
| 777 |
+
})
|
| 778 |
+
|
| 779 |
+
return dict(scores_by_uid)
|
| 780 |
+
|
| 781 |
+
except Exception as e:
|
| 782 |
+
bt.logging.error(f"Error in _get_all_scores: {str(e)}")
|
| 783 |
+
raise
|
| 784 |
+
|
| 785 |
+
try:
|
| 786 |
+
return self.execute_with_retry(_get_all_scores)
|
| 787 |
+
except Exception as e:
|
| 788 |
+
bt.logging.error(f"Failed to get all scores after {self.max_retries} attempts: {str(e)}")
|
| 789 |
+
return {}
|
| 790 |
+
|
| 791 |
+
def archive_scores_for_deregistered_models(self, registered_hotkey_uid_pairs):
|
| 792 |
+
"""Archive deregistered models with retry logic."""
|
| 793 |
+
def _archive_scores():
|
| 794 |
+
with self.session_scope() as session:
|
| 795 |
+
try:
|
| 796 |
+
all_models = session.query(
|
| 797 |
+
ModelQueue.hotkey,
|
| 798 |
+
ModelQueue.uid
|
| 799 |
+
).with_for_update().all()
|
| 800 |
+
|
| 801 |
+
deregistered_models = set(
|
| 802 |
+
(model.hotkey, model.uid) for model in all_models
|
| 803 |
+
) - set(registered_hotkey_uid_pairs)
|
| 804 |
+
|
| 805 |
+
for hotkey, uid in deregistered_models:
|
| 806 |
+
# Mark scores as archived
|
| 807 |
+
archive_result = session.query(ScoreHistory).filter_by(
|
| 808 |
+
hotkey=hotkey,
|
| 809 |
+
uid=uid,
|
| 810 |
+
is_archived=False
|
| 811 |
+
).update(
|
| 812 |
+
{"is_archived": True},
|
| 813 |
+
synchronize_session=False
|
| 814 |
+
)
|
| 815 |
+
|
| 816 |
+
# Remove from ModelQueue
|
| 817 |
+
delete_result = session.query(ModelQueue).filter_by(
|
| 818 |
+
hotkey=hotkey,
|
| 819 |
+
uid=uid
|
| 820 |
+
).delete(synchronize_session=False)
|
| 821 |
+
|
| 822 |
+
bt.logging.debug(
|
| 823 |
+
f"Processed deregistered model - Hotkey: {hotkey}, "
|
| 824 |
+
f"UID: {uid}, Archived scores: {archive_result}, "
|
| 825 |
+
f"Removed from queue: {delete_result}"
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
return len(deregistered_models)
|
| 829 |
+
|
| 830 |
+
except Exception as e:
|
| 831 |
+
bt.logging.error(f"Error in _archive_scores: {str(e)}")
|
| 832 |
+
raise
|
| 833 |
+
|
| 834 |
+
try:
|
| 835 |
+
result = self.execute_with_retry(_archive_scores)
|
| 836 |
+
print(f"Archived scores and removed {result} deregistered models from the queue.")
|
| 837 |
+
return result
|
| 838 |
+
except Exception as e:
|
| 839 |
+
bt.logging.error(f"Failed to archive scores after {self.max_retries} attempts: {str(e)}")
|
| 840 |
+
return 0
|
| 841 |
+
|
| 842 |
+
def close(self):
|
| 843 |
+
"""Safely close the session."""
|
| 844 |
+
try:
|
| 845 |
+
self.session.close()
|
| 846 |
+
except:
|
| 847 |
+
pass
|
model/storage/remote_model_store.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from model.data import Model, ModelId
|
| 3 |
+
from constants import CompetitionParameters
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
class RemoteModelStore(abc.ABC):
|
| 7 |
+
"""An abstract base class for storing and retrieving a pre trained model."""
|
| 8 |
+
|
| 9 |
+
@abc.abstractmethod
|
| 10 |
+
async def upload_model(self, model: Model, parameters: CompetitionParameters) -> ModelId:
|
| 11 |
+
"""Uploads a trained model in the appropriate location based on implementation."""
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
@abc.abstractmethod
|
| 15 |
+
async def download_model(self, model_id: ModelId, local_path: str, parameters: CompetitionParameters) -> Model:
|
| 16 |
+
"""Retrieves a trained model from the appropriate location and stores at the given path."""
|
| 17 |
+
pass
|
model/storage/reputation_store.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sqlalchemy import create_engine, Column, Integer, Float, String, DateTime, Boolean, JSON, func, desc, exists, ForeignKey, or_, and_, case
|
| 2 |
+
from sqlalchemy.exc import OperationalError, SQLAlchemyError
|
| 3 |
+
from sqlalchemy.ext.declarative import declarative_base
|
| 4 |
+
from sqlalchemy.orm import Session, sessionmaker, aliased, relationship
|
| 5 |
+
from contextlib import contextmanager
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
|
| 8 |
+
import time
|
| 9 |
+
from datetime import datetime, timedelta, timezone
|
| 10 |
+
import bittensor as bt
|
| 11 |
+
from typing import Optional
|
| 12 |
+
|
| 13 |
+
from vali_api.config import DBHOST, DBNAME, DBUSER, DBPASS, IS_PROD
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
Base = declarative_base()
|
| 17 |
+
|
| 18 |
+
# Global variables for engine and Session
|
| 19 |
+
_engine: Optional[object] = None
|
| 20 |
+
Session: Optional[sessionmaker] = None
|
| 21 |
+
|
| 22 |
+
def init_database():
|
| 23 |
+
"""
|
| 24 |
+
Initialize the database connection and create tables.
|
| 25 |
+
Must be called before using any database operations.
|
| 26 |
+
"""
|
| 27 |
+
global _engine, Session
|
| 28 |
+
|
| 29 |
+
if _engine is not None:
|
| 30 |
+
bt.logging.warning("Database already initialized")
|
| 31 |
+
return
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
connection_string = f'mysql://{DBUSER}:{DBPASS}@{DBHOST}/{DBNAME}'
|
| 35 |
+
_engine = create_engine(connection_string)
|
| 36 |
+
Session = sessionmaker(bind=_engine)
|
| 37 |
+
|
| 38 |
+
# Create all tables
|
| 39 |
+
Base.metadata.create_all(_engine)
|
| 40 |
+
bt.logging.info("Database initialized successfully")
|
| 41 |
+
|
| 42 |
+
except Exception as e:
|
| 43 |
+
bt.logging.error(f"Failed to initialize database: {str(e)}")
|
| 44 |
+
raise
|
| 45 |
+
|
| 46 |
+
def get_session() -> Session:
|
| 47 |
+
"""
|
| 48 |
+
Get a database session. Raises exception if database not initialized.
|
| 49 |
+
"""
|
| 50 |
+
if Session is None:
|
| 51 |
+
raise RuntimeError("Database not initialized. Call init_database() first.")
|
| 52 |
+
return Session()
|
| 53 |
+
|
| 54 |
+
def get_table_name(base_name: str) -> str:
|
| 55 |
+
"""Helper function to get the correct table name with suffix if not in production."""
|
| 56 |
+
return f"{base_name}{'_test' if not IS_PROD else ''}"
|
| 57 |
+
|
| 58 |
+
class BaselineScore(Base):
|
| 59 |
+
__tablename__ = get_table_name('sn21_baseline_scores')
|
| 60 |
+
id = Column(Integer, primary_key=True)
|
| 61 |
+
competition_id = Column(String(255), nullable=False, index=True)
|
| 62 |
+
score = Column(Float, nullable=False)
|
| 63 |
+
created_at = Column(DateTime, default=datetime.utcnow, index=True)
|
| 64 |
+
|
| 65 |
+
def __repr__(self):
|
| 66 |
+
return f"<BaselineScore(competition_id='{self.competition_id}', score={self.score}, created_at='{self.created_at}')>"
|
| 67 |
+
|
| 68 |
+
class MinerReputation(Base):
|
| 69 |
+
__tablename__ = get_table_name('sn21_miner_reputations')
|
| 70 |
+
hotkey = Column(String(255), primary_key=True)
|
| 71 |
+
reputation = Column(Float, default=0.5, nullable=False)
|
| 72 |
+
last_updated = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
| 73 |
+
|
| 74 |
+
def __repr__(self):
|
| 75 |
+
return f"<MinerReputation(hotkey='{self.hotkey}', reputation={self.reputation})>"
|
| 76 |
+
|
| 77 |
+
class ReputationHistory(Base):
|
| 78 |
+
__tablename__ = get_table_name('sn21_reputation_history')
|
| 79 |
+
id = Column(Integer, primary_key=True)
|
| 80 |
+
hotkey = Column(String(255), nullable=False, index=True)
|
| 81 |
+
timestamp = Column(DateTime, nullable=False, index=True)
|
| 82 |
+
reputation = Column(Float, nullable=False)
|
| 83 |
+
|
| 84 |
+
def __repr__(self):
|
| 85 |
+
return f"<ReputationHistory(hotkey='{self.hotkey}', timestamp='{self.timestamp}', reputation={self.reputation})>"
|
| 86 |
+
|
| 87 |
+
class ReputationStore:
|
| 88 |
+
def __init__(self, max_retries=3, retry_delay=1):
|
| 89 |
+
# Ensure DB is initialized and get the sessionmaker
|
| 90 |
+
if Session is None:
|
| 91 |
+
raise RuntimeError("Database not initialized. Call init_database() first.")
|
| 92 |
+
self.max_retries = max_retries
|
| 93 |
+
self.retry_delay = retry_delay
|
| 94 |
+
self.session = get_session()
|
| 95 |
+
|
| 96 |
+
@contextmanager
|
| 97 |
+
def session_scope(self):
|
| 98 |
+
"""Provide a transactional scope around a series of operations."""
|
| 99 |
+
session = get_session()
|
| 100 |
+
try:
|
| 101 |
+
yield session
|
| 102 |
+
session.commit()
|
| 103 |
+
except Exception as e:
|
| 104 |
+
session.rollback()
|
| 105 |
+
raise
|
| 106 |
+
finally:
|
| 107 |
+
session.close()
|
| 108 |
+
|
| 109 |
+
def reset_session(self):
|
| 110 |
+
"""Reset the session in case of connection issues."""
|
| 111 |
+
try:
|
| 112 |
+
self.session.close()
|
| 113 |
+
except:
|
| 114 |
+
pass
|
| 115 |
+
try:
|
| 116 |
+
self.session = get_session()
|
| 117 |
+
except Exception as e:
|
| 118 |
+
bt.logging.error(f"Failed to reset session: {str(e)}")
|
| 119 |
+
raise
|
| 120 |
+
|
| 121 |
+
def execute_with_retry(self, operation, *args, **kwargs):
|
| 122 |
+
"""Execute an operation with retry logic."""
|
| 123 |
+
for attempt in range(self.max_retries):
|
| 124 |
+
try:
|
| 125 |
+
return operation(*args, **kwargs)
|
| 126 |
+
except OperationalError as e:
|
| 127 |
+
if "Lost connection" in str(e) and attempt < self.max_retries - 1:
|
| 128 |
+
bt.logging.warning(f"Lost connection to MySQL. Attempt {attempt + 1}/{self.max_retries}. Retrying...")
|
| 129 |
+
self.reset_session()
|
| 130 |
+
time.sleep(self.retry_delay)
|
| 131 |
+
else:
|
| 132 |
+
raise
|
| 133 |
+
except SQLAlchemyError as e:
|
| 134 |
+
if attempt < self.max_retries - 1:
|
| 135 |
+
bt.logging.warning(f"Database error. Attempt {attempt + 1}/{self.max_retries}. Retrying...")
|
| 136 |
+
self.reset_session()
|
| 137 |
+
time.sleep(self.retry_delay)
|
| 138 |
+
else:
|
| 139 |
+
raise
|
| 140 |
+
|
| 141 |
+
def get_latest_baseline_score(self, competition_id):
|
| 142 |
+
def _get():
|
| 143 |
+
with self.session_scope() as session:
|
| 144 |
+
latest_baseline = (
|
| 145 |
+
session.query(BaselineScore)
|
| 146 |
+
.filter(BaselineScore.competition_id == competition_id)
|
| 147 |
+
.order_by(BaselineScore.created_at.desc())
|
| 148 |
+
.first()
|
| 149 |
+
)
|
| 150 |
+
if latest_baseline is None:
|
| 151 |
+
return None
|
| 152 |
+
else:
|
| 153 |
+
return {"competition_id": latest_baseline.competition_id, "score": latest_baseline.score, "created_at": latest_baseline.created_at}
|
| 154 |
+
|
| 155 |
+
return self.execute_with_retry(_get)
|
| 156 |
+
|
| 157 |
+
def get_all_reputations(self):
|
| 158 |
+
def _get():
|
| 159 |
+
with self.session_scope() as session:
|
| 160 |
+
records = session.query(MinerReputation).all()
|
| 161 |
+
return {
|
| 162 |
+
record.hotkey: {
|
| 163 |
+
"reputation": record.reputation,
|
| 164 |
+
"last_updated": record.last_updated.isoformat() if record.last_updated else None
|
| 165 |
+
}
|
| 166 |
+
for record in records
|
| 167 |
+
}
|
| 168 |
+
return self.execute_with_retry(_get)
|
| 169 |
+
|
| 170 |
+
def get_reputation(self, hotkey):
|
| 171 |
+
def _get():
|
| 172 |
+
with self.session_scope() as session:
|
| 173 |
+
record = session.query(MinerReputation).filter(MinerReputation.hotkey == hotkey).first()
|
| 174 |
+
if not record:
|
| 175 |
+
return None
|
| 176 |
+
return {
|
| 177 |
+
"hotkey": record.hotkey,
|
| 178 |
+
"reputation": record.reputation,
|
| 179 |
+
"last_updated": record.last_updated.isoformat() if record.last_updated else None
|
| 180 |
+
}
|
| 181 |
+
return self.execute_with_retry(_get)
|
| 182 |
+
|
| 183 |
+
def main():
|
| 184 |
+
"""
|
| 185 |
+
Main function to demonstrate the ReputationStore's three get methods.
|
| 186 |
+
"""
|
| 187 |
+
try:
|
| 188 |
+
# Initialize the database
|
| 189 |
+
print("Initializing database...")
|
| 190 |
+
init_database()
|
| 191 |
+
print("Database initialized successfully!")
|
| 192 |
+
|
| 193 |
+
# Create ReputationStore instance
|
| 194 |
+
print("\nCreating ReputationStore instance...")
|
| 195 |
+
reputation_store = ReputationStore()
|
| 196 |
+
print("ReputationStore created successfully!")
|
| 197 |
+
|
| 198 |
+
# Test 1: Get latest baseline score
|
| 199 |
+
print("\n=== Testing get_latest_baseline_score ===")
|
| 200 |
+
competition_id = "v1"
|
| 201 |
+
baseline_score = reputation_store.get_latest_baseline_score(competition_id)
|
| 202 |
+
if baseline_score:
|
| 203 |
+
print(f"Latest baseline score for competition '{competition_id}':")
|
| 204 |
+
print(f" Score: {baseline_score['score']}")
|
| 205 |
+
print(f" Created at: {baseline_score['created_at']}")
|
| 206 |
+
else:
|
| 207 |
+
print(f"No baseline score found for competition '{competition_id}'")
|
| 208 |
+
|
| 209 |
+
# Test 2: Get all reputations
|
| 210 |
+
print("\n=== Testing get_all_reputations ===")
|
| 211 |
+
all_reputations = reputation_store.get_all_reputations()
|
| 212 |
+
if all_reputations:
|
| 213 |
+
print(f"Found {len(all_reputations)} miner reputations:")
|
| 214 |
+
for hotkey, data in all_reputations.items():
|
| 215 |
+
print(f" Hotkey: {hotkey}")
|
| 216 |
+
print(f" Reputation: {data['reputation']}")
|
| 217 |
+
print(f" Last Updated: {data['last_updated']}")
|
| 218 |
+
break
|
| 219 |
+
else:
|
| 220 |
+
print("No miner reputations found in database")
|
| 221 |
+
|
| 222 |
+
# Test 3: Get specific reputation
|
| 223 |
+
print("\n=== Testing get_reputation ===")
|
| 224 |
+
test_hotkey = "test_hotkey_123"
|
| 225 |
+
reputation = reputation_store.get_reputation(test_hotkey)
|
| 226 |
+
if reputation:
|
| 227 |
+
print(f"Reputation for hotkey '{test_hotkey}':")
|
| 228 |
+
print(f" Hotkey: {reputation['hotkey']}")
|
| 229 |
+
print(f" Reputation: {reputation['reputation']}")
|
| 230 |
+
print(f" Last Updated: {reputation['last_updated']}")
|
| 231 |
+
else:
|
| 232 |
+
print(f"No reputation found for hotkey '{test_hotkey}'")
|
| 233 |
+
|
| 234 |
+
print("\n=== All tests completed successfully! ===")
|
| 235 |
+
|
| 236 |
+
except Exception as e:
|
| 237 |
+
print(f"Error occurred: {str(e)}")
|
| 238 |
+
import traceback
|
| 239 |
+
traceback.print_exc()
|
| 240 |
+
|
| 241 |
+
if __name__ == "__main__":
|
| 242 |
+
main()
|
| 243 |
+
|
| 244 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers==4.48.3
|
| 2 |
+
pydantic==2.11.4
|
| 3 |
+
numpy==2.2.5
|
| 4 |
+
torch==2.6.0
|
| 5 |
+
torchaudio==2.6.0
|
| 6 |
+
torchvision==0.21.0
|
| 7 |
+
fastapi==0.115.12
|
| 8 |
+
uvicorn==0.34.2
|
| 9 |
+
librosa==0.11.0
|
| 10 |
+
openai-whisper==20240930
|
| 11 |
+
soundfile==0.13.1
|
| 12 |
+
accelerate==0.26.0
|
| 13 |
+
voxcpm
|
server.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from fastapi import FastAPI, HTTPException
|
| 3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
+
from pydantic import BaseModel, Field
|
| 5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 6 |
+
import traceback
|
| 7 |
+
import whisper
|
| 8 |
+
import librosa
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import uvicorn
|
| 12 |
+
import base64
|
| 13 |
+
import io
|
| 14 |
+
from voxcpm import VoxCPM
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
asr_model = whisper.load_model("models/wpt/wpt.pt")
|
| 18 |
+
model_name = "models/Llama-3.2-1B-Instruct"
|
| 19 |
+
tok = AutoTokenizer.from_pretrained(model_name)
|
| 20 |
+
lm = AutoModelForCausalLM.from_pretrained(
|
| 21 |
+
model_name,
|
| 22 |
+
torch_dtype=torch.bfloat16,
|
| 23 |
+
device_map="cuda",
|
| 24 |
+
).eval()
|
| 25 |
+
|
| 26 |
+
tts = VoxCPM.from_pretrained(
|
| 27 |
+
"models/VoxCPM-0.5B",
|
| 28 |
+
local_files_only=True,
|
| 29 |
+
load_denoiser=True,
|
| 30 |
+
zipenhancer_model_id="models/iic/speech_zipenhancer_ans_multiloss_16k_base"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def chat(system_prompt: str, user_prompt: str) -> str:
|
| 34 |
+
print("LLM init...")
|
| 35 |
+
messages = [
|
| 36 |
+
{"role": "system", "content": system_prompt},
|
| 37 |
+
{"role": "user", "content": user_prompt},
|
| 38 |
+
]
|
| 39 |
+
inputs = tok.apply_chat_template(
|
| 40 |
+
messages,
|
| 41 |
+
add_generation_prompt=True,
|
| 42 |
+
return_tensors="pt",
|
| 43 |
+
return_dict=True
|
| 44 |
+
)
|
| 45 |
+
input_ids = inputs["input_ids"].to(lm.device)
|
| 46 |
+
attention_mask = inputs["attention_mask"].to(lm.device)
|
| 47 |
+
|
| 48 |
+
with torch.inference_mode():
|
| 49 |
+
output_ids = lm.generate(
|
| 50 |
+
input_ids=input_ids,
|
| 51 |
+
attention_mask=attention_mask,
|
| 52 |
+
pad_token_id=tok.eos_token_id,
|
| 53 |
+
max_new_tokens=2048,
|
| 54 |
+
do_sample=True,
|
| 55 |
+
temperature=0.2,
|
| 56 |
+
repetition_penalty=1.1,
|
| 57 |
+
top_k=100,
|
| 58 |
+
top_p=0.95,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
answer = tok.decode(
|
| 62 |
+
output_ids[0][input_ids.shape[-1]:],
|
| 63 |
+
skip_special_tokens=True,
|
| 64 |
+
clean_up_tokenization_spaces=True,
|
| 65 |
+
)
|
| 66 |
+
print("LLM answer done.")
|
| 67 |
+
return answer.strip()
|
| 68 |
+
|
| 69 |
+
def gt(audio: np.ndarray, sr: int):
|
| 70 |
+
print("Starting ASR transcription...")
|
| 71 |
+
ss = audio.squeeze().astype(np.float32)
|
| 72 |
+
if sr != 16_000:
|
| 73 |
+
ss = librosa.resample(audio, orig_sr=sr, target_sr=16_000)
|
| 74 |
+
|
| 75 |
+
result = asr_model.transcribe(ss, fp16=False, language=None)
|
| 76 |
+
transcribed_text = result["text"].strip()
|
| 77 |
+
print(f"ASR done. Transcribed: '{transcribed_text}'")
|
| 78 |
+
return transcribed_text
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def sample(rr: str) -> str:
|
| 82 |
+
if rr.strip() == "":
|
| 83 |
+
rr = "Hello "
|
| 84 |
+
|
| 85 |
+
inputs = tok(rr, return_tensors="pt").to(lm.device)
|
| 86 |
+
|
| 87 |
+
with torch.inference_mode():
|
| 88 |
+
out_ids = lm.generate(
|
| 89 |
+
**inputs,
|
| 90 |
+
max_new_tokens=2048,
|
| 91 |
+
do_sample=True,
|
| 92 |
+
temperature=0.2,
|
| 93 |
+
repetition_penalty=1.1,
|
| 94 |
+
top_k=100,
|
| 95 |
+
top_p=0.95,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
return tok.decode(
|
| 99 |
+
out_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
INITIALIZATION_STATUS = {"model_loaded": True, "error": None}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class GenerateRequest(BaseModel):
|
| 107 |
+
audio_data: str = Field(..., description="")
|
| 108 |
+
sample_rate: int = Field(..., description="")
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class GenerateResponse(BaseModel):
|
| 112 |
+
audio_data: str = Field(..., description="")
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
app = FastAPI(title="V1", version="0.1")
|
| 116 |
+
|
| 117 |
+
app.add_middleware(
|
| 118 |
+
CORSMiddleware,
|
| 119 |
+
allow_origins=["*"],
|
| 120 |
+
allow_credentials=True,
|
| 121 |
+
allow_methods=["*"],
|
| 122 |
+
allow_headers=["*"],
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def b64(b64: str) -> np.ndarray:
|
| 127 |
+
raw = base64.b64decode(b64)
|
| 128 |
+
return np.load(io.BytesIO(raw), allow_pickle=False)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def ab64(arr: np.ndarray, sr: int) -> str:
|
| 132 |
+
buf = io.BytesIO()
|
| 133 |
+
resampled = librosa.resample(arr, orig_sr=16000, target_sr=sr)
|
| 134 |
+
np.save(buf, resampled.astype(np.float32))
|
| 135 |
+
return base64.b64encode(buf.getvalue()).decode()
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@app.get("/api/v1/health")
|
| 139 |
+
def health_check():
|
| 140 |
+
return {
|
| 141 |
+
"status": "healthy",
|
| 142 |
+
"model_loaded": INITIALIZATION_STATUS["model_loaded"],
|
| 143 |
+
"error": INITIALIZATION_STATUS["error"],
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@app.post("/api/v1/v2v", response_model=GenerateResponse)
|
| 148 |
+
def generate_audio(req: GenerateRequest):
|
| 149 |
+
print("=== V2V Request Started ===")
|
| 150 |
+
audio_np = b64(req.audio_data)
|
| 151 |
+
if audio_np.ndim == 1:
|
| 152 |
+
audio_np = audio_np.reshape(1, -1)
|
| 153 |
+
print(f"Audio shape: {audio_np.shape}, Sample rate: {req.sample_rate}")
|
| 154 |
+
|
| 155 |
+
system_prompt = (
|
| 156 |
+
"You are a helpful assistant who tries to help answer the user's question. "
|
| 157 |
+
"This is a part of voice assistant system, don't generate anything other than pure text."
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
try:
|
| 161 |
+
text = gt(audio_np, req.sample_rate)
|
| 162 |
+
response_text = chat(system_prompt, user_prompt=text)
|
| 163 |
+
print(f"LLM response len chars: '{len(response_text)}'")
|
| 164 |
+
print(f"LLM response: '{response_text}'")
|
| 165 |
+
|
| 166 |
+
import time
|
| 167 |
+
start_time = time.perf_counter()
|
| 168 |
+
audio_out = tts.generate(
|
| 169 |
+
text=response_text,
|
| 170 |
+
prompt_wav_path=None,
|
| 171 |
+
prompt_text=None,
|
| 172 |
+
cfg_value=2.0,
|
| 173 |
+
inference_timesteps=10,
|
| 174 |
+
normalize=True,
|
| 175 |
+
denoise=True,
|
| 176 |
+
retry_badcase=True,
|
| 177 |
+
retry_badcase_max_times=3,
|
| 178 |
+
retry_badcase_ratio_threshold=6.0,
|
| 179 |
+
)
|
| 180 |
+
print("TTS generation complete.")
|
| 181 |
+
end_time = time.perf_counter()
|
| 182 |
+
print(f"TTS generation took {end_time - start_time:.2f} seconds.")
|
| 183 |
+
print("=== V2V Request Complete ===")
|
| 184 |
+
except Exception as e:
|
| 185 |
+
print(f"ERROR in V2V: {e}")
|
| 186 |
+
traceback.print_exc()
|
| 187 |
+
raise HTTPException(status_code=500, detail=f"{e}")
|
| 188 |
+
|
| 189 |
+
return GenerateResponse(audio_data=ab64(audio_out, req.sample_rate))
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@app.post("/api/v1/v2t")
|
| 193 |
+
def generate_text(req: GenerateRequest):
|
| 194 |
+
audio_np = b64(req.audio_data)
|
| 195 |
+
if audio_np.ndim == 1:
|
| 196 |
+
audio_np = audio_np.reshape(1, -1)
|
| 197 |
+
|
| 198 |
+
try:
|
| 199 |
+
text = gt(audio_np, req.sample_rate)
|
| 200 |
+
print(f"Transcribed text: {text}")
|
| 201 |
+
system_prompt = "You are a helpful assistant who tries to help answer the user's question."
|
| 202 |
+
response_text = chat(system_prompt, user_prompt=text)
|
| 203 |
+
except Exception as e:
|
| 204 |
+
traceback.print_exc()
|
| 205 |
+
raise HTTPException(status_code=500, detail=f"{e}")
|
| 206 |
+
|
| 207 |
+
return {"text": response_text}
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
if __name__ == "__main__":
|
| 211 |
+
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=False)
|
spk_001.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:79de3a5775f8880c0bf3e950b103f03b257db630224fab265a309d82753b1aa5
|
| 3 |
+
size 480044
|