Shinichie commited on
Commit
6e17fd0
·
verified ·
1 Parent(s): f54b13d

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assistant_female_voice.wav filter=lfs diff=lfs merge=lfs -text
37
+ models/VoxCPM-0.5B/assets/voxcpm_model.png filter=lfs diff=lfs merge=lfs -text
38
+ models/iic/SenseVoiceSmall/fig/aed_figure.png filter=lfs diff=lfs merge=lfs -text
39
+ models/iic/SenseVoiceSmall/fig/asr_results.png filter=lfs diff=lfs merge=lfs -text
40
+ models/iic/SenseVoiceSmall/fig/inference.png filter=lfs diff=lfs merge=lfs -text
41
+ models/iic/SenseVoiceSmall/fig/sensevoice.png filter=lfs diff=lfs merge=lfs -text
42
+ models/iic/SenseVoiceSmall/fig/ser_figure.png filter=lfs diff=lfs merge=lfs -text
43
+ models/iic/SenseVoiceSmall/fig/ser_table.png filter=lfs diff=lfs merge=lfs -text
44
+ models/iic/speech_zipenhancer_ans_multiloss_16k_base/description/matrix.jpg filter=lfs diff=lfs merge=lfs -text
45
+ models/iic/speech_zipenhancer_ans_multiloss_16k_base/description/matrix_voicebank.jpg filter=lfs diff=lfs merge=lfs -text
46
+ models/iic/speech_zipenhancer_ans_multiloss_16k_base/examples/speech_with_noise1.wav filter=lfs diff=lfs merge=lfs -text
47
+ models/v10/tokenizer.json filter=lfs diff=lfs merge=lfs -text
48
+ spk_001.wav filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04
2
+
3
+ # Set environment variables
4
+ ENV PYTHONUNBUFFERED=1 \
5
+ DEBIAN_FRONTEND=noninteractive \
6
+ CUDA_HOME=/usr/local/cuda \
7
+ PATH=/usr/local/cuda/bin:$PATH \
8
+ LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH \
9
+ NVIDIA_VISIBLE_DEVICES=all \
10
+ NVIDIA_DRIVER_CAPABILITIES=compute,utility \
11
+ HF_HOME=/app/models \
12
+ TRITON_CACHE_DIR=/tmp/triton_cache \
13
+ XDG_CACHE_HOME=/tmp \
14
+ NUMBA_CACHE_DIR=/tmp/numba_cache
15
+
16
+ # Install system dependencies
17
+ RUN apt-get update && apt-get install -y --no-install-recommends \
18
+ python3 \
19
+ python3-pip \
20
+ python3-dev \
21
+ build-essential \
22
+ git \
23
+ ffmpeg \
24
+ libsndfile1 \
25
+ curl \
26
+ && rm -rf /var/lib/apt/lists/*
27
+
28
+ # Upgrade pip and install build tools
29
+ RUN python3 -m pip install --upgrade pip setuptools wheel uv
30
+
31
+ WORKDIR /app
32
+
33
+ # Create Numba cache directory
34
+ RUN mkdir -p /tmp/numba_cache /tmp/triton_cache && \
35
+ chown nobody:nogroup /tmp/numba_cache /tmp/triton_cache && \
36
+ chmod 700 /tmp/numba_cache /tmp/triton_cache
37
+
38
+ COPY requirements.txt .
39
+
40
+ # Install other requirements
41
+ RUN python3 -m uv pip install --no-cache-dir -r requirements.txt --prerelease=allow
42
+
43
+ COPY . .
44
+
45
+ EXPOSE 8010
46
+
47
+ # CMD ["python3", "server.py"]
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - any-to-any
5
+ - omega
6
+ - omegalabs
7
+ - bittensor
8
+ - agi
9
+ ---
10
+
11
+ This is an Any-to-Any model checkpoint for the OMEGA Labs x Bittensor Any-to-Any subnet.
12
+
13
+ Check out the [git repo](https://github.com/omegalabsinc/omegalabs-anytoany-bittensor) and find OMEGA on X: [@omegalabsai](https://x.com/omegalabsai).
assistant_female_voice.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d712ba6de1d15d52eda96bdc043ce43eb5af4b4ac441b78b6fb0fdaf6683c7a
3
+ size 235244
attention_mask_research.md ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Attention Masks and Pad Tokens in Transformer Generation: Research Questions
2
+
3
+ ## Core Problem Statement
4
+
5
+ When running transformer models (specifically Llama-3.2-1B-Instruct) for text generation, we encounter warnings about missing attention masks and pad tokens, even for single input sequences. This leads to inconsistent generation outputs despite identical inputs.
6
+
7
+ ### Warning Messages Observed
8
+ ```
9
+ The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
10
+ Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
11
+ The attention mask is not set and cannot be inferred from input because pad token is same as eos token.
12
+ ```
13
+
14
+ ## Key Research Questions
15
+
16
+ ### 1. Why do single inputs require attention masks?
17
+ **Initial Assumption**: Single sequences without padding shouldn't need attention masks.
18
+ **Observed Reality**: Even single inputs show different generation outputs when attention masks are missing.
19
+
20
+ ### 2. What is the relationship between pad tokens and attention masks?
21
+ **Question**: How do pad_token_id and attention_mask work together in the generation process?
22
+
23
+ ### 3. Why does pad_token_id = eos_token_id cause issues?
24
+ **Specific Issue**: When padding token equals end-of-sequence token, what ambiguity does this create?
25
+
26
+ ## Code Analysis
27
+
28
+ ### Current Implementation (Problematic)
29
+ ```python
30
+ def chat_current(system_prompt: str, user_prompt: str) -> str:
31
+ messages = [
32
+ {"role": "system", "content": system_prompt},
33
+ {"role": "user", "content": user_prompt},
34
+ ]
35
+
36
+ # Only returns input_ids tensor
37
+ input_ids = tok.apply_chat_template(
38
+ messages,
39
+ add_generation_prompt=True,
40
+ return_tensors="pt"
41
+ ).to(lm.device)
42
+
43
+ with torch.inference_mode():
44
+ output_ids = lm.generate(
45
+ input_ids, # Missing: attention_mask, pad_token_id
46
+ max_new_tokens=2048,
47
+ do_sample=True,
48
+ temperature=0.2,
49
+ repetition_penalty=1.1,
50
+ top_k=100,
51
+ top_p=0.95,
52
+ )
53
+
54
+ return tok.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
55
+ ```
56
+
57
+ ### Fixed Implementation
58
+ ```python
59
+ def chat_fixed(system_prompt: str, user_prompt: str) -> str:
60
+ messages = [
61
+ {"role": "system", "content": system_prompt},
62
+ {"role": "user", "content": user_prompt},
63
+ ]
64
+
65
+ # Returns dictionary with input_ids AND attention_mask
66
+ inputs = tok.apply_chat_template(
67
+ messages,
68
+ add_generation_prompt=True,
69
+ return_tensors="pt",
70
+ return_dict=True # KEY CHANGE: Get both components
71
+ )
72
+
73
+ input_ids = inputs["input_ids"].to(lm.device)
74
+ attention_mask = inputs["attention_mask"].to(lm.device)
75
+
76
+ with torch.inference_mode():
77
+ output_ids = lm.generate(
78
+ input_ids=input_ids,
79
+ attention_mask=attention_mask, # Explicit attention guidance
80
+ pad_token_id=tok.eos_token_id, # Explicit pad token
81
+ max_new_tokens=2048,
82
+ do_sample=True,
83
+ temperature=0.2,
84
+ repetition_penalty=1.1,
85
+ top_k=100,
86
+ top_p=0.95,
87
+ )
88
+
89
+ return tok.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
90
+ ```
91
+
92
+ ### Model and Tokenizer Setup
93
+ ```python
94
+ model_name = "models/Llama-3.2-1B-Instruct"
95
+ tok = AutoTokenizer.from_pretrained(model_name)
96
+ # Critical: Set pad token if not available
97
+ if tok.pad_token is None:
98
+ tok.pad_token = tok.eos_token
99
+
100
+ lm = AutoModelForCausalLM.from_pretrained(
101
+ model_name,
102
+ torch_dtype=torch.bfloat16,
103
+ device_map="cuda",
104
+ ).eval()
105
+ ```
106
+
107
+ ## Observed Behavioral Differences
108
+
109
+ ### Input Structure Analysis
110
+ ```python
111
+ # Single input contains multiple components:
112
+ messages = [
113
+ {"role": "system", "content": "You are a helpful assistant..."},
114
+ {"role": "user", "content": "What is the capital of France?"},
115
+ ]
116
+
117
+ # After apply_chat_template, becomes token sequence:
118
+ # [system_tokens, user_tokens, assistant_start_token]
119
+ ```
120
+
121
+ ## Technical Hypotheses for Investigation
122
+
123
+ ### Hypothesis 1: Internal Masking Ambiguity
124
+ When attention_mask is missing, the model cannot distinguish between:
125
+ - Real input tokens that should influence generation
126
+ - Structural tokens (system prompts, role markers)
127
+ - Token boundaries between different message roles
128
+
129
+ ### Hypothesis 2: EOS Token Dual Purpose Confusion
130
+ When `pad_token_id == eos_token_id`, the model faces ambiguity:
131
+ ```python
132
+ # Same token (128001) serves dual purposes:
133
+ # 1. End of sequence marker
134
+ # 2. Padding token for batch processing
135
+ # Model cannot infer which purpose applies in context
136
+ ```
137
+
138
+ ### Hypothesis 3: Autoregressive Generation Context Boundary Issues
139
+ During generation, model needs to know:
140
+ - Which input tokens provide valid context for next token prediction
141
+ - Where the "prompt" ends and "generation" begins
142
+ - How to weight attention across different input components
143
+
144
+ ## Research Objectives
145
+
146
+ ### Primary Questions
147
+ 1. **Mechanism Analysis**: How exactly does missing attention_mask affect the internal attention computation?
148
+ 2. **Consistency Impact**: Why do identical inputs produce different outputs without proper masking?
149
+ 3. **Single vs Batch Behavior**: What differences exist between single sequence and batched sequence processing?
150
+
151
+ ### Secondary Questions
152
+ 1. **Model-Specific Behavior**: Do different transformer architectures handle missing attention masks differently?
153
+ 2. **Generation Parameter Interaction**: How do attention mask issues interact with sampling parameters (temperature, top_p, etc.)?
154
+ 3. **Performance Impact**: What computational overhead does proper attention masking add?
155
+
156
+ ## Key Technical Areas for Deep Research
157
+
158
+ ### Attention Mechanism Internals
159
+ - How attention weights are computed with/without explicit masks
160
+ - Impact on multi-head attention distributions
161
+ - Interaction with causal masking in autoregressive models
162
+
163
+ ### Tokenizer Behavior
164
+ - How `apply_chat_template` constructs input sequences
165
+ - Default attention mask generation behavior
166
+ - Role of special tokens in attention computation
167
+
168
+ ### Generation Process
169
+ - How `model.generate()` handles missing parameters
170
+ - Internal assumptions and fallback behaviors
171
+ - Impact on sampling and beam search algorithms
172
+
173
+ ## Expected Research Outcomes
174
+
175
+ Understanding of:
176
+ 1. Exact mechanism causing output inconsistency
177
+ 2. Best practices for single sequence generation
178
+ 3. Relationship between attention masking and generation quality
179
+ 4. Guidelines for production transformer deployment
180
+
181
+ ## References for Deep Research
182
+
183
+ - Hugging Face Transformers documentation on attention masks
184
+ - Technical blogs on transformer attention mechanisms (2024)
185
+ - Community discussions on pad token vs attention mask differences
186
+ - Official model documentation for Llama architecture attention handling
compare_generation.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+
5
+ # Load model and tokenizer (same as server.py)
6
+ model_name = "models/Llama-3.2-1B-Instruct"
7
+ tok = AutoTokenizer.from_pretrained(model_name)
8
+ lm = AutoModelForCausalLM.from_pretrained(
9
+ model_name,
10
+ torch_dtype=torch.bfloat16,
11
+ device_map="cuda",
12
+ ).eval()
13
+
14
+ def chat_current(system_prompt: str, user_prompt: str) -> str:
15
+ """
16
+ Current implementation (same as server.py) - will show warnings
17
+ """
18
+ print("🔴 Running CURRENT implementation (with warnings)...")
19
+
20
+ messages = [
21
+ {"role": "system", "content": system_prompt},
22
+ {"role": "user", "content": user_prompt},
23
+ ]
24
+
25
+
26
+ input_ids = tok.apply_chat_template(
27
+ messages,
28
+ add_generation_prompt=True,
29
+ return_tensors="pt"
30
+ ).to(lm.device)
31
+
32
+ with torch.inference_mode():
33
+ output_ids = lm.generate(
34
+ input_ids, # No attention_mask, no pad_token_id
35
+ max_new_tokens=2048,
36
+ do_sample=True,
37
+ temperature=0.2,
38
+ repetition_penalty=1.1,
39
+ top_k=100,
40
+ top_p=0.95,
41
+ )
42
+
43
+ answer = tok.decode(
44
+ output_ids[0][input_ids.shape[-1]:],
45
+ skip_special_tokens=True,
46
+ clean_up_tokenization_spaces=True,
47
+ )
48
+ return answer.strip()
49
+
50
+
51
+ def chat_fixed(system_prompt: str, user_prompt: str) -> str:
52
+ """
53
+ Fixed implementation - proper attention mask and pad token
54
+ """
55
+ print("🟢 Running FIXED implementation (no warnings)...")
56
+
57
+ messages = [
58
+ {"role": "system", "content": system_prompt},
59
+ {"role": "user", "content": user_prompt},
60
+ ]
61
+
62
+ # Get both input_ids and attention_mask
63
+ inputs = tok.apply_chat_template(
64
+ messages,
65
+ add_generation_prompt=True,
66
+ return_tensors="pt",
67
+ return_dict=True # Returns dict with input_ids and attention_mask
68
+ )
69
+
70
+ # Move to device
71
+ input_ids = inputs["input_ids"].to(lm.device)
72
+ attention_mask = inputs["attention_mask"].to(lm.device)
73
+
74
+ with torch.inference_mode():
75
+ output_ids = lm.generate(
76
+ input_ids=input_ids,
77
+ attention_mask=attention_mask, # Proper attention mask
78
+ pad_token_id=tok.eos_token_id, # Explicit pad token
79
+ max_new_tokens=2048,
80
+ do_sample=True,
81
+ temperature=0.2,
82
+ repetition_penalty=1.1,
83
+ top_k=100,
84
+ top_p=0.95,
85
+ )
86
+
87
+ answer = tok.decode(
88
+ output_ids[0][input_ids.shape[-1]:],
89
+ skip_special_tokens=True,
90
+ clean_up_tokenization_spaces=True,
91
+ )
92
+ return answer.strip()
93
+
94
+
95
+ def compare_generations():
96
+ """Compare both implementations"""
97
+ system_prompt = "You are a helpful assistant who tries to help answer the user's question."
98
+ user_prompt = "Create a report on anxiety in work. How do I manage time and stress effectively?"
99
+
100
+ print("=" * 60)
101
+ print("COMPARING GENERATION METHODS")
102
+ print("=" * 60)
103
+ print(f"System: {system_prompt}")
104
+ print(f"User: {user_prompt}")
105
+ print("=" * 60)
106
+
107
+ # Test current implementation
108
+ print("\n" + "=" * 60)
109
+ current_output = chat_current(system_prompt, user_prompt)
110
+ print(f"CURRENT OUTPUT:\n{current_output}")
111
+
112
+ print("\n" + "=" * 60)
113
+ # Test fixed implementation
114
+ fixed_output = chat_fixed(system_prompt, user_prompt)
115
+ print(f"FIXED OUTPUT:\n{fixed_output}")
116
+
117
+ print("\n" + "=" * 60)
118
+ print("COMPARISON:")
119
+ print(f"Outputs are identical: {current_output == fixed_output}")
120
+ print(f"Current length: {len(current_output)} chars")
121
+ print(f"Fixed length: {len(fixed_output)} chars")
122
+
123
+
124
+ if __name__ == "__main__":
125
+ # Set pad token for the fixed version
126
+ if tok.pad_token is None:
127
+ tok.pad_token = tok.eos_token
128
+
129
+ compare_generations()
hotkey.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 5DJMUw8aibqe5BkyF33m87hBrWKuxkiuGPFxdVGzva6DSFLK
model/__init__.py ADDED
File without changes
model/data.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, ClassVar, Dict, Optional, Type
2
+ # from transformers import PreTrainedModel, PreTrainedTokenizerBase
3
+ from pydantic import BaseModel, Field, PositiveInt, ConfigDict
4
+
5
+ # The maximum bytes for metadata on the chain.
6
+ MAX_METADATA_BYTES = 128
7
+ # The length, in bytes, of a git commit hash.
8
+ GIT_COMMIT_LENGTH = 40
9
+ # The length, in bytes, of a base64 encoded sha256 hash.
10
+ SHA256_BASE_64_LENGTH = 44
11
+ # The max length, in characters, of the competition id
12
+ MAX_COMPETITION_ID_LENGTH = 2
13
+
14
+
15
+ class ModelId(BaseModel):
16
+ """Uniquely identifies a trained model"""
17
+
18
+ MAX_REPO_ID_LENGTH: ClassVar[int] = (
19
+ MAX_METADATA_BYTES
20
+ - GIT_COMMIT_LENGTH
21
+ - SHA256_BASE_64_LENGTH
22
+ - MAX_COMPETITION_ID_LENGTH
23
+ - 4 # separators
24
+ )
25
+
26
+ namespace: str = Field(
27
+ description="Namespace where the model can be found. ex. Hugging Face username/org."
28
+ )
29
+ name: str = Field(description="Name of the model.")
30
+
31
+ epoch: str = Field(description="The epoch number to submit as your checkpoint to evaluate e.g. 10")
32
+
33
+ # When handling a model locally the commit and hash are not necessary.
34
+ # Commit must be filled when trying to download from a remote store.
35
+ commit: Optional[str] = Field(
36
+ description="Commit of the model. May be empty if not yet committed."
37
+ )
38
+ # Hash is filled automatically when uploading to or downloading from a remote store.
39
+ hash: Optional[str] = Field(description="Hash of the trained model.")
40
+ # Identifier for competition
41
+ competition_id: Optional[str] = Field(description="The competition id")
42
+
43
+ def to_compressed_str(self) -> str:
44
+ """Returns a compressed string representation."""
45
+ return f"{self.namespace}:{self.name}:{self.epoch}:{self.commit}:{self.hash}:{self.competition_id}"
46
+
47
+ @classmethod
48
+ def from_compressed_str(cls, cs: str) -> Type["ModelId"]:
49
+ """Returns an instance of this class from a compressed string representation"""
50
+ tokens = cs.split(":")
51
+ return cls(
52
+ namespace=tokens[0],
53
+ name=tokens[1],
54
+ epoch=tokens[2] if tokens[2] != "None" else None,
55
+ commit=tokens[3] if tokens[3] != "None" else None,
56
+ hash=tokens[4] if tokens[4] != "None" else None,
57
+ competition_id=(
58
+ tokens[5] if len(tokens) >= 6 and tokens[5] != "None" else None
59
+ ),
60
+ )
61
+
62
+
63
+ class Model(BaseModel):
64
+ """Represents a pre trained foundation model."""
65
+
66
+ model_config = ConfigDict(arbitrary_types_allowed=True)
67
+
68
+ id: ModelId = Field(description="Identifier for this model.")
69
+ local_repo_dir: str = Field(description="Local repository with the required files.")
70
+
71
+
72
+ class ModelMetadata(BaseModel):
73
+ id: ModelId = Field(description="Identifier for this trained model.")
74
+ block: PositiveInt = Field(
75
+ description="Block on which this model was claimed on the chain."
76
+ )
model/model_tracker.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import datetime
3
+ import threading
4
+ from typing import Dict, List, Optional, Set
5
+ import pickle
6
+ import bittensor as bt
7
+ import hashlib
8
+
9
+ from model.data import ModelMetadata
10
+
11
+
12
+ class NoopLock:
13
+ def __enter__(self):
14
+ pass
15
+
16
+ def __exit__(self, exc_type, exc_value, traceback):
17
+ pass
18
+
19
+
20
+ class ModelTracker:
21
+ """Tracks the current model for each miner.
22
+
23
+ Thread safe.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ thread_safe: bool = True,
29
+ ):
30
+ # Create a dict from miner hotkey to model metadata.
31
+ self.miner_hotkey_to_model_metadata_dict: dict[str, ModelMetadata] = dict()
32
+ # Create a dict from miner hotkey to last time it was evaluated/loaded/updated
33
+ self.miner_hotkey_to_last_touched_dict: dict[str, datetime.datetime] = dict()
34
+ # Create a dict from miner hotkey to model hash.
35
+ self.miner_hotkey_to_model_hash_dict: dict[str, str] = dict()
36
+
37
+ # List of overwritten models that may be safe to delete if not curently in use.
38
+ self.old_model_metadata: list[tuple[str, ModelMetadata]] = []
39
+ # List of model metadata that are currently in use.
40
+ self.model_metadata_in_use: set[tuple[str, str]] = set()
41
+
42
+ # Make this class thread safe because it will be accessed by multiple threads.
43
+ # One for the downloading new models loop and one for the validating models loop.
44
+ self.lock = threading.RLock() if thread_safe else NoopLock()
45
+
46
+ def save_state(self, filepath):
47
+ """Save the current state to the provided filepath."""
48
+
49
+ # Open a writable binary file for pickle.
50
+ with self.lock:
51
+ with open(filepath, "wb") as f:
52
+ pickle.dump(self.miner_hotkey_to_model_metadata_dict, f)
53
+
54
+ def load_state(self, filepath):
55
+ """Load the state from the provided filepath."""
56
+
57
+ # Open a readable binary file for pickle.
58
+ with open(filepath, "rb") as f:
59
+ self.miner_hotkey_to_model_metadata_dict = pickle.load(f)
60
+
61
+ def get_miner_hotkey_to_model_metadata_dict(self) -> Dict[str, ModelMetadata]:
62
+ """Returns the mapping from miner hotkey to model metadata."""
63
+
64
+ # Return a copy to ensure outside code can't modify the scores.
65
+ with self.lock:
66
+ return copy.deepcopy(self.miner_hotkey_to_model_metadata_dict)
67
+
68
+ def get_model_metadata_for_miner_hotkey(
69
+ self, hotkey: str
70
+ ) -> Optional[ModelMetadata]:
71
+ """Returns the model metadata for a given hotkey if any."""
72
+
73
+ with self.lock:
74
+ if hotkey in self.miner_hotkey_to_model_metadata_dict:
75
+ return self.miner_hotkey_to_model_metadata_dict[hotkey]
76
+ return None
77
+
78
+ def take_model_metadata_for_miner_hotkey(self, hotkey: str) -> Optional[ModelMetadata]:
79
+ """Returns the model metadata for a given hotkey if any. Also, marks it as in use to prevent race conditions."""
80
+
81
+ with self.lock:
82
+ if hotkey in self.miner_hotkey_to_model_metadata_dict:
83
+ metadata = self.miner_hotkey_to_model_metadata_dict[hotkey]
84
+ self.model_metadata_in_use.add((hotkey, metadata.id.hash))
85
+ return metadata
86
+ return None
87
+
88
+ def release_all(self):
89
+ with self.lock:
90
+ self.model_metadata_in_use.clear()
91
+
92
+ def release_model_metadata_for_miner_hotkey(self, hotkey: str, metadata: ModelMetadata):
93
+ with self.lock:
94
+ pair = (hotkey, metadata.id.hash)
95
+ if pair not in self.model_metadata_in_use:
96
+ bt.logging.error("Model metadata is not in use!")
97
+
98
+ if (hotkey, metadata) in self.old_model_metadata:
99
+ bt.logging.trace(f"Releasing old model metadata for hotkey: {hotkey}")
100
+
101
+ self.model_metadata_in_use.remove(pair)
102
+
103
+ def get_miner_hotkey_to_last_touched_dict(self) -> Dict[str, datetime.datetime]:
104
+ """Returns the mapping from miner hotkey to last time it was touched."""
105
+
106
+ # Return a copy to ensure outside code can't modify the scores.
107
+ with self.lock:
108
+ return copy.deepcopy(self.miner_hotkey_to_last_touched_dict)
109
+
110
+ def on_hotkeys_updated(self, incoming_hotkeys: Set[str]):
111
+ """Notifies the tracker which hotkeys are currently being tracked on the metagraph."""
112
+
113
+ with self.lock:
114
+ existing_hotkeys = set(self.miner_hotkey_to_model_metadata_dict.keys())
115
+ for hotkey in existing_hotkeys - incoming_hotkeys:
116
+ del self.miner_hotkey_to_model_metadata_dict[hotkey]
117
+ bt.logging.trace(f"Removed outdated hotkey metadata: {hotkey} from ModelTracker")
118
+
119
+ existing_hotkeys = set(self.miner_hotkey_to_last_touched_dict.keys())
120
+ for hotkey in existing_hotkeys - incoming_hotkeys:
121
+ del self.miner_hotkey_to_last_touched_dict[hotkey]
122
+ bt.logging.trace(f"Removed outdated hotkey timestamp: {hotkey} from ModelTracker")
123
+
124
+ def get_and_clear_old_models(self) -> list[tuple[str, ModelMetadata]]:
125
+ with self.lock:
126
+ to_delete = []
127
+ still_in_use = []
128
+ for hotkey, model in self.old_model_metadata:
129
+ if (hotkey, model.id.hash) in self.model_metadata_in_use:
130
+ still_in_use.append((hotkey, model))
131
+ else:
132
+ to_delete.append((hotkey, model))
133
+ self.old_model_metadata = still_in_use
134
+
135
+ return to_delete
136
+
137
+ def on_miner_model_updated(
138
+ self,
139
+ hotkey: str,
140
+ model_metadata: ModelMetadata,
141
+ ) -> None:
142
+ """Notifies the tracker that a miner has had their associated model updated.
143
+
144
+ Args:
145
+ hotkey (str): The miner's hotkey.
146
+ model_metadata (ModelMetadata): The latest model metadata of the miner.
147
+ """
148
+ with self.lock:
149
+ if hotkey in self.miner_hotkey_to_model_metadata_dict:
150
+ old_metadata = self.miner_hotkey_to_model_metadata_dict[hotkey]
151
+ self.old_model_metadata.append((hotkey, old_metadata))
152
+
153
+ self.miner_hotkey_to_model_metadata_dict[hotkey] = model_metadata
154
+ self.miner_hotkey_to_last_touched_dict[hotkey] = datetime.datetime.now()
155
+
156
+ bt.logging.trace(f"Updated Miner {hotkey}. ModelMetadata={model_metadata}.")
157
+
158
+ def touch_miner_model(self, hotkey: str) -> None:
159
+ """Notifies the tracker that a miner has been touched."""
160
+
161
+ now = datetime.datetime.now()
162
+ with self.lock:
163
+ self.miner_hotkey_to_last_touched_dict[hotkey] = now
164
+
165
+ bt.logging.trace(f"Touched Miner {hotkey}. datetime={now}.")
166
+
167
+ def touch_all_miner_models(self) -> None:
168
+ """Touch all miner models."""
169
+
170
+ now = datetime.datetime.now()
171
+ with self.lock:
172
+ for hotkey in list(self.miner_hotkey_to_model_metadata_dict.keys()):
173
+ self.miner_hotkey_to_last_touched_dict[hotkey] = now
174
+
175
+ bt.logging.trace(f"Touched All Miners. datetime={now}.")
176
+
177
+ def update_model_hash(self, hotkey: str, new_model_hash: str) -> bool:
178
+ """
179
+ Update the model_hash for a given hotkey.
180
+
181
+ Args:
182
+ hotkey (str): The miner's hotkey.
183
+ new_model_hash (str): The new model hash to be set.
184
+
185
+ Returns:
186
+ bool: True if the update was successful, False if the hotkey was not found.
187
+ """
188
+ with self.lock:
189
+ self.miner_hotkey_to_model_hash_dict[hotkey] = new_model_hash
190
+ return True
191
+
192
+ def calculate_file_hash(self, file_path: str) -> str:
193
+ """Calculate SHA1 hash of a file."""
194
+ sha1 = hashlib.sha1()
195
+ with open(file_path, 'rb') as f:
196
+ while True:
197
+ data = f.read(65536) # Read in 64kb chunks
198
+ if not data:
199
+ break
200
+ sha1.update(data)
201
+ return sha1.hexdigest()
202
+
203
+ def is_model_unique(self, hotkey_to_check: str, block_to_check: int, model_checkpoint_path: str) -> bool:
204
+ """Check if a model with a given model_hash is already in use."""
205
+ # generate hash from model_checkpoint_path
206
+ hash_to_check = self.calculate_file_hash(model_checkpoint_path)
207
+
208
+ with self.lock:
209
+ for hotkey, metadata in self.miner_hotkey_to_model_metadata_dict.items():
210
+ if hotkey == hotkey_to_check or hotkey not in self.miner_hotkey_to_model_hash_dict:
211
+ continue
212
+
213
+ if self.miner_hotkey_to_model_hash_dict[hotkey] == hash_to_check and metadata.block < block_to_check:
214
+ bt.logging.warning(
215
+ f"*** Model with hash {hash_to_check} on block {block_to_check} is not unique. Already in use by {hotkey} on block {metadata.block} for model {metadata.id.namespace}/{metadata.id.name}. ***"
216
+ )
217
+ # Update the model hash for the hotkey
218
+ self.update_model_hash(hotkey_to_check, hash_to_check)
219
+ return False, hash_to_check
220
+
221
+ # Update the model hash for the hotkey
222
+ self.update_model_hash(hotkey_to_check, hash_to_check)
223
+ return True, hash_to_check
224
+
model/model_updater.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bittensor as bt
2
+ from typing import Optional
3
+ from constants import CompetitionParameters, COMPETITION_SCHEDULE
4
+ import constants
5
+ from model.data import ModelMetadata, Model
6
+ from model.model_tracker import ModelTracker
7
+ from model.storage.local_model_store import LocalModelStore
8
+ from model.storage.model_metadata_store import ModelMetadataStore
9
+ from model.storage.remote_model_store import RemoteModelStore
10
+
11
+
12
+ class ModelUpdater:
13
+ """Checks if the currently tracked model for a hotkey matches what the miner committed to the chain."""
14
+
15
+ def __init__(
16
+ self,
17
+ metadata_store: ModelMetadataStore,
18
+ remote_store: RemoteModelStore,
19
+ local_store: LocalModelStore,
20
+ model_tracker: ModelTracker,
21
+ ):
22
+ self.metadata_store = metadata_store
23
+ self.remote_store = remote_store
24
+ self.local_store = local_store
25
+ self.model_tracker = model_tracker
26
+ self.min_block: Optional[int] = None
27
+
28
+ def set_min_block(self, val: Optional[int]):
29
+ self.min_block = val
30
+
31
+ @classmethod
32
+ def get_competition_parameters(cls, id: str) -> Optional[CompetitionParameters]:
33
+ for x in COMPETITION_SCHEDULE:
34
+ if x.competition_id == id:
35
+ return x
36
+ return None
37
+
38
+ async def _get_metadata(self, hotkey: str) -> Optional[ModelMetadata]:
39
+ """Get metadata about a model by hotkey"""
40
+ return await self.metadata_store.retrieve_model_metadata(hotkey)
41
+
42
+ async def sync_model(self, hotkey: str) -> bool:
43
+ """Updates local model for a hotkey if out of sync and returns if it was updated."""
44
+ # Get the metadata for the miner.
45
+ metadata = await self._get_metadata(hotkey)
46
+
47
+ if not metadata:
48
+ bt.logging.trace(
49
+ f"No valid metadata found on the chain for hotkey {hotkey}"
50
+ )
51
+ return False
52
+
53
+ if self.min_block and metadata.block < self.min_block:
54
+ bt.logging.trace(
55
+ f"Skipping model for {hotkey} since it was submitted at block {metadata.block} which is less than the minimum block {self.min_block}"
56
+ )
57
+ return False
58
+
59
+ # Backwards compatability for models submitted before competition id added
60
+ if metadata.id.competition_id is None:
61
+ metadata.id.competition_id = constants.ORIGINAL_COMPETITION_ID
62
+
63
+ parameters = ModelUpdater.get_competition_parameters(metadata.id.competition_id)
64
+ if not parameters:
65
+ bt.logging.trace(
66
+ f"No competition parameters found for {metadata.id.competition_id}"
67
+ )
68
+ return False
69
+
70
+ # Check what model id the model tracker currently has for this hotkey.
71
+ tracker_model_metadata = self.model_tracker.get_model_metadata_for_miner_hotkey(
72
+ hotkey
73
+ )
74
+ if metadata == tracker_model_metadata:
75
+ return False
76
+ bt.logging.debug(f"Syncing model for hotkey {hotkey}")
77
+ # Get the local path based on the local store to download to (top level hotkey path)
78
+ path = self.local_store.get_path(hotkey)
79
+
80
+ # bt.logging.warning(f"Downloading model to {path}")
81
+ # # Otherwise we need to download the new model based on the metadata.
82
+ # model = await self.remote_store.download_model(metadata.id, path, parameters)
83
+ # bt.logging.warning(f"Downloaded model to {path}")
84
+ # # Check that the hash of the downloaded content matches.
85
+ # if model.id.hash != metadata.id.hash:
86
+ # raise ValueError(
87
+ # f"Sync for hotkey {hotkey} failed. Hash of content downloaded from hugging face does not match chain metadata. {metadata}"
88
+ # )
89
+
90
+ # Update the tracker
91
+ self.model_tracker.on_miner_model_updated(hotkey, metadata)
92
+ bt.logging.info(f"Model for hotkey {hotkey} updated to {metadata}")
93
+ return True
model/storage/__init__.py ADDED
File without changes
model/storage/chain/chain_model_metadata_store.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import functools
3
+ import bittensor as bt
4
+ import os
5
+ from model.data import ModelId, ModelMetadata
6
+ import constants
7
+ from model.storage.model_metadata_store import ModelMetadataStore
8
+ from typing import Optional
9
+
10
+ from utilities import utils
11
+
12
+
13
+ class ChainModelMetadataStore(ModelMetadataStore):
14
+ """Chain based implementation for storing and retrieving metadata about a model."""
15
+
16
+ def __init__(
17
+ self,
18
+ subtensor: bt.subtensor,
19
+ subnet_uid: int,
20
+ wallet: Optional[bt.wallet] = None,
21
+ ):
22
+ self.subtensor = subtensor
23
+ self.wallet = (
24
+ wallet # Wallet is only needed to write to the chain, not to read.
25
+ )
26
+ self.subnet_uid = subnet_uid
27
+
28
+ # this is a hacky way to prime the get_metadata function
29
+ SN21_OWNER_KEY = "5GsHpHeCGhxstoEEZTR64VUashnDP4n7ir7LbNdRfXpkMU7R"
30
+ metadata = bt.extrinsics_subpackage.serving.get_metadata(self.subtensor, self.subnet_uid, SN21_OWNER_KEY)
31
+ bt.logging.debug(f"primed get_metadata call successfully: metadata={metadata} (ok to be None)")
32
+
33
+ async def store_model_metadata(self, hotkey: str, model_id: ModelId):
34
+ """Stores model metadata on this subnet for a specific wallet."""
35
+ if self.wallet is None:
36
+ raise ValueError("No wallet available to write to the chain.")
37
+
38
+ # Wrap calls to the subtensor in a subprocess with a timeout to handle potential hangs.
39
+ # partial = functools.partial(
40
+ # self.subtensor.commit,
41
+ # self.wallet,
42
+ # self.subnet_uid,
43
+ # model_id.to_compressed_str(),
44
+ # )
45
+ # utils.run_in_subprocess(partial, 60)
46
+ commit = self.subtensor.commit(self.wallet, self.subnet_uid, model_id.to_compressed_str())
47
+ print(f"success commit: {commit}")
48
+ if not commit:
49
+ raise ValueError("Failed to commit model metadata to the chain.")
50
+ return commit
51
+
52
+ async def retrieve_model_metadata(self, hotkey: str) -> Optional[ModelMetadata]:
53
+ """Retrieves model metadata on this subnet for specific hotkey"""
54
+
55
+ # Wrap calls to the subtensor in a subprocess with a timeout to handle potential hangs.
56
+ partial = functools.partial(
57
+ bt.extrinsics_subpackage.serving.get_metadata, self.subtensor, self.subnet_uid, hotkey
58
+ )
59
+
60
+ metadata = utils.run_in_subprocess(partial, 60)
61
+
62
+ if not metadata:
63
+ return None
64
+
65
+ commitment = metadata["info"]["fields"][0][0]
66
+
67
+ hex_data_tuple = commitment[list(commitment.keys())[0]][0]
68
+
69
+ chain_str = ''.join(chr(num) for num in hex_data_tuple)
70
+
71
+ model_id = None
72
+
73
+ try:
74
+ model_id = ModelId.from_compressed_str(chain_str)
75
+ except:
76
+ # If the metadata format is not correct on the chain then we return None.
77
+ bt.logging.trace(
78
+ f"Failed to parse the metadata on the chain for hotkey {hotkey}."
79
+ )
80
+ return None
81
+
82
+ model_metadata = ModelMetadata(id=model_id, block=metadata["block"])
83
+ return model_metadata
84
+
85
+
86
+ # Can only commit data every ~20 minutes.
87
+ async def test_store_model_metadata():
88
+ """Verifies that the ChainModelMetadataStore can store data on the chain."""
89
+ model_id = ModelId(
90
+ namespace="TestPath", name="TestModel", hash="TestHash1", commit="1.0"
91
+ )
92
+
93
+ # Use a different subnet that does not leverage chain storage to avoid conflicts.
94
+ # TODO switch to a mocked version when it supports commits.
95
+ subtensor = bt.subtensor()
96
+
97
+ # Uses .env configured wallet/hotkey/uid for the test.
98
+ coldkey = os.getenv("TEST_COLDKEY")
99
+ hotkey = os.getenv("TEST_HOTKEY")
100
+ net_uid = int(os.getenv("TEST_SUBNET_UID"))
101
+
102
+ wallet = bt.wallet(name=coldkey, hotkey=hotkey)
103
+
104
+ metadata_store = ChainModelMetadataStore(
105
+ subtensor=subtensor, wallet=wallet, subnet_uid=net_uid
106
+ )
107
+
108
+ # Store the metadata on chain.
109
+ await metadata_store.store_model_metadata(hotkey=hotkey, model_id=model_id)
110
+
111
+ print(f"Finished storing {model_id} on the chain.")
112
+
113
+
114
+ async def test_retrieve_model_metadata():
115
+ """Verifies that the ChainModelMetadataStore can retrieve data from the chain."""
116
+ expected_model_id = ModelId(
117
+ namespace="TestPath", name="TestModel", hash="TestHash1", commit="1.0"
118
+ )
119
+
120
+ # Use a different subnet that does not leverage chain storage to avoid conflicts.
121
+ # TODO switch to a mocked version when it supports commits.
122
+ subtensor = bt.subtensor()
123
+
124
+ # Uses .env configured hotkey/uid for the test.
125
+ net_uid = int(os.getenv("TEST_SUBNET_UID"))
126
+ hotkey_address = os.getenv("TEST_HOTKEY_ADDRESS")
127
+
128
+ # Do not require a wallet for retrieving data.
129
+ metadata_store = ChainModelMetadataStore(
130
+ subtensor=subtensor, wallet=None, subnet_uid=net_uid
131
+ )
132
+
133
+ # Retrieve the metadata from the chain.
134
+ model_metadata = await metadata_store.retrieve_model_metadata(hotkey_address)
135
+
136
+ print(f"Expecting matching model id: {expected_model_id == model_metadata.id}")
137
+
138
+
139
+ # Can only commit data every ~20 minutes.
140
+ async def test_roundtrip_model_metadata():
141
+ """Verifies that the ChainModelMetadataStore can roundtrip data on the chain."""
142
+ model_id = ModelId(
143
+ namespace="TestPath", name="TestModel", hash="TestHash1", commit="1.0"
144
+ )
145
+
146
+ # Use a different subnet that does not leverage chain storage to avoid conflicts.
147
+ # TODO switch to a mocked version when it supports commits.
148
+ subtensor = bt.subtensor()
149
+
150
+ # Uses .env configured wallet/hotkey/uid for the test.
151
+ coldkey = os.getenv("TEST_COLDKEY")
152
+ hotkey = os.getenv("TEST_HOTKEY")
153
+ net_uid = int(os.getenv("TEST_SUBNET_UID"))
154
+
155
+ wallet = bt.wallet(name=coldkey, hotkey=hotkey)
156
+
157
+ metadata_store = ChainModelMetadataStore(
158
+ subtensor=subtensor, wallet=wallet, subnet_uid=net_uid
159
+ )
160
+
161
+ # Store the metadata on chain.
162
+ await metadata_store.store_model_metadata(hotkey=hotkey, model_id=model_id)
163
+
164
+ # May need to use the underlying publish_metadata function with wait_for_inclusion: True to pass here.
165
+ # Otherwise it defaults to False and we only wait for finalization not necessarily inclusion.
166
+
167
+ # Retrieve the metadata from the chain.
168
+ model_metadata = await metadata_store.retrieve_model_metadata(hotkey)
169
+
170
+ print(f"Expecting matching metadata: {model_id == model_metadata.id}")
171
+
172
+
173
+ if __name__ == "__main__":
174
+ # Can only commit data every ~20 minutes.
175
+ # asyncio.run(test_roundtrip_model_metadata())
176
+ # asyncio.run(test_store_model_metadata())
177
+ asyncio.run(test_retrieve_model_metadata())
model/storage/disk/__init__.py ADDED
File without changes
model/storage/disk/disk_model_store.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bittensor as bt
2
+ import datetime
3
+ import os
4
+ from typing import Dict
5
+ from constants import CompetitionParameters
6
+ from model.data import Model, ModelId
7
+ from model.storage.disk import utils
8
+ from model.storage.local_model_store import LocalModelStore
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ from pathlib import Path
11
+
12
+
13
+ class DiskModelStore(LocalModelStore):
14
+ """Local storage based implementation for storing and retrieving a model on disk."""
15
+
16
+ def __init__(self, base_dir: str):
17
+ self.base_dir = base_dir
18
+ os.makedirs(utils.get_local_miners_dir(base_dir), exist_ok=True)
19
+
20
+ def get_path(self, hotkey: str) -> str:
21
+ """Returns the path to where this store would locate this hotkey."""
22
+ return utils.get_local_miner_dir(self.base_dir, hotkey)
23
+
24
+ def store_model(self, hotkey: str, model: Model, hf_model: AutoModelForCausalLM, hf_tokenizer: AutoTokenizer ) -> ModelId:
25
+ """Stores a trained model locally."""
26
+ # get the path to where the model should be stored
27
+ model_dir = os.path.join(self.get_path(hotkey), model.id.name)
28
+ hf_model.save_pretrained(model_dir)
29
+ hf_tokenizer.save_pretrained(model_dir)
30
+ model.local_repo_dir = model_dir
31
+
32
+ return model.id
33
+
34
+
35
+ def retrieve_model(
36
+ self, hotkey: str, model_id: ModelId, model_parameters: CompetitionParameters
37
+ ) -> Model:
38
+ """Retrieves a trained model locally."""
39
+
40
+ # get the path to where the model should be stored
41
+ model_dir = os.path.join(self.get_path(hotkey), model_id.name)
42
+ return Model(id=model_id, local_repo_dir=model_dir)
43
+
44
+ def delete_unreferenced_models(
45
+ self,
46
+ valid_models_by_hotkey: Dict[str, ModelId],
47
+ model_touched_by_hotkey: Dict[str, datetime.datetime],
48
+ grace_period_seconds: int,
49
+ ):
50
+ """Check across all of local storage and delete unreferenced models out of grace period."""
51
+ # TODO: THIS METHOD IS NOT UP TO DATE YET
52
+ raise NotImplementedError("This method is not implemented yet.")
53
+ # Expected directory structure is as follows.
54
+ # self.base_dir/models/hotkey/models--namespace--name/snapshots/commit/config.json + other files.
55
+
56
+ # Create a set of valid model paths up to where we expect to see the actual files.
57
+ valid_model_paths = set()
58
+ for hotkey, model_id in valid_models_by_hotkey.items():
59
+ valid_model_paths.add(
60
+ utils.get_local_model_snapshot_dir(self.base_dir, hotkey, model_id)
61
+ )
62
+
63
+ # For each hotkey path on disk using listdir to go one level deep.
64
+ miners_dir = Path(utils.get_local_miners_dir(self.base_dir))
65
+ hotkey_subfolder_names = [d.name for d in miners_dir.iterdir() if d.is_dir()]
66
+
67
+ for hotkey in hotkey_subfolder_names:
68
+ # Reconstruct the path from the hotkey
69
+ hotkey_path = utils.get_local_miner_dir(self.base_dir, hotkey)
70
+
71
+ # If it is not in valid_hotkeys and out of grace period remove it.
72
+ if hotkey not in valid_models_by_hotkey:
73
+ deleted_hotkey = utils.remove_dir_out_of_grace(
74
+ hotkey_path, grace_period_seconds
75
+ )
76
+ if deleted_hotkey:
77
+ bt.logging.trace(
78
+ f"Removed directory for unreferenced hotkey: {hotkey}."
79
+ )
80
+ else:
81
+ # Check all the models--namespace--name subfolder paths.
82
+ hotkey_dir = Path(hotkey_path)
83
+ model_subfolder_paths = [
84
+ str(d) for d in hotkey_dir.iterdir() if d.is_dir()
85
+ ]
86
+
87
+ # Check all the snapshots subfolder paths
88
+ for model_path in model_subfolder_paths:
89
+ model_dir = Path(model_path)
90
+ snapshot_subfolder_paths = [
91
+ str(d) for d in model_dir.iterdir() if d.is_dir()
92
+ ]
93
+
94
+ # Check all the commit paths.
95
+ for snapshot_path in snapshot_subfolder_paths:
96
+ snapshot_dir = Path(snapshot_path)
97
+ commit_subfolder_paths = [
98
+ str(d) for d in snapshot_dir.iterdir() if d.is_dir()
99
+ ]
100
+
101
+ # Reached the end. Check all the actual commit subfolders for the files.
102
+ for commit_path in commit_subfolder_paths:
103
+ if commit_path not in valid_model_paths:
104
+ deleted_model = utils.remove_dir_out_of_grace(
105
+ commit_path, grace_period_seconds
106
+ )
107
+ if deleted_model:
108
+ bt.logging.trace(
109
+ f"Removing directory for unreferenced model at: {commit_path}."
110
+ )
111
+ else:
112
+ last_touched = model_touched_by_hotkey.get(hotkey)
113
+ if last_touched is not None:
114
+ deleted_model = (
115
+ utils.remove_dir_out_of_grace_by_datetime(
116
+ commit_path,
117
+ grace_period_seconds,
118
+ last_touched,
119
+ )
120
+ )
121
+ if deleted_model:
122
+ bt.logging.trace(
123
+ f"Removing directory for stale model at: {commit_path}."
124
+ )
model/storage/disk/utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import datetime
3
+ import hashlib
4
+ import os
5
+ import shutil
6
+ import sys
7
+ from model.data import ModelId
8
+
9
+
10
+ def get_local_miners_dir(base_dir: str) -> str:
11
+ return os.path.join(base_dir, "models")
12
+
13
+
14
+ def get_local_miner_dir(base_dir: str, hotkey: str) -> str:
15
+ return os.path.join(get_local_miners_dir(base_dir), hotkey)
16
+
17
+
18
+ # Hugging face stores models under models--namespace--name/snapshots/commit when downloading.
19
+ def get_local_model_dir(base_dir: str, hotkey: str, model_id: ModelId) -> str:
20
+ return os.path.join(
21
+ get_local_miner_dir(base_dir, hotkey),
22
+ "models" + "--" + model_id.namespace + "--" + model_id.name,
23
+ )
24
+
25
+
26
+ def get_local_model_snapshot_dir(base_dir: str, hotkey: str, model_id: ModelId) -> str:
27
+ return os.path.join(
28
+ get_local_model_dir(base_dir, hotkey, model_id),
29
+ "snapshots",
30
+ model_id.commit,
31
+ )
32
+
33
+
34
+ def get_hf_download_path(local_path: str, model_id: ModelId) -> str:
35
+ return os.path.join(
36
+ local_path,
37
+ "models" + "--" + model_id.namespace + "--" + model_id.name,
38
+ "snapshots",
39
+ model_id.commit,
40
+ )
41
+
42
+
43
+ def get_newest_datetime_under_path(path: str) -> datetime.datetime:
44
+ newest_filetime = sys.maxsize
45
+
46
+ # Check to see if any file at any level was modified more recently than the current one.
47
+ for cur_path, dirnames, filenames in os.walk(path):
48
+ for filename in filenames:
49
+ path = os.path.join(cur_path, filename)
50
+ try:
51
+ mod_time = os.stat(path).st_mtime
52
+ if mod_time < newest_filetime:
53
+ newest_filetime = mod_time
54
+ except:
55
+ pass
56
+
57
+ if newest_filetime == sys.maxsize:
58
+ return datetime.datetime.max
59
+
60
+ return datetime.datetime.fromtimestamp(newest_filetime)
61
+
62
+
63
+ def remove_dir_out_of_grace_by_datetime(path: str, grace_period_seconds: int, last_modified: datetime.datetime) -> bool:
64
+ """Removes a dir if the last modified time is out of grace period secs. Returns if it was deleted."""
65
+ grace = datetime.timedelta(seconds=grace_period_seconds)
66
+
67
+ if last_modified < datetime.datetime.now() - grace:
68
+ shutil.rmtree(path=path, ignore_errors=True)
69
+ return True
70
+
71
+ return False
72
+
73
+ def remove_dir_out_of_grace(path: str, grace_period_seconds: int) -> bool:
74
+ """Removes a dir if the last modified time is out of grace period secs. Returns if it was deleted."""
75
+ last_modified = get_newest_datetime_under_path(path)
76
+ return remove_dir_out_of_grace_by_datetime(path, grace_period_seconds, last_modified)
77
+
78
+
79
+ def realize_symlinks_in_directory(path: str) -> int:
80
+ """Realizes all symlinks in the given directory, moving the linked file to the location. Returns count removed."""
81
+ realized_symlinks = 0
82
+
83
+ for cur_path, dirnames, filenames in os.walk(path):
84
+ for filename in filenames:
85
+ path = os.path.abspath(os.path.join(cur_path, filename))
86
+ # Get path resolving symlinks if encountered
87
+ real_path = os.path.realpath(path)
88
+ # If different then move
89
+ if path != real_path:
90
+ realized_symlinks += 1
91
+ shutil.move(real_path, path)
92
+
93
+ return realized_symlinks
94
+
95
+
96
+ def get_hash_of_file(path: str) -> str:
97
+ blocksize = 64 * 1024
98
+ file_hash = hashlib.sha256()
99
+ with open(path, "rb") as fp:
100
+ while True:
101
+ data = fp.read(blocksize)
102
+ if not data:
103
+ break
104
+ file_hash.update(data)
105
+ return base64.b64encode(file_hash.digest()).decode("utf-8")
106
+
107
+
108
+ def get_hash_of_directory(path: str) -> str:
109
+ dir_hash = hashlib.sha256()
110
+
111
+ # Recursively walk everything under the directory for files.
112
+ for cur_path, dirnames, filenames in os.walk(path):
113
+ # Ensure we walk future directories in a consistent order.
114
+ dirnames.sort()
115
+ # Ensure we walk files in a consistent order.
116
+ for filename in sorted(filenames):
117
+ path = os.path.join(cur_path, filename)
118
+ file_hash = get_hash_of_file(path)
119
+ dir_hash.update(file_hash.encode())
120
+
121
+ return base64.b64encode(dir_hash.digest()).decode("utf-8")
model/storage/eval_leaderboard.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import create_engine, Column, Integer, Float, String, DateTime, ForeignKey
2
+ from sqlalchemy.ext.declarative import declarative_base
3
+ from sqlalchemy.orm import Session, sessionmaker, relationship
4
+ from sqlalchemy.exc import OperationalError
5
+ from contextlib import contextmanager
6
+ from datetime import datetime
7
+ import bittensor as bt
8
+ from typing import Optional, Dict, List
9
+ import time
10
+ from vali_api.config import DBHOST, DBNAME, DBUSER, DBPASS
11
+
12
+ Base = declarative_base()
13
+
14
+ # Global variables for engine and Session
15
+ _engine: Optional[object] = None
16
+ Session: Optional[sessionmaker] = None
17
+
18
+ def init_database():
19
+ """Initialize the database connection and create tables."""
20
+ global _engine, Session
21
+
22
+ if _engine is not None:
23
+ bt.logging.warning("Database already initialized")
24
+ return
25
+
26
+ try:
27
+ connection_string = f'mysql://{DBUSER}:{DBPASS}@{DBHOST}/{DBNAME}'
28
+ _engine = create_engine(connection_string)
29
+ Session = sessionmaker(bind=_engine)
30
+ Base.metadata.create_all(_engine)
31
+ bt.logging.info("Database initialized successfully")
32
+ except Exception as e:
33
+ bt.logging.error(f"Failed to initialize database: {str(e)}")
34
+ raise
35
+
36
+ def get_session() -> Session:
37
+ """Get a database session."""
38
+ if Session is None:
39
+ raise RuntimeError("Database not initialized. Call init_database() first.")
40
+ return Session()
41
+
42
+ class EvaluationModel(Base):
43
+ __tablename__ = 'sn21_evals_test'
44
+
45
+ eval_id = Column(Integer, primary_key=True)
46
+ miner_hotkey = Column(String(255))
47
+ miner_uid = Column(Integer)
48
+ model_name = Column(String(255))
49
+ model_type = Column(String(255))
50
+ eval_date = Column(DateTime)
51
+ competition_id = Column(String(10))
52
+
53
+ results = relationship("EvaluationResult", back_populates="evaluation")
54
+
55
+ class EvaluationResult(Base):
56
+ __tablename__ = 'sn21_eval_results_test'
57
+
58
+ eval_result_id = Column(Integer, primary_key=True)
59
+ eval_id = Column(Integer, ForeignKey('sn21_evals_test.eval_id'))
60
+ task = Column(String(255))
61
+ result_name = Column(String(255))
62
+ result = Column(Float)
63
+ competition_id = Column(String(10))
64
+ deleted_at = Column(DateTime)
65
+
66
+ evaluation = relationship("EvaluationModel", back_populates="results")
67
+
68
+ class EvalLeaderboardManager:
69
+ def __init__(self, max_retries=3, retry_delay=1):
70
+ if Session is None:
71
+ raise RuntimeError("Database not initialized. Call init_database() first.")
72
+
73
+ self.session = get_session()
74
+ self.max_retries = max_retries
75
+ self.retry_delay = retry_delay
76
+
77
+ @contextmanager
78
+ def session_scope(self):
79
+ """Provide a transactional scope around a series of operations."""
80
+ session = get_session()
81
+ try:
82
+ yield session
83
+ session.commit()
84
+ except Exception as e:
85
+ session.rollback()
86
+ raise
87
+ finally:
88
+ session.close()
89
+
90
+ def execute_with_retry(self, operation, *args, **kwargs):
91
+ """Execute database operation with retry logic."""
92
+ for attempt in range(self.max_retries):
93
+ try:
94
+ return operation(*args, **kwargs)
95
+ except OperationalError as e:
96
+ if attempt < self.max_retries - 1:
97
+ bt.logging.warning(f"Database error. Attempt {attempt + 1}/{self.max_retries}. Retrying...")
98
+ time.sleep(self.retry_delay)
99
+ else:
100
+ raise
101
+ except Exception as e:
102
+ bt.logging.error(f"Error executing operation: {str(e)}")
103
+ raise
104
+
105
+
106
+ def get_metrics_timeseries(self) -> Dict[str, List[Dict]]:
107
+ """
108
+ Get time series data for all turn metrics.
109
+ Returns data in format {metric_name: [{date: "", models: [{modelType: "", modelName: "", score: 123}, ...]}, ...]}
110
+ Groups models by model_type and includes model names in the data.
111
+ """
112
+ def _get_timeseries():
113
+ with self.session_scope() as session:
114
+ evals = session.query(EvaluationModel).order_by(EvaluationModel.eval_date).all()
115
+
116
+ metrics = set()
117
+ # Get all non-deleted results that aren't the excluded metric
118
+ sample_results = session.query(EvaluationResult).filter(
119
+ EvaluationResult.result_name != "exact_match_stderr,flexible-extract",
120
+ EvaluationResult.deleted_at.is_(None) # SQLAlchemy's proper way to check for NULL
121
+ ).all()
122
+ for result in sample_results:
123
+ metric_name = result.result_name
124
+ metrics.add(metric_name)
125
+ timeseries_data = {metric: {} for metric in metrics}
126
+
127
+ for eval_ in evals:
128
+ date_str = eval_.eval_date.strftime('%Y-%m-%d')
129
+ model_type = eval_.model_type
130
+ model_name = eval_.model_name
131
+ competition_id = eval_.competition_id
132
+
133
+ # Only process non-deleted results
134
+ for result in [r for r in eval_.results if r.deleted_at is None]:
135
+ metric_name = result.result_name
136
+ if metric_name in metrics:
137
+ if date_str not in timeseries_data[metric_name]:
138
+ timeseries_data[metric_name][date_str] = {
139
+ 'date': date_str,
140
+ 'models': []
141
+ }
142
+
143
+ # Check if we already have an entry for this model_type
144
+ existing_model = next(
145
+ (m for m in timeseries_data[metric_name][date_str]['models']
146
+ if m['modelType'] == model_type),
147
+ None
148
+ )
149
+
150
+ if existing_model:
151
+ # Update existing entry
152
+ existing_model['score'] = result.result
153
+ if model_name not in existing_model['modelName']:
154
+ existing_model['modelName'] = model_name
155
+ else:
156
+ # Add new entry
157
+ timeseries_data[metric_name][date_str]['models'].append({
158
+ 'modelType': model_type,
159
+ 'modelName': model_name,
160
+ 'score': result.result,
161
+ 'competition_id': competition_id,
162
+ 'task': result.task
163
+ })
164
+
165
+ return {
166
+ metric: [
167
+ data_point for data_point in sorted(data.values(), key=lambda x: x['date'])
168
+ if data_point['models'] # Filter out entries with empty models list
169
+ ]
170
+ for metric, data in timeseries_data.items()
171
+ }
172
+
173
+ try:
174
+ return self.execute_with_retry(_get_timeseries)
175
+ except Exception as e:
176
+ bt.logging.error(f"Failed to get metrics timeseries: {str(e)}")
177
+ return {}
178
+
179
+
180
+ def close(self):
181
+ """Safely close the session."""
182
+ try:
183
+ self.session.close()
184
+ except:
185
+ pass
model/storage/hugging_face/__init__.py ADDED
File without changes
model/storage/hugging_face/hugging_face_model_store.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from omegaconf import OmegaConf
3
+ from huggingface_hub import HfApi
4
+ from model.data import Model, ModelId
5
+ from model.storage.disk import utils
6
+ from constants import CompetitionParameters, MAX_HUGGING_FACE_BYTES
7
+
8
+ from model.storage.remote_model_store import RemoteModelStore
9
+ from huggingface_hub import HfApi, file_exists
10
+ from collections import defaultdict
11
+
12
+
13
+ MODEL_FILE_PT = "meta_model_{epoch}.pt"
14
+ ADAPTER_FILE_PT = "adapter_{epoch}.pt"
15
+ CONFIG_FILE = "training_config.yml"
16
+ CONFIG_FILE_MOSHI = "config.yaml"
17
+ README_FILE = "README.md"
18
+ HOTKEY_FILE = "hotkey.txt"
19
+
20
+ ### MOSHI ###
21
+ LM_FILE_PT_MOSHI = "model.safetensors"
22
+ MIMI_FILE_PT_MOSHI = "tokenizer-e351c8d8-checkpoint125.safetensors"
23
+ TOKENIZER_FILE_MOSHI = "tokenizer_spm_32k_3.model"
24
+
25
+
26
+ def check_config(ckpt_dir):
27
+ config_file = os.path.join(ckpt_dir, CONFIG_FILE)
28
+ cfg = OmegaConf.load(config_file)
29
+ if cfg.model.use_clip:
30
+ raise ValueError("Cannot upload checkpoints with CLIP embeddings")
31
+
32
+
33
+ def get_required_files(epoch: int, model_type: str):
34
+ if model_type == "o1":
35
+ return [
36
+ MODEL_FILE_PT.format(epoch=epoch),
37
+ CONFIG_FILE,
38
+ ]
39
+ elif model_type == "v1":
40
+ return [
41
+ LM_FILE_PT_MOSHI,
42
+ MIMI_FILE_PT_MOSHI,
43
+ TOKENIZER_FILE_MOSHI,
44
+ CONFIG_FILE_MOSHI
45
+ ]
46
+
47
+
48
+ def export_readme(ckpt_dir: str):
49
+ readme_file = os.path.join(ckpt_dir, README_FILE)
50
+ with open(readme_file, "w") as f:
51
+ f.write(
52
+ f"""---
53
+ license: mit
54
+ tags:
55
+ - any-to-any
56
+ - omega
57
+ - omegalabs
58
+ - bittensor
59
+ - agi
60
+ ---
61
+
62
+ This is an Any-to-Any model checkpoint for the OMEGA Labs x Bittensor Any-to-Any subnet.
63
+
64
+ Check out the [git repo](https://github.com/omegalabsinc/omegalabs-anytoany-bittensor) and find OMEGA on X: [@omegalabsai](https://x.com/omegalabsai).
65
+ """
66
+ )
67
+
68
+
69
+ def export_hotkey(ckpt_dir: str, hotkey: str):
70
+ hotkey_file = os.path.join(ckpt_dir, HOTKEY_FILE)
71
+ with open(hotkey_file, "w") as f:
72
+ f.write(hotkey)
73
+
74
+
75
+ class HuggingFaceModelStore(RemoteModelStore):
76
+ """Hugging Face based implementation for storing and retrieving a model."""
77
+
78
+ @classmethod
79
+ def assert_access_token_exists(cls) -> str:
80
+ """Asserts that the access token exists."""
81
+ if not os.getenv("HF_ACCESS_TOKEN"):
82
+ raise ValueError("No Hugging Face access token found to write to the hub.")
83
+ return os.getenv("HF_ACCESS_TOKEN")
84
+
85
+
86
+ async def upload_model(
87
+ self, model: Model,
88
+ competition_parameters: CompetitionParameters,
89
+ hotkey: str
90
+ ) -> ModelId:
91
+ """Uploads a trained model to Hugging Face."""
92
+ token = HuggingFaceModelStore.assert_access_token_exists()
93
+ api = HfApi(token=token)
94
+ export_readme(model.local_repo_dir)
95
+ export_hotkey(model.local_repo_dir, hotkey)
96
+ hf_repo_id = model.id.namespace + "/" + model.id.name
97
+ api.create_repo(
98
+ repo_id=hf_repo_id,
99
+ exist_ok=True,
100
+ private=True,
101
+ )
102
+ commit_info = api.upload_folder(repo_id=hf_repo_id, folder_path=model.local_repo_dir, use_auth_token=True)
103
+
104
+ print(f"Successfully uploaded model repository '{model.local_repo_dir}' to {hf_repo_id}")
105
+
106
+ model_id_with_commit = ModelId(
107
+ namespace=model.id.namespace,
108
+ name=model.id.name,
109
+ epoch=model.id.epoch,
110
+ hash=model.id.hash,
111
+ commit=commit_info.oid,
112
+ competition_id=model.id.competition_id,
113
+ )
114
+
115
+ return model_id_with_commit
116
+ # # TODO consider skipping the redownload if a hash is already provided.
117
+ # # To get the hash we need to redownload it at a local tmp directory after which it can be deleted.
118
+ # with tempfile.TemporaryDirectory() as temp_dir:
119
+ # model_with_hash = await self.download_model(
120
+ # model_id_with_commit, temp_dir, competition_parameters
121
+ # )
122
+ # # Return a ModelId with both the correct commit and hash.
123
+ # return model_with_hash.id
124
+
125
+ async def download_model(
126
+ self,
127
+ model_id: ModelId,
128
+ local_path: str,
129
+ model_parameters: CompetitionParameters,
130
+ ) -> Model:
131
+ """Retrieves a trained model from Hugging Face."""
132
+ if not model_id.commit:
133
+ raise ValueError("No Hugging Face commit id found to read from the hub.")
134
+
135
+ repo_id = model_id.namespace + "/" + model_id.name
136
+
137
+ # Check ModelInfo for the size of model.safetensors file before downloading.
138
+ try:
139
+ token = HuggingFaceModelStore.assert_access_token_exists()
140
+ except:
141
+ token = None
142
+ api = HfApi(token=token)
143
+ model_info = api.model_info(
144
+ repo_id=repo_id, revision=model_id.commit, timeout=10, files_metadata=True
145
+ )
146
+ size = sum(repo_file.size for repo_file in model_info.siblings)
147
+ if size > MAX_HUGGING_FACE_BYTES:
148
+ raise ValueError(
149
+ f"Hugging Face repo over maximum size limit. Size {size}. Limit {MAX_HUGGING_FACE_BYTES}."
150
+ )
151
+
152
+ api.hf_hub_download(
153
+ repo_id=repo_id,
154
+ revision=model_id.commit,
155
+ filename="checkpoint.safetensors",
156
+ cache_dir=local_path,
157
+ )
158
+
159
+ # Get the directory the model was stored to.
160
+ model_dir = utils.get_hf_download_path(local_path, model_id)
161
+
162
+ # Realize all symlinks in that directory since Transformers library does not support avoiding symlinks.
163
+ utils.realize_symlinks_in_directory(model_dir)
164
+
165
+ # Compute the hash of the downloaded model.
166
+ model_hash = utils.get_hash_of_directory(model_dir)
167
+ model_id_with_hash = ModelId(
168
+ namespace=model_id.namespace,
169
+ name=model_id.name,
170
+ commit=model_id.commit,
171
+ hash=model_hash,
172
+ competition_id=model_id.competition_id,
173
+ )
174
+
175
+ return Model(id=model_id_with_hash, ckpt=model_dir)
model/storage/local_model_store.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Dict
3
+ from model.data import Model, ModelId
4
+ from constants import CompetitionParameters
5
+
6
+
7
+ class LocalModelStore(abc.ABC):
8
+ """An abstract base class for storing and retrieving a pre trained model locally."""
9
+
10
+ @abc.abstractmethod
11
+ def store_model(self, hotkey: str, model: Model) -> ModelId:
12
+ """Stores a trained model in the appropriate location based on implementation."""
13
+ pass
14
+
15
+ @abc.abstractmethod
16
+ def get_path(self, hotkey: str) -> str:
17
+ """Returns the path to the appropriate location based on implementation."""
18
+ pass
19
+
20
+ @abc.abstractmethod
21
+ def retrieve_model(self, hotkey: str, model_id: ModelId, parameters: CompetitionParameters) -> Model:
22
+ """Retrieves a trained model from the appropriate location based on implementation."""
23
+ pass
24
+
25
+ @abc.abstractmethod
26
+ def delete_unreferenced_models(
27
+ self, valid_models_by_hotkey: Dict[str, ModelId], grace_period_seconds: int
28
+ ):
29
+ """Check across all of local storage and delete unreferenced models out of grace period."""
30
+ pass
model/storage/model_metadata_store.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Optional
3
+ from model.data import ModelId, ModelMetadata
4
+
5
+
6
+ class ModelMetadataStore(abc.ABC):
7
+ """An abstract base class for storing and retrieving model metadata."""
8
+
9
+ @abc.abstractmethod
10
+ async def store_model_metadata(self, hotkey: str, model_id: ModelId):
11
+ """Stores model metadata on this subnet for a specific miner."""
12
+ pass
13
+
14
+ @abc.abstractmethod
15
+ async def retrieve_model_metadata(self, hotkey: str) -> Optional[ModelMetadata]:
16
+ """Retrieves model metadata + block information on this subnet for specific miner, if present"""
17
+ pass
model/storage/mysql_model_queue.py ADDED
@@ -0,0 +1,847 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import create_engine, Column, Integer, Float, String, DateTime, Boolean, JSON, func, desc, exists, ForeignKey, or_, and_, case, text
2
+ from sqlalchemy.exc import OperationalError, SQLAlchemyError
3
+ from sqlalchemy.ext.declarative import declarative_base
4
+ from sqlalchemy.dialects.mysql import JSON as MySQLJSON
5
+ from sqlalchemy.orm import Session, sessionmaker, aliased, relationship
6
+ from contextlib import contextmanager
7
+ from collections import defaultdict
8
+
9
+ import time
10
+ import json
11
+
12
+ from datetime import datetime, timedelta, timezone
13
+ import bittensor as bt
14
+ from typing import Optional
15
+
16
+ from model.data import ModelId
17
+ from vali_api.config import DBHOST, DBNAME, DBUSER, DBPASS, IS_PROD
18
+
19
+ Base = declarative_base()
20
+
21
+ # Global variables for engine and Session
22
+ _engine: Optional[object] = None
23
+ Session: Optional[sessionmaker] = None
24
+
25
+ def init_database():
26
+ """
27
+ Initialize the database connection and create tables.
28
+ Must be called before using any database operations.
29
+ """
30
+ global _engine, Session
31
+
32
+ if _engine is not None:
33
+ bt.logging.warning("Database already initialized")
34
+ return
35
+
36
+ # Try different MySQL drivers in order of preference
37
+ drivers_to_try = [
38
+ ('mysql', 'mysqlclient (MySQLdb)'),
39
+ ('mysql+pymysql', 'PyMySQL')
40
+ ]
41
+
42
+ for driver, driver_name in drivers_to_try:
43
+ try:
44
+ connection_string = f'{driver}://{DBUSER}:{DBPASS}@{DBHOST}/{DBNAME}'
45
+ bt.logging.info(f"Attempting database connection with {driver_name}")
46
+
47
+ _engine = create_engine(connection_string)
48
+ Session = sessionmaker(bind=_engine)
49
+
50
+ # Test the connection
51
+ with Session() as session:
52
+ session.execute(text('SELECT 1'))
53
+
54
+ # Create all tables
55
+ Base.metadata.create_all(_engine)
56
+ bt.logging.info(f"Database initialized successfully with {driver_name}")
57
+ return
58
+
59
+ except ImportError as e:
60
+ bt.logging.warning(f"Driver {driver_name} not available: {e}")
61
+ continue
62
+ except Exception as e:
63
+ bt.logging.error(f"Failed to connect with {driver_name}: {e}")
64
+ if driver == drivers_to_try[-1][0]: # Last driver in list
65
+ raise
66
+ continue
67
+
68
+ raise RuntimeError("Failed to initialize database with any available MySQL driver")
69
+
70
+ def get_session() -> Session:
71
+ """
72
+ Get a database session. Raises exception if database not initialized.
73
+ """
74
+ if Session is None:
75
+ raise RuntimeError("Database not initialized. Call init_database() first.")
76
+ return Session()
77
+
78
+ def get_table_name(base_name: str) -> str:
79
+ """Helper function to get the correct table name with suffix if not in production."""
80
+ return f"{base_name}{'_test' if not IS_PROD else ''}"
81
+
82
+ class ModelQueue(Base):
83
+ __tablename__ = get_table_name('sn21_model_queue')
84
+
85
+ hotkey = Column(String(255), primary_key=True)
86
+ uid = Column(String(255), primary_key=True, index=True)
87
+ block = Column(Integer, index=True)
88
+ competition_id = Column(String(255), index=True)
89
+ model_metadata = Column(JSON)
90
+ is_new = Column(Boolean, default=True)
91
+ is_being_scored = Column(Boolean, default=False)
92
+ is_being_scored_by = Column(String(255), default=None)
93
+ scoring_updated_at = Column(DateTime, default=None)
94
+ updated_at = Column(DateTime, default=datetime.utcnow)
95
+
96
+ # Relationship to use dynamic table name (lambda function)
97
+ scores = relationship(
98
+ "ScoreHistory",
99
+ back_populates="model",
100
+ foreign_keys="[ScoreHistory.hotkey, ScoreHistory.uid]",
101
+ primaryjoin=lambda: and_(
102
+ ModelQueue.hotkey == ScoreHistory.hotkey,
103
+ ModelQueue.uid == ScoreHistory.uid
104
+ )
105
+ )
106
+
107
+ def __repr__(self):
108
+ return f"<ModelQueue(hotkey='{self.hotkey}', uid='{self.uid}', competition_id='{self.competition_id}', is_new={self.is_new})>"
109
+
110
+ class ScoreHistory(Base):
111
+ __tablename__ = get_table_name('sn21_score_history')
112
+
113
+ id = Column(Integer, primary_key=True)
114
+ hotkey = Column(String(255), ForeignKey(f"{get_table_name('sn21_model_queue')}.hotkey", ondelete='SET NULL'), index=True, nullable=True)
115
+ uid = Column(String(255), ForeignKey(f"{get_table_name('sn21_model_queue')}.uid", ondelete='SET NULL'), index=True, nullable=True)
116
+ competition_id = Column(String(255), index=True)
117
+ model_metadata = Column(JSON)
118
+ score = Column(Float)
119
+ scored_at = Column(DateTime, default=datetime.utcnow)
120
+ block = Column(Integer)
121
+ model_hash = Column(String(255))
122
+ scorer_hotkey = Column(String(255), index=True)
123
+ is_archived = Column(Boolean, default=False)
124
+ metric_scores = Column(MySQLJSON, nullable=True)
125
+ wandb_run_id = Column(String(255), nullable=True)
126
+ wandb_run_url = Column(String(512), nullable=True)
127
+ # Relationship to ModelQueue using dynamic table name (lambda function)
128
+ model = relationship(
129
+ "ModelQueue",
130
+ back_populates="scores",
131
+ foreign_keys=[hotkey, uid],
132
+ primaryjoin=lambda: and_(
133
+ ModelQueue.hotkey == ScoreHistory.hotkey,
134
+ ModelQueue.uid == ScoreHistory.uid
135
+ )
136
+ )
137
+
138
+ def __repr__(self):
139
+ return f"<ScoreHistory(hotkey='{self.hotkey}', uid='{self.uid}', score={self.score}, scored_at={self.scored_at}, model_metadata={self.model_metadata} is_archived={self.is_archived})>"
140
+
141
+ class ModelIdEncoder(json.JSONEncoder):
142
+ def default(self, obj):
143
+ if isinstance(obj, ModelId):
144
+ return {
145
+ 'namespace': obj.namespace,
146
+ 'name': obj.name,
147
+ 'epoch': obj.epoch,
148
+ 'commit': obj.commit,
149
+ 'hash': obj.hash,
150
+ 'competition_id': obj.competition_id
151
+ }
152
+ return super().default(obj)
153
+
154
+ class ModelQueueManager:
155
+ def __init__(self, max_scores_per_model=5, rescore_interval_hours=24, max_retries=3, retry_delay=1):
156
+ if Session is None:
157
+ raise RuntimeError("Database not initialized. Call init_database() first.")
158
+
159
+ self.session = get_session()
160
+ self.max_scores_per_model = max_scores_per_model
161
+ self.rescore_interval = timedelta(hours=rescore_interval_hours)
162
+ self.max_retries = max_retries
163
+ self.retry_delay = retry_delay
164
+
165
+ @contextmanager
166
+ def session_scope(self):
167
+ """Provide a transactional scope around a series of operations."""
168
+ session = get_session()
169
+ try:
170
+ yield session
171
+ session.commit()
172
+ except Exception as e:
173
+ session.rollback()
174
+ raise
175
+ finally:
176
+ session.close()
177
+
178
+ def reset_session(self):
179
+ """Reset the session in case of connection issues."""
180
+ try:
181
+ self.session.close()
182
+ except:
183
+ pass
184
+ try:
185
+ self.session = get_session()
186
+ except RuntimeError as e:
187
+ bt.logging.error(f"Failed to reset session: {str(e)}")
188
+ raise
189
+
190
+ def execute_with_retry(self, operation, *args, **kwargs):
191
+ """Execute an operation with retry logic."""
192
+ for attempt in range(self.max_retries):
193
+ try:
194
+ return operation(*args, **kwargs)
195
+ except OperationalError as e:
196
+ if "Lost connection" in str(e) and attempt < self.max_retries - 1:
197
+ bt.logging.warning(f"Lost connection to MySQL. Attempt {attempt + 1}/{self.max_retries}. Retrying...")
198
+ self.reset_session()
199
+ time.sleep(self.retry_delay)
200
+ else:
201
+ raise
202
+ except SQLAlchemyError as e:
203
+ if attempt < self.max_retries - 1:
204
+ bt.logging.warning(f"Database error. Attempt {attempt + 1}/{self.max_retries}. Retrying...")
205
+ self.reset_session()
206
+ time.sleep(self.retry_delay)
207
+ else:
208
+ raise
209
+
210
+ def store_updated_model(self, uid, hotkey, model_metadata, updated):
211
+ """
212
+ Store or update model metadata with retry logic.
213
+
214
+ Args:
215
+ uid (str): Model UID
216
+ hotkey (str): Model hotkey
217
+ model_metadata: Model metadata object
218
+ updated (bool): Whether this is an update
219
+
220
+ Returns:
221
+ bool: Success status
222
+ """
223
+ def _store_model():
224
+ with self.session_scope() as session:
225
+ try:
226
+ # Query existing model with lock
227
+ existing_model = session.query(ModelQueue).filter_by(
228
+ hotkey=hotkey,
229
+ uid=uid
230
+ ).with_for_update().first()
231
+
232
+ # Serialize metadata
233
+ serialized_metadata = json.dumps(model_metadata.__dict__, cls=ModelIdEncoder)
234
+
235
+ if existing_model:
236
+ if existing_model.model_metadata != serialized_metadata or existing_model.block != model_metadata.block:
237
+ bt.logging.debug(f"Updating existing model metadata for UID={uid}, Hotkey={hotkey}. Old metadata: {existing_model.model_metadata}, New metadata: {serialized_metadata}")
238
+ existing_model.model_metadata = serialized_metadata
239
+ existing_model.is_new = True
240
+ existing_model.block = model_metadata.block
241
+ existing_model.updated_at = datetime.utcnow()
242
+ else:
243
+ # Create new model entry
244
+ new_model = ModelQueue(
245
+ hotkey=hotkey,
246
+ uid=uid,
247
+ competition_id=model_metadata.id.competition_id,
248
+ model_metadata=serialized_metadata,
249
+ is_new=True,
250
+ block=model_metadata.block
251
+ )
252
+ session.add(new_model)
253
+ bt.logging.debug(f"Stored new model for UID={uid}, Hotkey={hotkey} in database. Is new = {updated}")
254
+
255
+ return True
256
+
257
+ except Exception as e:
258
+ bt.logging.error(f"Error in _store_model: {str(e)}")
259
+ bt.logging.error(f"Model metadata: {model_metadata}")
260
+ raise
261
+
262
+ try:
263
+ return self.execute_with_retry(_store_model)
264
+ except Exception as e:
265
+ bt.logging.error(f"Failed to store model after {self.max_retries} attempts: {str(e)}")
266
+ return False
267
+
268
+ def get_next_model_to_score(self, competition_id: str):
269
+ """
270
+ Get next model to score with retry logic.
271
+
272
+ The updated prioritization logic ensures:
273
+ 1. New models (highest priority)
274
+ 2. Models never scored with non-zero scores
275
+ 3a. High-scoring models not scored for over a week
276
+ 3b. Models not scored for more than 7 days (safety net for winning models)
277
+ 4. Models eligible by standard criteria (not scored in 5 days or < 5 scores)
278
+ 5. Everything else (lowest priority)
279
+
280
+ Zero-scored models that are frequently scored are downgraded in priority
281
+ to prevent them from consuming too many resources.
282
+ """
283
+ def _get_next_model():
284
+ with self.session_scope() as session:
285
+ try:
286
+ now = datetime.utcnow()
287
+
288
+ # ---- START: Query to find overall highest score ----
289
+ overall_max_score_value = session.query(func.max(ScoreHistory.score)).filter(
290
+ ScoreHistory.competition_id == competition_id,
291
+ ScoreHistory.is_archived == False,
292
+ ScoreHistory.score > 0 # Consider only positive scores as relevant for "highest"
293
+ ).scalar()
294
+
295
+ if overall_max_score_value is not None:
296
+ bt.logging.info(f"Overall highest positive score in competition '{competition_id}' is: {overall_max_score_value:.4f}")
297
+ else:
298
+ bt.logging.info(f"No positive scores found for competition '{competition_id}' to determine an overall highest score.")
299
+ # ---- END: Query ----
300
+
301
+ # Get latest score timestamp and count for each model
302
+ score_subquery = session.query(
303
+ ScoreHistory.hotkey,
304
+ ScoreHistory.uid,
305
+ func.count(ScoreHistory.id).label('score_count'),
306
+ func.max(ScoreHistory.scored_at).label('latest_scored_at'), # Get the latest score timestamp
307
+ func.max(ScoreHistory.score).label('max_score') # Get the maximum score
308
+ ).filter(
309
+ ScoreHistory.is_archived == False,
310
+ ScoreHistory.competition_id == competition_id,
311
+ ScoreHistory.score > 0 # Only consider non-zero scores
312
+ ).group_by(
313
+ ScoreHistory.hotkey,
314
+ ScoreHistory.uid
315
+ ).subquery()
316
+
317
+ # Also track all scores (including zeros) for high-frequency zero score detection
318
+ all_scores_subquery = session.query(
319
+ ScoreHistory.hotkey,
320
+ ScoreHistory.uid,
321
+ func.count(ScoreHistory.id).label('all_score_count'),
322
+ func.max(ScoreHistory.scored_at).label('latest_all_scored_at'),
323
+ func.sum(case((ScoreHistory.score > 0, 1), else_=0)).label('non_zero_count')
324
+ ).filter(
325
+ ScoreHistory.is_archived == False,
326
+ ScoreHistory.competition_id == competition_id
327
+ ).group_by(
328
+ ScoreHistory.hotkey,
329
+ ScoreHistory.uid
330
+ ).subquery()
331
+
332
+ five_days_ago = now - timedelta(days=5)
333
+ weekly_rescore_threshold_time = now - timedelta(days=7) # Define a 7-day threshold
334
+
335
+ # Check if we have new models before proceeding
336
+ have_new_models = session.query(ModelQueue).filter(
337
+ ModelQueue.is_being_scored == False,
338
+ ModelQueue.competition_id == competition_id,
339
+ ModelQueue.is_new == True
340
+ ).first() is not None
341
+
342
+ # Check if we have never-scored models
343
+ never_scored_count = session.query(func.count(ModelQueue.uid)).filter(
344
+ ModelQueue.is_being_scored == False,
345
+ ModelQueue.competition_id == competition_id,
346
+ ~exists().where(
347
+ and_(
348
+ ScoreHistory.hotkey == ModelQueue.hotkey,
349
+ ScoreHistory.uid == ModelQueue.uid,
350
+ ScoreHistory.score > 0
351
+ )
352
+ )
353
+ ).scalar()
354
+
355
+ # If no new models and no never-scored models, prioritize high scoring models not scored recently
356
+ if not have_new_models:
357
+ # ---- START: Modified logic for dynamic high-score threshold ----
358
+ if overall_max_score_value is not None and overall_max_score_value > 0: # Ensure we have a valid max score
359
+ dynamic_high_score_threshold = overall_max_score_value * 0.97
360
+ bt.logging.info(f"Using dynamic high-score threshold for competition '{competition_id}': >= {dynamic_high_score_threshold:.4f} (based on overall max of {overall_max_score_value:.4f})")
361
+
362
+ # First try to get a high-scoring model not scored in over a week
363
+ top_model = session.query(ModelQueue).join(
364
+ score_subquery,
365
+ and_(
366
+ ModelQueue.hotkey == score_subquery.c.hotkey,
367
+ ModelQueue.uid == score_subquery.c.uid
368
+ )
369
+ ).filter(
370
+ ModelQueue.is_being_scored == False,
371
+ ModelQueue.competition_id == competition_id,
372
+ score_subquery.c.latest_scored_at < weekly_rescore_threshold_time,
373
+ score_subquery.c.max_score >= dynamic_high_score_threshold # Use dynamic threshold
374
+ ).order_by(
375
+ score_subquery.c.max_score.desc() # Highest score first
376
+ ).with_for_update().first()
377
+
378
+ if top_model:
379
+ # Create a dictionary with the model's attributes
380
+ model_data = {
381
+ 'hotkey': top_model.hotkey,
382
+ 'uid': top_model.uid,
383
+ 'block': top_model.block,
384
+ 'competition_id': top_model.competition_id,
385
+ 'model_metadata': top_model.model_metadata,
386
+ 'is_new': top_model.is_new,
387
+ 'is_being_scored': top_model.is_being_scored,
388
+ 'is_being_scored_by': top_model.is_being_scored_by,
389
+ 'scoring_updated_at': top_model.scoring_updated_at,
390
+ 'updated_at': top_model.updated_at
391
+ }
392
+ bt.logging.debug(f"Found high-scoring model (dynamic threshold) to score: hotkey={model_data['hotkey']}, uid={model_data['uid']}")
393
+ return model_data
394
+ else:
395
+ bt.logging.info(f"Skipping dynamic high-score prioritization for competition '{competition_id}' as no overall positive max score is available or it's zero.")
396
+ # ---- END: Modified logic ----
397
+
398
+ # Otherwise, use the standard prioritization logic with the zero-score detection
399
+ next_model = session.query(ModelQueue).outerjoin(
400
+ score_subquery,
401
+ and_(
402
+ ModelQueue.hotkey == score_subquery.c.hotkey,
403
+ ModelQueue.uid == score_subquery.c.uid
404
+ )
405
+ ).outerjoin(
406
+ all_scores_subquery,
407
+ and_(
408
+ ModelQueue.hotkey == all_scores_subquery.c.hotkey,
409
+ ModelQueue.uid == all_scores_subquery.c.uid
410
+ )
411
+ ).filter(
412
+ ModelQueue.is_being_scored == False,
413
+ ModelQueue.competition_id == competition_id
414
+ ).order_by(
415
+ desc(ModelQueue.is_new), # 1. Prioritize new models
416
+ (score_subquery.c.score_count == None).desc(), # 2. Prioritize models never scored (non-zero)
417
+ case( # 3. Prioritize models not scored for more than 7 days (safety net)
418
+ (and_(score_subquery.c.latest_scored_at != None, score_subquery.c.latest_scored_at < weekly_rescore_threshold_time), 0),
419
+ else_=1
420
+ ),
421
+ # 4. Decrease priority for models with all zero scores and frequent scoring
422
+ case(
423
+ (and_(
424
+ all_scores_subquery.c.all_score_count > 10, # Has many scores
425
+ all_scores_subquery.c.non_zero_count == 0, # All scores are zero
426
+ all_scores_subquery.c.latest_all_scored_at > five_days_ago # Scored recently
427
+ ), 1),
428
+ else_=0
429
+ ),
430
+ case( # 5. Prioritize models eligible by standard criteria
431
+ (or_(
432
+ score_subquery.c.latest_scored_at == None,
433
+ score_subquery.c.latest_scored_at <= five_days_ago,
434
+ score_subquery.c.score_count < 5
435
+ ), 0),
436
+ else_=1
437
+ ),
438
+ func.rand() # 6. Random tie-breaker
439
+ ).with_for_update().first()
440
+
441
+ if next_model:
442
+ # Create a dictionary with the model's attributes
443
+ model_data = {
444
+ 'hotkey': next_model.hotkey,
445
+ 'uid': next_model.uid,
446
+ 'block': next_model.block,
447
+ 'competition_id': next_model.competition_id,
448
+ 'model_metadata': next_model.model_metadata,
449
+ 'is_new': next_model.is_new,
450
+ 'is_being_scored': next_model.is_being_scored,
451
+ 'is_being_scored_by': next_model.is_being_scored_by,
452
+ 'scoring_updated_at': next_model.scoring_updated_at,
453
+ 'updated_at': next_model.updated_at
454
+ }
455
+ bt.logging.debug(f"Found next model to score: hotkey={model_data['hotkey']}, uid={model_data['uid']}")
456
+ return model_data
457
+ else:
458
+ bt.logging.debug("No models available for scoring")
459
+ return None
460
+
461
+ except Exception as e:
462
+ bt.logging.error(f"Error in _get_next_model: {str(e)}")
463
+ raise
464
+
465
+ try:
466
+ return self.execute_with_retry(_get_next_model)
467
+ except Exception as e:
468
+ bt.logging.error(f"Failed to get next model after {self.max_retries} attempts: {str(e)}")
469
+ return None
470
+
471
+ def mark_model_as_being_scored(self, model_hotkey, model_uid, scorer_hotkey):
472
+ """Mark model as being scored with retry logic."""
473
+ def _mark_model():
474
+ with self.session_scope() as session:
475
+ model = session.query(ModelQueue).filter_by(
476
+ hotkey=model_hotkey,
477
+ uid=model_uid
478
+ ).with_for_update().first()
479
+
480
+ if model and not model.is_being_scored:
481
+ model.is_being_scored = True
482
+ model.is_being_scored_by = scorer_hotkey
483
+ model.scoring_updated_at = datetime.utcnow()
484
+ return True
485
+ return False
486
+
487
+ try:
488
+ return self.execute_with_retry(_mark_model)
489
+ except Exception as e:
490
+ bt.logging.error(f"Failed to mark model as being scored after {self.max_retries} attempts: {str(e)}")
491
+ return False
492
+
493
+ def submit_score(self, model_hotkey, model_uid, scorer_hotkey, model_hash, score, metric_scores):
494
+ """Submit score with retry logic. Mark the model in queue as scored. Remove from queue."""
495
+ def _submit_score():
496
+ with self.session_scope() as session:
497
+ try:
498
+ model = session.query(ModelQueue).filter_by(
499
+ hotkey=model_hotkey,
500
+ uid=model_uid
501
+ ).with_for_update().first()
502
+
503
+ if not model:
504
+ bt.logging.error(f"No model found for hotkey {model_hotkey} and uid {model_uid}")
505
+ return False
506
+
507
+ """
508
+ # temporarily allow scoring from any hotkey
509
+ new_score = ScoreHistory(
510
+ hotkey=model_hotkey,
511
+ uid=model_uid,
512
+ competition_id=model.competition_id,
513
+ score=score,
514
+ block=model.block,
515
+ model_hash=model_hash,
516
+ scorer_hotkey=scorer_hotkey,
517
+ model_metadata=model.model_metadata
518
+ )
519
+ session.add(new_score)
520
+ model.is_new = False
521
+ model.is_being_scored = False
522
+ model.is_being_scored_by = None
523
+ model.scoring_updated_at = None
524
+ model.updated_at = datetime.now(timezone.utc)
525
+ bt.logging.info(f"Successfully submitted score for model {model_hotkey} by {scorer_hotkey}")
526
+ return True
527
+ """
528
+
529
+ if model.is_being_scored and model.is_being_scored_by == scorer_hotkey:
530
+ # Extract wandb fields from metric_scores if present
531
+ wandb_run_id = None
532
+ wandb_run_url = None
533
+ if metric_scores and isinstance(metric_scores, dict):
534
+ wandb_run_id = metric_scores.get('wandb_run_id')
535
+ wandb_run_url = metric_scores.get('wandb_run_url')
536
+
537
+ new_score = ScoreHistory(
538
+ hotkey=model_hotkey,
539
+ uid=model_uid,
540
+ competition_id=model.competition_id,
541
+ score=score,
542
+ block=model.block,
543
+ model_hash=model_hash,
544
+ scorer_hotkey=scorer_hotkey,
545
+ model_metadata=model.model_metadata,
546
+ metric_scores=metric_scores,
547
+ wandb_run_id=wandb_run_id,
548
+ wandb_run_url=wandb_run_url
549
+ )
550
+ session.add(new_score)
551
+ model.is_new = False
552
+ model.is_being_scored = False
553
+ model.is_being_scored_by = None
554
+ model.scoring_updated_at = None
555
+ model.updated_at = datetime.now(timezone.utc)
556
+ bt.logging.info(f"Successfully submitted score for model {model_hotkey} by {scorer_hotkey}")
557
+ return True
558
+ else:
559
+ bt.logging.error(f"Failed to submit score for model {model_hotkey} by {scorer_hotkey}. "
560
+ f"Model: {model}, is_being_scored: {model.is_being_scored}, "
561
+ f"is_being_scored_by: {model.is_being_scored_by}")
562
+ return False
563
+
564
+ except Exception as e:
565
+ bt.logging.error(f"Error in _submit_score: {str(e)}")
566
+ raise
567
+
568
+ try:
569
+ return self.execute_with_retry(_submit_score)
570
+ except Exception as e:
571
+ bt.logging.error(f"Failed to submit score after {self.max_retries} attempts: {str(e)}")
572
+ return False
573
+
574
+ def reset_stale_scoring_tasks(self, max_scoring_time_minutes=15):
575
+ """Reset stale scoring tasks with retry logic."""
576
+ def _reset_stale_tasks():
577
+ with self.session_scope() as session:
578
+ try:
579
+ stale_time = datetime.utcnow() - timedelta(minutes=max_scoring_time_minutes)
580
+ stale_models = session.query(ModelQueue).filter(
581
+ ModelQueue.is_being_scored == True,
582
+ ModelQueue.scoring_updated_at < stale_time
583
+ ).with_for_update().all()
584
+
585
+ reset_count = 0
586
+ for model in stale_models:
587
+ model.is_being_scored = False
588
+ model.is_being_scored_by = None
589
+ model.scoring_updated_at = None
590
+ reset_count += 1
591
+ bt.logging.info(f"Reset scoring task for stale model: hotkey={model.hotkey}, uid={model.uid}")
592
+
593
+ return reset_count
594
+
595
+ except Exception as e:
596
+ bt.logging.error(f"Error in _reset_stale_tasks: {str(e)}")
597
+ raise
598
+
599
+ try:
600
+ return self.execute_with_retry(_reset_stale_tasks)
601
+ except Exception as e:
602
+ bt.logging.error(f"Failed to reset stale tasks after {self.max_retries} attempts: {str(e)}")
603
+ return 0
604
+
605
+ def get_recent_model_scores(self, scores_per_model):
606
+ """
607
+ Get recent scores for all models.
608
+
609
+ Args:
610
+ scores_per_model (int): Number of recent scores to fetch per model
611
+
612
+ Returns:
613
+ dict: Dictionary of model scores grouped by UID
614
+ """
615
+ def _get_recent_scores():
616
+ with self.session_scope() as session:
617
+ try:
618
+ # First, create a subquery that ranks scores by timestamp for each model
619
+ ranked_scores = (
620
+ session.query(
621
+ ScoreHistory,
622
+ func.row_number().over(
623
+ partition_by=(ScoreHistory.hotkey, ScoreHistory.uid),
624
+ order_by=desc(ScoreHistory.scored_at)
625
+ ).label('score_rank')
626
+ )
627
+ .filter(ScoreHistory.is_archived == False)
628
+ .filter(ScoreHistory.score != 0)
629
+ .subquery()
630
+ )
631
+
632
+ # Get the most recent scores for each model
633
+ recent_scores = session.query(ranked_scores).filter(
634
+ ranked_scores.c.score_rank <= scores_per_model
635
+ ).subquery('recent_scores')
636
+
637
+ # Join with ModelQueue to get additional model information
638
+ results = session.query(
639
+ ModelQueue.uid,
640
+ ModelQueue.hotkey,
641
+ ModelQueue.competition_id,
642
+ ModelQueue.model_metadata,
643
+ recent_scores.c.score,
644
+ recent_scores.c.scored_at,
645
+ recent_scores.c.block,
646
+ recent_scores.c.model_hash,
647
+ recent_scores.c.scorer_hotkey,
648
+ recent_scores.c.score_rank
649
+ ).outerjoin(
650
+ recent_scores,
651
+ and_(
652
+ ModelQueue.hotkey == recent_scores.c.hotkey,
653
+ ModelQueue.uid == recent_scores.c.uid,
654
+ )
655
+ ).order_by(
656
+ ModelQueue.uid,
657
+ ModelQueue.hotkey,
658
+ recent_scores.c.scored_at.desc()
659
+ ).all()
660
+
661
+ scores_by_uid = defaultdict(lambda: defaultdict(list))
662
+
663
+ for result in results:
664
+ if result.score is not None:
665
+ # Create a unique key for each hotkey+uid combination
666
+ model_key = f"{result.hotkey}_{result.uid}"
667
+
668
+ scores_by_uid[result.uid][model_key].append({
669
+ 'hotkey': result.hotkey,
670
+ 'competition_id': result.competition_id,
671
+ 'model_metadata': result.model_metadata,
672
+ 'score': result.score,
673
+ 'scored_at': result.scored_at.isoformat() if result.scored_at else None,
674
+ 'block': result.block,
675
+ 'model_hash': result.model_hash,
676
+ 'scorer_hotkey': result.scorer_hotkey,
677
+ 'rank': result.score_rank
678
+ })
679
+ else:
680
+ # Handle models with no scores
681
+ model_key = f"{result.hotkey}_{result.uid}"
682
+ if not scores_by_uid[result.uid][model_key]: # Only add if no scores exist
683
+ scores_by_uid[result.uid][model_key].append({
684
+ 'hotkey': result.hotkey,
685
+ 'competition_id': None,
686
+ 'model_metadata': result.model_metadata,
687
+ 'score': None,
688
+ 'scored_at': None,
689
+ 'block': None,
690
+ 'model_hash': None,
691
+ 'scorer_hotkey': None,
692
+ 'rank': None
693
+ })
694
+
695
+ # Convert defaultdict to regular dict for return
696
+ return {
697
+ uid: dict(models)
698
+ for uid, models in scores_by_uid.items()
699
+ }
700
+
701
+ except Exception as e:
702
+ bt.logging.error(f"Error in _get_recent_scores: {str(e)}")
703
+ raise
704
+
705
+ try:
706
+ return self.execute_with_retry(_get_recent_scores)
707
+ except Exception as e:
708
+ bt.logging.error(f"Failed to get recent scores after {self.max_retries} attempts: {str(e)}")
709
+ return {}
710
+
711
+ def get_all_model_scores(self):
712
+ """Get all model scores with retry logic."""
713
+ def _get_all_scores():
714
+ with self.session_scope() as session:
715
+ try:
716
+ # First, get the latest score timestamps
717
+ latest_scores = session.query(
718
+ ScoreHistory.hotkey,
719
+ ScoreHistory.uid,
720
+ func.max(ScoreHistory.scored_at).label('latest_score_time')
721
+ ).filter(
722
+ ScoreHistory.is_archived == False
723
+ ).group_by(
724
+ ScoreHistory.hotkey,
725
+ ScoreHistory.uid
726
+ ).subquery('latest_scores')
727
+
728
+ # Get score details
729
+ latest_score_details = session.query(
730
+ ScoreHistory
731
+ ).join(
732
+ latest_scores,
733
+ and_(
734
+ ScoreHistory.hotkey == latest_scores.c.hotkey,
735
+ ScoreHistory.uid == latest_scores.c.uid,
736
+ ScoreHistory.scored_at == latest_scores.c.latest_score_time
737
+ )
738
+ ).subquery('latest_score_details')
739
+
740
+ # Get final results
741
+ results = session.query(
742
+ ModelQueue.uid,
743
+ ModelQueue.hotkey,
744
+ ModelQueue.competition_id,
745
+ latest_score_details.c.score,
746
+ latest_score_details.c.scored_at,
747
+ latest_score_details.c.block,
748
+ latest_score_details.c.model_hash,
749
+ latest_score_details.c.scorer_hotkey
750
+ ).outerjoin(
751
+ latest_score_details,
752
+ and_(
753
+ ModelQueue.hotkey == latest_score_details.c.hotkey,
754
+ ModelQueue.uid == latest_score_details.c.uid
755
+ )
756
+ ).all()
757
+
758
+ scores_by_uid = defaultdict(list)
759
+ for result in results:
760
+ if result.score is not None:
761
+ scores_by_uid[result.uid].append({
762
+ 'hotkey': result.hotkey,
763
+ 'competition_id': result.competition_id,
764
+ 'score': result.score,
765
+ 'scored_at': result.scored_at.isoformat() if result.scored_at else None,
766
+ 'block': result.block,
767
+ 'model_hash': result.model_hash,
768
+ })
769
+ else:
770
+ scores_by_uid[result.uid].append({
771
+ 'hotkey': result.hotkey,
772
+ 'competition_id': result.competition_id,
773
+ 'score': None,
774
+ 'scored_at': None,
775
+ 'block': None,
776
+ 'model_hash': None,
777
+ })
778
+
779
+ return dict(scores_by_uid)
780
+
781
+ except Exception as e:
782
+ bt.logging.error(f"Error in _get_all_scores: {str(e)}")
783
+ raise
784
+
785
+ try:
786
+ return self.execute_with_retry(_get_all_scores)
787
+ except Exception as e:
788
+ bt.logging.error(f"Failed to get all scores after {self.max_retries} attempts: {str(e)}")
789
+ return {}
790
+
791
+ def archive_scores_for_deregistered_models(self, registered_hotkey_uid_pairs):
792
+ """Archive deregistered models with retry logic."""
793
+ def _archive_scores():
794
+ with self.session_scope() as session:
795
+ try:
796
+ all_models = session.query(
797
+ ModelQueue.hotkey,
798
+ ModelQueue.uid
799
+ ).with_for_update().all()
800
+
801
+ deregistered_models = set(
802
+ (model.hotkey, model.uid) for model in all_models
803
+ ) - set(registered_hotkey_uid_pairs)
804
+
805
+ for hotkey, uid in deregistered_models:
806
+ # Mark scores as archived
807
+ archive_result = session.query(ScoreHistory).filter_by(
808
+ hotkey=hotkey,
809
+ uid=uid,
810
+ is_archived=False
811
+ ).update(
812
+ {"is_archived": True},
813
+ synchronize_session=False
814
+ )
815
+
816
+ # Remove from ModelQueue
817
+ delete_result = session.query(ModelQueue).filter_by(
818
+ hotkey=hotkey,
819
+ uid=uid
820
+ ).delete(synchronize_session=False)
821
+
822
+ bt.logging.debug(
823
+ f"Processed deregistered model - Hotkey: {hotkey}, "
824
+ f"UID: {uid}, Archived scores: {archive_result}, "
825
+ f"Removed from queue: {delete_result}"
826
+ )
827
+
828
+ return len(deregistered_models)
829
+
830
+ except Exception as e:
831
+ bt.logging.error(f"Error in _archive_scores: {str(e)}")
832
+ raise
833
+
834
+ try:
835
+ result = self.execute_with_retry(_archive_scores)
836
+ print(f"Archived scores and removed {result} deregistered models from the queue.")
837
+ return result
838
+ except Exception as e:
839
+ bt.logging.error(f"Failed to archive scores after {self.max_retries} attempts: {str(e)}")
840
+ return 0
841
+
842
+ def close(self):
843
+ """Safely close the session."""
844
+ try:
845
+ self.session.close()
846
+ except:
847
+ pass
model/storage/remote_model_store.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from model.data import Model, ModelId
3
+ from constants import CompetitionParameters
4
+ from typing import Optional
5
+
6
+ class RemoteModelStore(abc.ABC):
7
+ """An abstract base class for storing and retrieving a pre trained model."""
8
+
9
+ @abc.abstractmethod
10
+ async def upload_model(self, model: Model, parameters: CompetitionParameters) -> ModelId:
11
+ """Uploads a trained model in the appropriate location based on implementation."""
12
+ pass
13
+
14
+ @abc.abstractmethod
15
+ async def download_model(self, model_id: ModelId, local_path: str, parameters: CompetitionParameters) -> Model:
16
+ """Retrieves a trained model from the appropriate location and stores at the given path."""
17
+ pass
model/storage/reputation_store.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import create_engine, Column, Integer, Float, String, DateTime, Boolean, JSON, func, desc, exists, ForeignKey, or_, and_, case
2
+ from sqlalchemy.exc import OperationalError, SQLAlchemyError
3
+ from sqlalchemy.ext.declarative import declarative_base
4
+ from sqlalchemy.orm import Session, sessionmaker, aliased, relationship
5
+ from contextlib import contextmanager
6
+ from collections import defaultdict
7
+
8
+ import time
9
+ from datetime import datetime, timedelta, timezone
10
+ import bittensor as bt
11
+ from typing import Optional
12
+
13
+ from vali_api.config import DBHOST, DBNAME, DBUSER, DBPASS, IS_PROD
14
+
15
+
16
+ Base = declarative_base()
17
+
18
+ # Global variables for engine and Session
19
+ _engine: Optional[object] = None
20
+ Session: Optional[sessionmaker] = None
21
+
22
+ def init_database():
23
+ """
24
+ Initialize the database connection and create tables.
25
+ Must be called before using any database operations.
26
+ """
27
+ global _engine, Session
28
+
29
+ if _engine is not None:
30
+ bt.logging.warning("Database already initialized")
31
+ return
32
+
33
+ try:
34
+ connection_string = f'mysql://{DBUSER}:{DBPASS}@{DBHOST}/{DBNAME}'
35
+ _engine = create_engine(connection_string)
36
+ Session = sessionmaker(bind=_engine)
37
+
38
+ # Create all tables
39
+ Base.metadata.create_all(_engine)
40
+ bt.logging.info("Database initialized successfully")
41
+
42
+ except Exception as e:
43
+ bt.logging.error(f"Failed to initialize database: {str(e)}")
44
+ raise
45
+
46
+ def get_session() -> Session:
47
+ """
48
+ Get a database session. Raises exception if database not initialized.
49
+ """
50
+ if Session is None:
51
+ raise RuntimeError("Database not initialized. Call init_database() first.")
52
+ return Session()
53
+
54
+ def get_table_name(base_name: str) -> str:
55
+ """Helper function to get the correct table name with suffix if not in production."""
56
+ return f"{base_name}{'_test' if not IS_PROD else ''}"
57
+
58
+ class BaselineScore(Base):
59
+ __tablename__ = get_table_name('sn21_baseline_scores')
60
+ id = Column(Integer, primary_key=True)
61
+ competition_id = Column(String(255), nullable=False, index=True)
62
+ score = Column(Float, nullable=False)
63
+ created_at = Column(DateTime, default=datetime.utcnow, index=True)
64
+
65
+ def __repr__(self):
66
+ return f"<BaselineScore(competition_id='{self.competition_id}', score={self.score}, created_at='{self.created_at}')>"
67
+
68
+ class MinerReputation(Base):
69
+ __tablename__ = get_table_name('sn21_miner_reputations')
70
+ hotkey = Column(String(255), primary_key=True)
71
+ reputation = Column(Float, default=0.5, nullable=False)
72
+ last_updated = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
73
+
74
+ def __repr__(self):
75
+ return f"<MinerReputation(hotkey='{self.hotkey}', reputation={self.reputation})>"
76
+
77
+ class ReputationHistory(Base):
78
+ __tablename__ = get_table_name('sn21_reputation_history')
79
+ id = Column(Integer, primary_key=True)
80
+ hotkey = Column(String(255), nullable=False, index=True)
81
+ timestamp = Column(DateTime, nullable=False, index=True)
82
+ reputation = Column(Float, nullable=False)
83
+
84
+ def __repr__(self):
85
+ return f"<ReputationHistory(hotkey='{self.hotkey}', timestamp='{self.timestamp}', reputation={self.reputation})>"
86
+
87
+ class ReputationStore:
88
+ def __init__(self, max_retries=3, retry_delay=1):
89
+ # Ensure DB is initialized and get the sessionmaker
90
+ if Session is None:
91
+ raise RuntimeError("Database not initialized. Call init_database() first.")
92
+ self.max_retries = max_retries
93
+ self.retry_delay = retry_delay
94
+ self.session = get_session()
95
+
96
+ @contextmanager
97
+ def session_scope(self):
98
+ """Provide a transactional scope around a series of operations."""
99
+ session = get_session()
100
+ try:
101
+ yield session
102
+ session.commit()
103
+ except Exception as e:
104
+ session.rollback()
105
+ raise
106
+ finally:
107
+ session.close()
108
+
109
+ def reset_session(self):
110
+ """Reset the session in case of connection issues."""
111
+ try:
112
+ self.session.close()
113
+ except:
114
+ pass
115
+ try:
116
+ self.session = get_session()
117
+ except Exception as e:
118
+ bt.logging.error(f"Failed to reset session: {str(e)}")
119
+ raise
120
+
121
+ def execute_with_retry(self, operation, *args, **kwargs):
122
+ """Execute an operation with retry logic."""
123
+ for attempt in range(self.max_retries):
124
+ try:
125
+ return operation(*args, **kwargs)
126
+ except OperationalError as e:
127
+ if "Lost connection" in str(e) and attempt < self.max_retries - 1:
128
+ bt.logging.warning(f"Lost connection to MySQL. Attempt {attempt + 1}/{self.max_retries}. Retrying...")
129
+ self.reset_session()
130
+ time.sleep(self.retry_delay)
131
+ else:
132
+ raise
133
+ except SQLAlchemyError as e:
134
+ if attempt < self.max_retries - 1:
135
+ bt.logging.warning(f"Database error. Attempt {attempt + 1}/{self.max_retries}. Retrying...")
136
+ self.reset_session()
137
+ time.sleep(self.retry_delay)
138
+ else:
139
+ raise
140
+
141
+ def get_latest_baseline_score(self, competition_id):
142
+ def _get():
143
+ with self.session_scope() as session:
144
+ latest_baseline = (
145
+ session.query(BaselineScore)
146
+ .filter(BaselineScore.competition_id == competition_id)
147
+ .order_by(BaselineScore.created_at.desc())
148
+ .first()
149
+ )
150
+ if latest_baseline is None:
151
+ return None
152
+ else:
153
+ return {"competition_id": latest_baseline.competition_id, "score": latest_baseline.score, "created_at": latest_baseline.created_at}
154
+
155
+ return self.execute_with_retry(_get)
156
+
157
+ def get_all_reputations(self):
158
+ def _get():
159
+ with self.session_scope() as session:
160
+ records = session.query(MinerReputation).all()
161
+ return {
162
+ record.hotkey: {
163
+ "reputation": record.reputation,
164
+ "last_updated": record.last_updated.isoformat() if record.last_updated else None
165
+ }
166
+ for record in records
167
+ }
168
+ return self.execute_with_retry(_get)
169
+
170
+ def get_reputation(self, hotkey):
171
+ def _get():
172
+ with self.session_scope() as session:
173
+ record = session.query(MinerReputation).filter(MinerReputation.hotkey == hotkey).first()
174
+ if not record:
175
+ return None
176
+ return {
177
+ "hotkey": record.hotkey,
178
+ "reputation": record.reputation,
179
+ "last_updated": record.last_updated.isoformat() if record.last_updated else None
180
+ }
181
+ return self.execute_with_retry(_get)
182
+
183
+ def main():
184
+ """
185
+ Main function to demonstrate the ReputationStore's three get methods.
186
+ """
187
+ try:
188
+ # Initialize the database
189
+ print("Initializing database...")
190
+ init_database()
191
+ print("Database initialized successfully!")
192
+
193
+ # Create ReputationStore instance
194
+ print("\nCreating ReputationStore instance...")
195
+ reputation_store = ReputationStore()
196
+ print("ReputationStore created successfully!")
197
+
198
+ # Test 1: Get latest baseline score
199
+ print("\n=== Testing get_latest_baseline_score ===")
200
+ competition_id = "v1"
201
+ baseline_score = reputation_store.get_latest_baseline_score(competition_id)
202
+ if baseline_score:
203
+ print(f"Latest baseline score for competition '{competition_id}':")
204
+ print(f" Score: {baseline_score['score']}")
205
+ print(f" Created at: {baseline_score['created_at']}")
206
+ else:
207
+ print(f"No baseline score found for competition '{competition_id}'")
208
+
209
+ # Test 2: Get all reputations
210
+ print("\n=== Testing get_all_reputations ===")
211
+ all_reputations = reputation_store.get_all_reputations()
212
+ if all_reputations:
213
+ print(f"Found {len(all_reputations)} miner reputations:")
214
+ for hotkey, data in all_reputations.items():
215
+ print(f" Hotkey: {hotkey}")
216
+ print(f" Reputation: {data['reputation']}")
217
+ print(f" Last Updated: {data['last_updated']}")
218
+ break
219
+ else:
220
+ print("No miner reputations found in database")
221
+
222
+ # Test 3: Get specific reputation
223
+ print("\n=== Testing get_reputation ===")
224
+ test_hotkey = "test_hotkey_123"
225
+ reputation = reputation_store.get_reputation(test_hotkey)
226
+ if reputation:
227
+ print(f"Reputation for hotkey '{test_hotkey}':")
228
+ print(f" Hotkey: {reputation['hotkey']}")
229
+ print(f" Reputation: {reputation['reputation']}")
230
+ print(f" Last Updated: {reputation['last_updated']}")
231
+ else:
232
+ print(f"No reputation found for hotkey '{test_hotkey}'")
233
+
234
+ print("\n=== All tests completed successfully! ===")
235
+
236
+ except Exception as e:
237
+ print(f"Error occurred: {str(e)}")
238
+ import traceback
239
+ traceback.print_exc()
240
+
241
+ if __name__ == "__main__":
242
+ main()
243
+
244
+
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers==4.48.3
2
+ pydantic==2.11.4
3
+ numpy==2.2.5
4
+ torch==2.6.0
5
+ torchaudio==2.6.0
6
+ torchvision==0.21.0
7
+ fastapi==0.115.12
8
+ uvicorn==0.34.2
9
+ librosa==0.11.0
10
+ openai-whisper==20240930
11
+ soundfile==0.13.1
12
+ accelerate==0.26.0
13
+ voxcpm
server.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fastapi import FastAPI, HTTPException
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from pydantic import BaseModel, Field
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ import traceback
7
+ import whisper
8
+ import librosa
9
+ import numpy as np
10
+ import torch
11
+ import uvicorn
12
+ import base64
13
+ import io
14
+ from voxcpm import VoxCPM
15
+
16
+
17
+ asr_model = whisper.load_model("models/wpt/wpt.pt")
18
+ model_name = "models/Llama-3.2-1B-Instruct"
19
+ tok = AutoTokenizer.from_pretrained(model_name)
20
+ lm = AutoModelForCausalLM.from_pretrained(
21
+ model_name,
22
+ torch_dtype=torch.bfloat16,
23
+ device_map="cuda",
24
+ ).eval()
25
+
26
+ tts = VoxCPM.from_pretrained(
27
+ "models/VoxCPM-0.5B",
28
+ local_files_only=True,
29
+ load_denoiser=True,
30
+ zipenhancer_model_id="models/iic/speech_zipenhancer_ans_multiloss_16k_base"
31
+ )
32
+
33
+ def chat(system_prompt: str, user_prompt: str) -> str:
34
+ print("LLM init...")
35
+ messages = [
36
+ {"role": "system", "content": system_prompt},
37
+ {"role": "user", "content": user_prompt},
38
+ ]
39
+ inputs = tok.apply_chat_template(
40
+ messages,
41
+ add_generation_prompt=True,
42
+ return_tensors="pt",
43
+ return_dict=True
44
+ )
45
+ input_ids = inputs["input_ids"].to(lm.device)
46
+ attention_mask = inputs["attention_mask"].to(lm.device)
47
+
48
+ with torch.inference_mode():
49
+ output_ids = lm.generate(
50
+ input_ids=input_ids,
51
+ attention_mask=attention_mask,
52
+ pad_token_id=tok.eos_token_id,
53
+ max_new_tokens=2048,
54
+ do_sample=True,
55
+ temperature=0.2,
56
+ repetition_penalty=1.1,
57
+ top_k=100,
58
+ top_p=0.95,
59
+ )
60
+
61
+ answer = tok.decode(
62
+ output_ids[0][input_ids.shape[-1]:],
63
+ skip_special_tokens=True,
64
+ clean_up_tokenization_spaces=True,
65
+ )
66
+ print("LLM answer done.")
67
+ return answer.strip()
68
+
69
+ def gt(audio: np.ndarray, sr: int):
70
+ print("Starting ASR transcription...")
71
+ ss = audio.squeeze().astype(np.float32)
72
+ if sr != 16_000:
73
+ ss = librosa.resample(audio, orig_sr=sr, target_sr=16_000)
74
+
75
+ result = asr_model.transcribe(ss, fp16=False, language=None)
76
+ transcribed_text = result["text"].strip()
77
+ print(f"ASR done. Transcribed: '{transcribed_text}'")
78
+ return transcribed_text
79
+
80
+
81
+ def sample(rr: str) -> str:
82
+ if rr.strip() == "":
83
+ rr = "Hello "
84
+
85
+ inputs = tok(rr, return_tensors="pt").to(lm.device)
86
+
87
+ with torch.inference_mode():
88
+ out_ids = lm.generate(
89
+ **inputs,
90
+ max_new_tokens=2048,
91
+ do_sample=True,
92
+ temperature=0.2,
93
+ repetition_penalty=1.1,
94
+ top_k=100,
95
+ top_p=0.95,
96
+ )
97
+
98
+ return tok.decode(
99
+ out_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True
100
+ )
101
+
102
+
103
+ INITIALIZATION_STATUS = {"model_loaded": True, "error": None}
104
+
105
+
106
+ class GenerateRequest(BaseModel):
107
+ audio_data: str = Field(..., description="")
108
+ sample_rate: int = Field(..., description="")
109
+
110
+
111
+ class GenerateResponse(BaseModel):
112
+ audio_data: str = Field(..., description="")
113
+
114
+
115
+ app = FastAPI(title="V1", version="0.1")
116
+
117
+ app.add_middleware(
118
+ CORSMiddleware,
119
+ allow_origins=["*"],
120
+ allow_credentials=True,
121
+ allow_methods=["*"],
122
+ allow_headers=["*"],
123
+ )
124
+
125
+
126
+ def b64(b64: str) -> np.ndarray:
127
+ raw = base64.b64decode(b64)
128
+ return np.load(io.BytesIO(raw), allow_pickle=False)
129
+
130
+
131
+ def ab64(arr: np.ndarray, sr: int) -> str:
132
+ buf = io.BytesIO()
133
+ resampled = librosa.resample(arr, orig_sr=16000, target_sr=sr)
134
+ np.save(buf, resampled.astype(np.float32))
135
+ return base64.b64encode(buf.getvalue()).decode()
136
+
137
+
138
+ @app.get("/api/v1/health")
139
+ def health_check():
140
+ return {
141
+ "status": "healthy",
142
+ "model_loaded": INITIALIZATION_STATUS["model_loaded"],
143
+ "error": INITIALIZATION_STATUS["error"],
144
+ }
145
+
146
+
147
+ @app.post("/api/v1/v2v", response_model=GenerateResponse)
148
+ def generate_audio(req: GenerateRequest):
149
+ print("=== V2V Request Started ===")
150
+ audio_np = b64(req.audio_data)
151
+ if audio_np.ndim == 1:
152
+ audio_np = audio_np.reshape(1, -1)
153
+ print(f"Audio shape: {audio_np.shape}, Sample rate: {req.sample_rate}")
154
+
155
+ system_prompt = (
156
+ "You are a helpful assistant who tries to help answer the user's question. "
157
+ "This is a part of voice assistant system, don't generate anything other than pure text."
158
+ )
159
+
160
+ try:
161
+ text = gt(audio_np, req.sample_rate)
162
+ response_text = chat(system_prompt, user_prompt=text)
163
+ print(f"LLM response len chars: '{len(response_text)}'")
164
+ print(f"LLM response: '{response_text}'")
165
+
166
+ import time
167
+ start_time = time.perf_counter()
168
+ audio_out = tts.generate(
169
+ text=response_text,
170
+ prompt_wav_path=None,
171
+ prompt_text=None,
172
+ cfg_value=2.0,
173
+ inference_timesteps=10,
174
+ normalize=True,
175
+ denoise=True,
176
+ retry_badcase=True,
177
+ retry_badcase_max_times=3,
178
+ retry_badcase_ratio_threshold=6.0,
179
+ )
180
+ print("TTS generation complete.")
181
+ end_time = time.perf_counter()
182
+ print(f"TTS generation took {end_time - start_time:.2f} seconds.")
183
+ print("=== V2V Request Complete ===")
184
+ except Exception as e:
185
+ print(f"ERROR in V2V: {e}")
186
+ traceback.print_exc()
187
+ raise HTTPException(status_code=500, detail=f"{e}")
188
+
189
+ return GenerateResponse(audio_data=ab64(audio_out, req.sample_rate))
190
+
191
+
192
+ @app.post("/api/v1/v2t")
193
+ def generate_text(req: GenerateRequest):
194
+ audio_np = b64(req.audio_data)
195
+ if audio_np.ndim == 1:
196
+ audio_np = audio_np.reshape(1, -1)
197
+
198
+ try:
199
+ text = gt(audio_np, req.sample_rate)
200
+ print(f"Transcribed text: {text}")
201
+ system_prompt = "You are a helpful assistant who tries to help answer the user's question."
202
+ response_text = chat(system_prompt, user_prompt=text)
203
+ except Exception as e:
204
+ traceback.print_exc()
205
+ raise HTTPException(status_code=500, detail=f"{e}")
206
+
207
+ return {"text": response_text}
208
+
209
+
210
+ if __name__ == "__main__":
211
+ uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=False)
spk_001.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79de3a5775f8880c0bf3e950b103f03b257db630224fab265a309d82753b1aa5
3
+ size 480044