Mingke977 commited on
Commit
fda2d4a
·
verified ·
1 Parent(s): f794b33

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. .gitignore +2 -0
  2. README.md +378 -0
  3. chat_template.jinja +103 -0
  4. config.json +58 -0
  5. configuration_deepseek.py +247 -0
  6. docs/deploy_guidance.md +42 -0
  7. model.safetensors.index.json +0 -0
  8. modeling_deepseek.py +1030 -0
  9. tokenizer.json +0 -0
  10. tokenizer_config.json +34 -0
  11. venv/bin/Activate.ps1 +247 -0
  12. venv/bin/activate +69 -0
  13. venv/bin/activate.csh +26 -0
  14. venv/bin/activate.fish +69 -0
  15. venv/bin/hf +10 -0
  16. venv/bin/httpx +10 -0
  17. venv/bin/markdown-it +10 -0
  18. venv/bin/pip +10 -0
  19. venv/bin/pip3 +10 -0
  20. venv/bin/pip3.10 +10 -0
  21. venv/bin/pygmentize +10 -0
  22. venv/bin/tiny-agents +10 -0
  23. venv/bin/tqdm +10 -0
  24. venv/bin/typer +10 -0
  25. venv/lib/python3.10/site-packages/_distutils_hack/__init__.py +132 -0
  26. venv/lib/python3.10/site-packages/_distutils_hack/__pycache__/__init__.cpython-310.pyc +0 -0
  27. venv/lib/python3.10/site-packages/_distutils_hack/__pycache__/override.cpython-310.pyc +0 -0
  28. venv/lib/python3.10/site-packages/_distutils_hack/override.py +1 -0
  29. venv/lib/python3.10/site-packages/_yaml/__init__.py +33 -0
  30. venv/lib/python3.10/site-packages/_yaml/__pycache__/__init__.cpython-310.pyc +0 -0
  31. venv/lib/python3.10/site-packages/annotated_doc-0.0.4.dist-info/INSTALLER +1 -0
  32. venv/lib/python3.10/site-packages/annotated_doc-0.0.4.dist-info/METADATA +145 -0
  33. venv/lib/python3.10/site-packages/annotated_doc-0.0.4.dist-info/RECORD +11 -0
  34. venv/lib/python3.10/site-packages/annotated_doc-0.0.4.dist-info/WHEEL +4 -0
  35. venv/lib/python3.10/site-packages/annotated_doc-0.0.4.dist-info/entry_points.txt +4 -0
  36. venv/lib/python3.10/site-packages/annotated_doc-0.0.4.dist-info/licenses/LICENSE +21 -0
  37. venv/lib/python3.10/site-packages/annotated_doc/__init__.py +3 -0
  38. venv/lib/python3.10/site-packages/annotated_doc/__pycache__/__init__.cpython-310.pyc +0 -0
  39. venv/lib/python3.10/site-packages/annotated_doc/__pycache__/main.cpython-310.pyc +0 -0
  40. venv/lib/python3.10/site-packages/annotated_doc/main.py +36 -0
  41. venv/lib/python3.10/site-packages/annotated_doc/py.typed +0 -0
  42. venv/lib/python3.10/site-packages/anyio/__init__.py +111 -0
  43. venv/lib/python3.10/site-packages/anyio/from_thread.py +578 -0
  44. venv/lib/python3.10/site-packages/anyio/functools.py +375 -0
  45. venv/lib/python3.10/site-packages/anyio/lowlevel.py +196 -0
  46. venv/lib/python3.10/site-packages/anyio/py.typed +0 -0
  47. venv/lib/python3.10/site-packages/anyio/pytest_plugin.py +302 -0
  48. venv/lib/python3.10/site-packages/anyio/to_interpreter.py +246 -0
  49. venv/lib/python3.10/site-packages/typing_extensions.py +0 -0
  50. venv/pyvenv.cfg +3 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .joycode/
2
+ venv/
README.md ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - zh
4
+ - en
5
+ pipeline_tag: text-generation
6
+ library_name: transformers
7
+ ---
8
+ <div align="center">
9
+ <picture>
10
+ <img src="figures/joyai-logo.png" width="30%" alt="JoyAI-LLM Flash">
11
+ </picture>
12
+ </div>
13
+ <hr>
14
+
15
+ <div align="center" style="line-height: 1;">
16
+ <a href="https://huggingface.co/jdopensource" target="_blank"><img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-JD-ffc107?color=ffc107&logoColor=white"/></a>
17
+ <a href="https://huggingface.co/jdopensource/JoyAI-LLM-Flash/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-Modified_MIT-f5de53?&color=f5de53"/></a>
18
+ </div>
19
+
20
+ ## 1. Model Introduction
21
+
22
+ JoyAI-LLM-Flash is a state-of-the-art medium-sized instruct language model with 3 billion activated parameters and 48 billion total parameters. JoyAI-LLM-Flash was pretrained on 20 trillion text tokens using Muon optimizer, followed by large-scale supervised fine-tuning (SFT), direct preference optimization (DPO), and reinforcement learning (RL) across diverse environments. JoyAI-LLM-Flash achieves strong performance across frontier knowledge, reasoning, coding tasks and agentic capabilities.
23
+
24
+ ### Key Features
25
+
26
+ - Fiber Bundle RL: Introduces fiber bundle theory into reinforcement learning, proposing a novel optimization framework, FiberPO. This method is specifically designed to handle the challenges of large-scale and heterogeneous agent training, improving stability and robustness under complex data distributions.
27
+ - Training-Inference Collaboration: apply Muon optimizer with dense MTP, develop novel optimization techniques to resolve instabilities while scaling up, delivering 1.3× to 1.7× the throughput of the non-MTP version.
28
+ - Agentic Intelligence: designed for tool use, reasoning, and autonomous problem-solving.
29
+
30
+ ## 2. Model Summary
31
+
32
+ | | |
33
+ | :-----------------------------------------: | :----------------------: |
34
+ | **Architecture** | Mixture-of-Experts (MoE) |
35
+ | **Total Parameters** | 48B |
36
+ | **Activated Parameters** | 3B |
37
+ | **Number of Layers** (Dense layer included) | 40 |
38
+ | **Number of Dense Layers** | 1 |
39
+ | **Attention Hidden Dimension** | 2048 |
40
+ | **MoE Hidden Dimension** (per Expert) | 768 |
41
+ | **Number of Attention Heads** | 32 |
42
+ | **Number of Experts** | 256 |
43
+ | **Selected Experts per Token** | 8 |
44
+ | **Number of Shared Experts** | 1 |
45
+ | **Vocabulary Size** | 129K |
46
+ | **Context Length** | 128K |
47
+ | **Attention Mechanism** | MLA |
48
+ | **Activation Function** | SwiGLU |
49
+ | </div> | |
50
+
51
+
52
+ ## 3. Evaluation Results
53
+
54
+ <table>
55
+ <thead>
56
+ <tr>
57
+ <th align="center">Benchmark</th>
58
+ <th align="center"><sup>JoyAI-LLM Flash</sup></th>
59
+ <th align="center"><sup>Qwen3-30B-A3B-Instuct-2507</sup></th>
60
+ <th align="center"><sup>GLM-4.7-Flash<br>(Non-thinking)</sup></th>
61
+ </tr>
62
+ </thead>
63
+ <tbody>
64
+
65
+
66
+ <tr>
67
+ <td align="center" colspan=8><strong>Knowledge &amp; Alignment</strong></td>
68
+ </tr>
69
+ <tr>
70
+ <td align="center" style="vertical-align: middle">MMLU</td>
71
+ <td align="center" style="vertical-align: middle"><strong>89.50</strong></td>
72
+ <td align="center" style="vertical-align: middle">86.87</td>
73
+ <td align="center" style="vertical-align: middle">80.53</td>
74
+ </tr>
75
+ <tr>
76
+ <td align="center" style="vertical-align: middle">MMLU-Pro</td>
77
+ <td align="center" style="vertical-align: middle"><strong>81.02</strong></td>
78
+ <td align="center" style="vertical-align: middle">73.88</td>
79
+ <td align="center" style="vertical-align: middle">63.62</td>
80
+ </tr>
81
+ <tr>
82
+ <td align="center" style="vertical-align: middle">CMMLU</td>
83
+ <td align="center" style="vertical-align: middle"><strong>87.03</strong></td>
84
+ <td align="center" style="vertical-align: middle">85.88</td>
85
+ <td align="center" style="vertical-align: middle">75.85</td>
86
+ </tr>
87
+ <tr>
88
+ <td align="center" style="vertical-align: middle">GPQA-Diamond</td>
89
+ <td align="center" style="vertical-align: middle"><strong>74.43</strong></td>
90
+ <td align="center" style="vertical-align: middle">68.69</td>
91
+ <td align="center" style="vertical-align: middle">39.90</td>
92
+ </tr>
93
+ <tr>
94
+ <td align="center" style="vertical-align: middle">SuperGPQA</td>
95
+ <td align="center" style="vertical-align: middle"><strong>55.00</strong></td>
96
+ <td align="center" style="vertical-align: middle">52.00</td>
97
+ <td align="center" style="vertical-align: middle">32.00</td>
98
+ </tr>
99
+ <tr>
100
+ <td align="center" style="vertical-align: middle">LiveBench</td>
101
+ <td align="center" style="vertical-align: middle"><strong>72.90</strong></td>
102
+ <td align="center" style="vertical-align: middle">59.70</td>
103
+ <td align="center" style="vertical-align: middle">43.10</td>
104
+ </tr>
105
+ <tr>
106
+ <td align="center" style="vertical-align: middle">IFEval</td>
107
+ <td align="center" style="vertical-align: middle"><strong>86.69</strong></td>
108
+ <td align="center" style="vertical-align: middle">83.18</td>
109
+ <td align="center" style="vertical-align: middle">82.44</td>
110
+ </tr>
111
+ <tr>
112
+ <td align="center" style="vertical-align: middle">AlignBench</td>
113
+ <td align="center" style="vertical-align: middle"><strong>8.24</strong></td>
114
+ <td align="center" style="vertical-align: middle">8.07</td>
115
+ <td align="center" style="vertical-align: middle">6.85</td>
116
+ </tr>
117
+ <tr>
118
+ <td align="center" style="vertical-align: middle">HellaSwag</td>
119
+ <td align="center" style="vertical-align: middle"><strong>91.79</strong></td>
120
+ <td align="center" style="vertical-align: middle">89.90</td>
121
+ <td align="center" style="vertical-align: middle">60.84</td>
122
+ </tr>
123
+
124
+ <tr>
125
+ <td align="center" colspan=8><strong>Coding</strong></td>
126
+ </tr>
127
+ <tr>
128
+ <td align="center" style="vertical-align: middle">HumanEval</td>
129
+ <td align="center" style="vertical-align: middle"><strong>96.34</strong></td>
130
+ <td align="center" style="vertical-align: middle">95.12</td>
131
+ <td align="center" style="vertical-align: middle">74.39</td>
132
+ </tr>
133
+ <tr>
134
+ <td align="center" style="vertical-align: middle">LiveCodeBench</td>
135
+ <td align="center" style="vertical-align: middle"><strong>65.60</strong></td>
136
+ <td align="center" style="vertical-align: middle">39.71</td>
137
+ <td align="center" style="vertical-align: middle">27.43</td>
138
+ </tr>
139
+ <tr>
140
+ <td align="center" style="vertical-align: middle">SciCode</td>
141
+ <td align="center" style="vertical-align: middle"><strong>3.08/22.92</strong></td>
142
+ <td align="center" style="vertical-align: middle"><strong>3.08/22.92</strong></td>
143
+ <td align="center" style="vertical-align: middle">3.08/15.11</td>
144
+ </tr>
145
+ <tr>
146
+ <td align="center" colspan=8><strong>Mathematics</strong></td>
147
+ </tr>
148
+ <tr>
149
+ <td align="center" style="vertical-align: middle">GSM8K</td>
150
+ <td align="center" style="vertical-align: middle"><strong>95.83</strong></td>
151
+ <td align="center" style="vertical-align: middle">79.83</td>
152
+ <td align="center" style="vertical-align: middle">81.88</td>
153
+ </tr>
154
+ <tr>
155
+ <td align="center" style="vertical-align: middle">AIME2025</td>
156
+ <td align="center" style="vertical-align: middle"><strong>65.83</strong></td>
157
+ <td align="center" style="vertical-align: middle">62.08</td>
158
+ <td align="center" style="vertical-align: middle">24.17</td>
159
+ </tr>
160
+ <tr>
161
+ <td align="center" style="vertical-align: middle">MATH 500</td>
162
+ <td align="center" style="vertical-align: middle"><strong>97.10</strong></td>
163
+ <td align="center" style="vertical-align: middle">89.80</td>
164
+ <td align="center" style="vertical-align: middle">90.90</td>
165
+ </tr>
166
+
167
+ <tr>
168
+ <td align="center" colspan=8><strong>Agentic</strong></td>
169
+ </tr>
170
+ <tr>
171
+ <td align="center" style="vertical-align: middle">SWE-bench Verified</td>
172
+ <td align="center" style="vertical-align: middle"><strong>60.60</strong></td>
173
+ <td align="center" style="vertical-align: middle">24.44</td>
174
+ <td align="center" style="vertical-align: middle">51.60</td>
175
+ </tr>
176
+ <tr>
177
+ <td align="center" style="vertical-align: middle">Tau2-Retail</td>
178
+ <td align="center" style="vertical-align: middle"><strong>67.55</strong></td>
179
+ <td align="center" style="vertical-align: middle">53.51</td>
180
+ <td align="center" style="vertical-align: middle">62.28</td>
181
+ </tr>
182
+ <tr>
183
+ <td align="center" style="vertical-align: middle">Tau2-Airline</td>
184
+ <td align="center" style="vertical-align: middle"><strong>54.00</strong></td>
185
+ <td align="center" style="vertical-align: middle">32.00</td>
186
+ <td align="center" style="vertical-align: middle">52.00</td>
187
+ </tr>
188
+ <tr>
189
+ <td align="center" style="vertical-align: middle">Tau2-Telecom</td>
190
+ <td align="center" style="vertical-align: middle">79.83</td>
191
+ <td align="center" style="vertical-align: middle">4.39</td>
192
+ <td align="center" style="vertical-align: middle"><strong>88.60</strong></td>
193
+ </tr>
194
+
195
+ <tr>
196
+ <td align="center" colspan=8><strong>Long Context</strong></td>
197
+ </tr>
198
+ <tr>
199
+ <td align="center" style="vertical-align: middle">RULER</td>
200
+ <td align="center" style="vertical-align: middle"><strong>95.60</strong></td>
201
+ <td align="center" style="vertical-align: middle">89.66</td>
202
+ <td align="center" style="vertical-align: middle">56.12</td>
203
+ </tr>
204
+ </tbody>
205
+ </table>
206
+
207
+
208
+ ## 4. Deployment
209
+
210
+ > [!Note]
211
+ > You can access JoyAI-LLM Flash API on https://docs.jdcloud.com/cn/jdaip/chat and we provide OpenAI/Anthropic-compatible API for you.
212
+ > Currently, JoyAI-LLM-Flash-FP8 is recommended to run on the following inference engines:
213
+
214
+ * vLLM
215
+ * SGLang
216
+
217
+ The minimum version requirement for `transformers` is `4.57.1`.
218
+
219
+ Deployment examples can be found in the [Model Deployment Guide](docs/deploy_guidance.md).
220
+
221
+
222
+
223
+ ## 5. Model Usage
224
+
225
+ The usage demos below demonstrate how to call our official API.
226
+
227
+ For third-party APIs deployed with vLLM or SGLang, please note that:
228
+
229
+ > [!Note] Recommended sampling parameters: `temperature=0.6`, `top_p=1.0`
230
+
231
+ ### Chat Completion
232
+
233
+ This is a simple chat completion script which shows how to call JoyAI-Flash API.
234
+
235
+ ```python
236
+ from openai import OpenAI
237
+
238
+ client = OpenAI(base_url="http://IP:PORT/v1", api_key="EMPTY")
239
+
240
+
241
+ def simple_chat(client: OpenAI):
242
+ messages = [
243
+ {
244
+ "role": "user",
245
+ "content": [
246
+ {
247
+ "type": "text",
248
+ "text": "which one is bigger, 9.11 or 9.9? think carefully.",
249
+ }
250
+ ],
251
+ },
252
+ ]
253
+ model_name = client.models.list().data[0].id
254
+ response = client.chat.completions.create(
255
+ model=model_name, messages=messages, stream=False, max_tokens=4096
256
+ )
257
+ print(f"response: {response.choices[0].message.content}")
258
+
259
+
260
+ if __name__ == "__main__":
261
+ simple_chat(client)
262
+ ```
263
+
264
+
265
+ ### Tool call Completion
266
+
267
+ This is a simple toll call completion script which shows how to call JoyAI-Flash API.
268
+
269
+ ```python
270
+ import json
271
+
272
+ from openai import OpenAI
273
+
274
+ client = OpenAI(base_url="http://IP:PORT/v1", api_key="EMPTY")
275
+
276
+
277
+ def my_calculator(expression: str) -> str:
278
+ return str(eval(expression))
279
+
280
+
281
+ def rewrite(expression: str) -> str:
282
+ return str(expression)
283
+
284
+
285
+ def simple_tool_call(client: OpenAI):
286
+ messages = [
287
+ {
288
+ "role": "user",
289
+ "content": [
290
+ {
291
+ "type": "text",
292
+ "text": "use my functions to compute the results for the equations: 6+1",
293
+ },
294
+ ],
295
+ },
296
+ ]
297
+ tools = [
298
+ {
299
+ "type": "function",
300
+ "function": {
301
+ "name": "my_calculator",
302
+ "description": "A calculator that can evaluate a mathematical equation and compute its results.",
303
+ "parameters": {
304
+ "type": "object",
305
+ "properties": {
306
+ "expression": {
307
+ "type": "string",
308
+ "description": "The mathematical expression to evaluate.",
309
+ },
310
+ },
311
+ "required": ["expression"],
312
+ },
313
+ },
314
+ },
315
+ {
316
+ "type": "function",
317
+ "function": {
318
+ "name": "rewrite",
319
+ "description": "Rewrite a given text for improved clarity",
320
+ "parameters": {
321
+ "type": "object",
322
+ "properties": {
323
+ "text": {
324
+ "type": "string",
325
+ "description": "The input text to rewrite",
326
+ }
327
+ },
328
+ },
329
+ },
330
+ },
331
+ ]
332
+ model_name = client.models.list().data[0].id
333
+ response = client.chat.completions.create(
334
+ model=model_name,
335
+ messages=messages,
336
+ temperature=1.0,
337
+ max_tokens=1024,
338
+ tools=tools,
339
+ tool_choice="auto",
340
+ )
341
+ tool_calls = response.choices[0].message.tool_calls
342
+
343
+ results = []
344
+ for tool_call in tool_calls:
345
+ function_name = tool_call.function.name
346
+ function_args = tool_call.function.arguments
347
+ if function_name == "my_calculator":
348
+ result = my_calculator(**json.loads(function_args))
349
+ results.append(result)
350
+ messages.append({"role": "assistant", "tool_calls": tool_calls})
351
+ for tool_call, result in zip(tool_calls, results):
352
+ messages.append(
353
+ {
354
+ "role": "tool",
355
+ "tool_call_id": tool_call.id,
356
+ "name": tool_call.function.name,
357
+ "content": result,
358
+ }
359
+ )
360
+ response = client.chat.completions.create(
361
+ model=model_name,
362
+ messages=messages,
363
+ temperature=1.0,
364
+ max_tokens=1024,
365
+ )
366
+ print(response.choices[0].message.content)
367
+
368
+
369
+ if __name__ == "__main__":
370
+ simple_tool_call(client)
371
+
372
+ ```
373
+
374
+ ---
375
+
376
+ ## 6. License
377
+
378
+ Both the code repository and the model weights are released under the [Modified MIT License](LICENSE).
chat_template.jinja ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- macro render_extra_keys(json_dict, handled_keys) -%}
2
+ {%- if json_dict is mapping -%}
3
+ {%- for json_key in json_dict if json_key not in handled_keys -%}
4
+ {%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) -%}
5
+ {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' -}}
6
+ {%- else -%}
7
+ {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' -}}
8
+ {%- endif -%}
9
+ {%- endfor -%}
10
+ {%- endif -%}
11
+ {%- endmacro -%}
12
+
13
+ {%- if not add_generation_prompt is defined -%}{%- set add_generation_prompt = false -%}{%- endif -%}
14
+
15
+ {%- set ns = namespace(system_prompt='', is_first_sp=true, is_last_user=false) -%}
16
+ {%- set default_system = "You are JoyAI , a large language model trained by JD(京东)that can interact with a computer to solve tasks. Answer as concisely as possible." -%}
17
+ {%- set ns.system_prompt = default_system -%}
18
+
19
+ {%- for message in messages -%}
20
+ {%- if message['role'] == 'system' -%}
21
+ {%- if ns.is_first_sp -%}
22
+ {%- set ns.system_prompt = message['content'] -%}
23
+ {%- set ns.is_first_sp = false -%}
24
+ {%- else -%}
25
+ {%- set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] -%}
26
+ {%- endif -%}
27
+ {%- endif -%}
28
+ {%- endfor -%}
29
+
30
+ {{- bos_token -}}{{- ns.system_prompt -}}
31
+ {%- if tools is iterable and tools | length > 0 -%}
32
+ {{- "\n\n# Tools\n\nYou have access to the following functions:\n\n" }}
33
+ {{- "<tools>" }}
34
+ {%- for tool in tools %}
35
+ {%- if tool.function is defined %}
36
+ {%- set tool = tool.function %}
37
+ {%- endif %}
38
+ {{- "\n<function>\n<name>" ~ tool.name ~ "</name>" }}
39
+ {%- if tool.description is defined %}
40
+ {{- '\n<description>' ~ (tool.description | trim) ~ '</description>' }}
41
+ {%- endif %}
42
+ {{- '\n<parameters>' }}
43
+ {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %}
44
+ {%- for param_name, param_fields in tool.parameters.properties|items %}
45
+ {{- '\n<parameter>' }}
46
+ {{- '\n<name>' ~ param_name ~ '</name>' }}
47
+ {%- if param_fields.type is defined %}
48
+ {{- '\n<type>' ~ (param_fields.type | string) ~ '</type>' }}
49
+ {%- endif %}
50
+ {%- if param_fields.description is defined %}
51
+ {{- '\n<description>' ~ (param_fields.description | trim) ~ '</description>' }}
52
+ {%- endif %}
53
+ {%- set handled_keys = ['name', 'type', 'description'] %}
54
+ {{- render_extra_keys(param_fields, handled_keys) }}
55
+ {{- '\n</parameter>' }}
56
+ {%- endfor %}
57
+ {%- endif %}
58
+ {% set handled_keys = ['type', 'properties'] %}
59
+ {{- render_extra_keys(tool.parameters, handled_keys) }}
60
+ {{- '\n</parameters>' }}
61
+ {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %}
62
+ {{- render_extra_keys(tool, handled_keys) }}
63
+ {{- '\n</function>' }}
64
+ {%- endfor %}
65
+ {{- "\n</tools>" }}
66
+ {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
67
+ {%- endif %}
68
+ {%- for message in messages -%}
69
+ {%- if message['role'] == 'user' -%}
70
+ {%- set ns.is_last_user = true -%}
71
+ {{- '<|User|>' + message['content'] -}}
72
+ {%- elif message['role'] == 'assistant' -%}
73
+ {%- if ns.is_last_user -%}
74
+ {{ '<|Assistant|>' }}
75
+ {%- endif -%}
76
+ {%- set ns.is_last_user = false -%}
77
+ {%- set content = message.get('content') | default('', true) -%}
78
+ {{ '<|end_of_thought|>' + content }}
79
+ {%- if message['tool_calls'] is defined and message['tool_calls'] is not none -%}
80
+ {%- for tool in message['tool_calls'] -%}
81
+ {%- if tool.function is defined %}{% set tool = tool.function %}{% endif -%}
82
+ {{- '\n<tool_call>\n<function=' + tool.name + '>\n' -}}
83
+ {%- if tool.arguments is defined -%}
84
+ {%- if tool.arguments is string -%}{%- set args_data = tool.arguments | from_json -%}{%- else -%}{%- set args_data = tool.arguments -%}{%- endif -%}
85
+ {%- for args_name, args_value in args_data.items() -%}
86
+ {{- '<parameter=' + args_name + '>\n' -}}
87
+ {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string -%}
88
+ {{- args_value -}}{{- '\n</parameter>\n' -}}
89
+ {%- endfor -%}
90
+ {%- endif -%}
91
+ {{- '</function>\n</tool_call>' -}}
92
+ {%- endfor -%}
93
+ {%- endif -%}
94
+ {{ '<|end▁of▁sentence|>' }}
95
+ {%- elif message['role'] == 'tool' -%}
96
+ {%- set ns.is_last_user = true -%}
97
+ {{ '\n<tool_response>\n' + message['content'] + '\n</tool_response>' }}
98
+ {%- endif -%}
99
+ {%- endfor -%}
100
+
101
+ {%- if add_generation_prompt -%}
102
+ {{ '<|Assistant|>' }}{{ '<|end_of_thought|>' }}
103
+ {%- endif -%}
config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DeepseekV3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_deepseek.DeepseekV3Config",
9
+ "AutoModel": "modeling_deepseek.DeepseekV3Model",
10
+ "AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM"
11
+ },
12
+ "bos_token_id": 0,
13
+ "eos_token_id": 1,
14
+ "ep_size": 1,
15
+ "first_k_dense_replace": 1,
16
+ "hidden_act": "silu",
17
+ "hidden_size": 2048,
18
+ "initializer_range": 0.02,
19
+ "intermediate_size": 7168,
20
+ "kv_lora_rank": 512,
21
+ "max_position_embeddings": 131072,
22
+ "model_type": "joyai_llm_flash",
23
+ "moe_intermediate_size": 768,
24
+ "moe_layer_freq": 1,
25
+ "n_group": 1,
26
+ "n_routed_experts": 256,
27
+ "n_shared_experts": 1,
28
+ "norm_topk_prob": true,
29
+ "num_attention_heads": 32,
30
+ "num_experts_per_tok": 8,
31
+ "num_hidden_layers": 40,
32
+ "num_key_value_heads": 32,
33
+ "num_nextn_predict_layers": 1,
34
+ "q_lora_rank": 1536,
35
+ "qk_nope_head_dim": 128,
36
+ "qk_rope_head_dim": 64,
37
+ "quantization_config": {
38
+ "activation_scheme": "dynamic",
39
+ "fmt": "e4m3",
40
+ "quant_method": "fp8",
41
+ "weight_block_size": [
42
+ 128,
43
+ 128
44
+ ]
45
+ },
46
+ "rms_norm_eps": 1e-06,
47
+ "rope_theta": 32000000,
48
+ "routed_scaling_factor": 2.5,
49
+ "scoring_func": "sigmoid",
50
+ "tie_word_embeddings": false,
51
+ "topk_group": 1,
52
+ "topk_method": "noaux_tc",
53
+ "torch_dtype": "bfloat16",
54
+ "transformers_version": "4.44.2",
55
+ "use_cache": true,
56
+ "v_head_dim": 128,
57
+ "vocab_size": 129280
58
+ }
configuration_deepseek.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 bzantium and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on the DeepSeekV3 implementations from the DeepSeek AI team. (https://huggingface.co/deepseek-ai/DeepSeek-V3)
5
+
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """DeepSeekV3 model configuration"""
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.modeling_rope_utils import rope_config_validation
21
+
22
+
23
+ DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
24
+
25
+
26
+ class DeepseekV3Config(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
29
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
30
+ defaults will yield a similar configuration to that of the DeepSeek-V3.
31
+ e.g. [bzantium/tiny-deepseek-v3](https://huggingface.co/bzantium/tiny-deepseek-v3)
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+
36
+ Args:
37
+ vocab_size (`int`, *optional*, defaults to 129280):
38
+ Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
39
+ `inputs_ids` passed when calling [`DeepseekV3Model`]
40
+ hidden_size (`int`, *optional*, defaults to 7168):
41
+ Dimension of the hidden representations.
42
+ intermediate_size (`int`, *optional*, defaults to 18432):
43
+ Dimension of the MLP representations.
44
+ moe_intermediate_size (`int`, *optional*, defaults to 2048):
45
+ Dimension of the MoE representations.
46
+ num_hidden_layers (`int`, *optional*, defaults to 61):
47
+ Number of hidden layers in the Transformer decoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 128):
49
+ Number of attention heads for each attention layer in the Transformer decoder.
50
+ num_key_value_heads (`int`, *optional*, defaults to 128):
51
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55
+ by meanpooling all the original heads within that group. For more details checkout [this
56
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
57
+ `num_attention_heads`.
58
+ n_shared_experts (`int`, *optional*, defaults to 1):
59
+ Number of shared experts.
60
+ n_routed_experts (`int`, *optional*, defaults to 256):
61
+ Number of routed experts.
62
+ routed_scaling_factor (`float`, *optional*, defaults to 2.5):
63
+ Scaling factor or routed experts.
64
+ kv_lora_rank (`int`, *optional*, defaults to 512):
65
+ Rank of the LoRA matrices for key and value projections.
66
+ q_lora_rank (`int`, *optional*, defaults to 1536):
67
+ Rank of the LoRA matrices for query projections.
68
+ qk_rope_head_dim (`int`, *optional*, defaults to 64):
69
+ Dimension of the query/key heads that use rotary position embeddings.
70
+ v_head_dim (`int`, *optional*, defaults to 128):
71
+ Dimension of the value heads.
72
+ qk_nope_head_dim (`int`, *optional*, defaults to 128):
73
+ Dimension of the query/key heads that don't use rotary position embeddings.
74
+ n_group (`int`, *optional*, defaults to 8):
75
+ Number of groups for routed experts.
76
+ topk_group (`int`, *optional*, defaults to 4):
77
+ Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
78
+ num_experts_per_tok (`int`, *optional*, defaults to 8):
79
+ Number of selected experts, None means dense model.
80
+ first_k_dense_replace (`int`, *optional*, defaults to 3):
81
+ Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
82
+ \--k dense layers--/
83
+ norm_topk_prob (`bool`, *optional*, defaults to `True`):
84
+ Whether to normalize the weights of the routed experts.
85
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
86
+ The non-linear activation function (function or string) in the decoder.
87
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
88
+ The maximum sequence length that this model might ever be used with.
89
+ initializer_range (`float`, *optional*, defaults to 0.02):
90
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
91
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
92
+ The epsilon used by the rms normalization layers.
93
+ use_cache (`bool`, *optional*, defaults to `True`):
94
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
95
+ relevant if `config.is_decoder=True`.
96
+ pad_token_id (`int`, *optional*):
97
+ Padding token id.
98
+ bos_token_id (`int`, *optional*, defaults to 0):
99
+ Beginning of stream token id.
100
+ eos_token_id (`int`, *optional*, defaults to 1):
101
+ End of stream token id.
102
+ pretraining_tp (`int`, *optional*, defaults to 1):
103
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
104
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
105
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
106
+ issue](https://github.com/pytorch/pytorch/issues/76232).
107
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
108
+ Whether to tie weight embeddings
109
+ rope_theta (`float`, *optional*, defaults to 10000.0):
110
+ The base period of the RoPE embeddings.
111
+ rope_scaling (`Dict`, *optional*):
112
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
113
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
114
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
115
+ `max_position_embeddings` to the expected new maximum.
116
+ rope_interleave (`bool`, *optional*, defaults to `True`):
117
+ Whether to interleave the rotary position embeddings.
118
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
119
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
120
+ attention_dropout (`float`, *optional*, defaults to 0.0):
121
+ The dropout ratio for the attention probabilities.
122
+
123
+ ```python
124
+ >>> from transformers import DeepseekV3Model, DeepseekV3Config
125
+
126
+ >>> # Initializing a Deepseek-V3 style configuration
127
+ >>> configuration = DeepseekV3Config()
128
+
129
+ >>> # Accessing the model configuration
130
+ >>> configuration = model.config
131
+ ```"""
132
+
133
+ model_type = "deepseek_v3"
134
+ keys_to_ignore_at_inference = ["past_key_values"]
135
+ base_model_tp_plan = { # TODO: only replicate attention layers when > first_k_dense_replace
136
+ "layers.*.mlp.experts.*.gate_proj": "local_colwise",
137
+ "layers.*.mlp.experts.*.up_proj": "local_colwise",
138
+ "layers.*.mlp.experts.*.down_proj": "local_rowwise",
139
+ "layers.*.mlp.experts.*": "local", # each expert is wrapped in a module list
140
+ "layers.*.mlp.shared_experts.gate_proj": "local_colwise",
141
+ "layers.*.mlp.shared_experts.up_proj": "local_colwise",
142
+ "layers.*.mlp.shared_experts.down_proj": "local_rowwise",
143
+ "layers.*.mlp.shared_experts": "local",
144
+ "layers.*.mlp.gate_proj": "local_colwise",
145
+ "layers.*.mlp.up_proj": "local_colwise",
146
+ "layers.*.mlp.down_proj": "local_rowwise",
147
+ "layers.*.mlp": "gather", # This is the only moment where results are gathered
148
+ }
149
+ base_model_pp_plan = {
150
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
151
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
152
+ "norm": (["hidden_states"], ["hidden_states"]),
153
+ }
154
+
155
+ def __init__(
156
+ self,
157
+ vocab_size=129280,
158
+ hidden_size=7168,
159
+ intermediate_size=18432,
160
+ moe_intermediate_size=2048,
161
+ num_hidden_layers=61,
162
+ num_attention_heads=128,
163
+ num_key_value_heads=128,
164
+ n_shared_experts=1,
165
+ n_routed_experts=256,
166
+ routed_scaling_factor=2.5,
167
+ kv_lora_rank=512,
168
+ q_lora_rank=1536,
169
+ qk_rope_head_dim=64,
170
+ v_head_dim=128,
171
+ qk_nope_head_dim=128,
172
+ n_group=8,
173
+ topk_group=4,
174
+ num_experts_per_tok=8,
175
+ first_k_dense_replace=3,
176
+ norm_topk_prob=True,
177
+ hidden_act="silu",
178
+ max_position_embeddings=4096,
179
+ initializer_range=0.02,
180
+ rms_norm_eps=1e-6,
181
+ use_cache=True,
182
+ pad_token_id=None,
183
+ bos_token_id=0,
184
+ eos_token_id=1,
185
+ pretraining_tp=1,
186
+ tie_word_embeddings=False,
187
+ rope_theta=10000.0,
188
+ rope_scaling=None,
189
+ rope_interleave=True,
190
+ attention_bias=False,
191
+ attention_dropout=0.0,
192
+ **kwargs,
193
+ ):
194
+ self.vocab_size = vocab_size
195
+ self.max_position_embeddings = max_position_embeddings
196
+ self.hidden_size = hidden_size
197
+ self.intermediate_size = intermediate_size
198
+ self.moe_intermediate_size = moe_intermediate_size
199
+ self.num_hidden_layers = num_hidden_layers
200
+ self.num_attention_heads = num_attention_heads
201
+ self.n_shared_experts = n_shared_experts
202
+ self.n_routed_experts = n_routed_experts
203
+ self.routed_scaling_factor = routed_scaling_factor
204
+ self.kv_lora_rank = kv_lora_rank
205
+ self.q_lora_rank = q_lora_rank
206
+ self.qk_rope_head_dim = qk_rope_head_dim
207
+ self.v_head_dim = v_head_dim
208
+ self.qk_nope_head_dim = qk_nope_head_dim
209
+ self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
210
+ self.head_dim = qk_rope_head_dim
211
+ self.n_group = n_group
212
+ self.topk_group = topk_group
213
+ self.num_experts_per_tok = num_experts_per_tok
214
+ self.first_k_dense_replace = first_k_dense_replace
215
+ self.norm_topk_prob = norm_topk_prob
216
+ self.rope_interleave = rope_interleave
217
+
218
+ # for backward compatibility
219
+ if num_key_value_heads is None:
220
+ num_key_value_heads = num_attention_heads
221
+
222
+ self.num_key_value_heads = num_key_value_heads
223
+ self.hidden_act = hidden_act
224
+ self.initializer_range = initializer_range
225
+ self.rms_norm_eps = rms_norm_eps
226
+ self.pretraining_tp = pretraining_tp
227
+ self.use_cache = use_cache
228
+ self.rope_theta = rope_theta
229
+ self.rope_scaling = rope_scaling
230
+ self.attention_bias = attention_bias
231
+ self.attention_dropout = attention_dropout
232
+ # Validate the correctness of rotary position embeddings parameters
233
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
234
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
235
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
236
+ rope_config_validation(self)
237
+
238
+ super().__init__(
239
+ pad_token_id=pad_token_id,
240
+ bos_token_id=bos_token_id,
241
+ eos_token_id=eos_token_id,
242
+ tie_word_embeddings=tie_word_embeddings,
243
+ **kwargs,
244
+ )
245
+
246
+
247
+ __all__ = ["DeepseekV3Config"]
docs/deploy_guidance.md ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Deployment Guide
2
+
3
+ > [!Note]
4
+ > This guide offers a selection of deployment command examples for JoyAI-LLM Flash, which may not be the optimal configuration. Given the rapid evolution of inference engines, we recommend referring to their official documentation for the latest updates to ensure peak performance.
5
+
6
+ > Support for JoyAI-LLM Flash’s dense MTP architecture is currently being integrated into vLLM and SGLang. Until these PRs are merged into a stable release, please use the nightly Docker image for access to these features.
7
+
8
+ ## vLLM Deployment
9
+
10
+ Here is the example to serve this model on a single GPU card via vLLM:
11
+
12
+ 1. pull the Docker image.
13
+ ```bash
14
+ docker pull jdopensource/joyai-llm-vllm:v0.15.1-joyai_llm_flash
15
+ ```
16
+ 2. launch JoyAI-LLM Flash model with dense MTP (Also quantized to FP8).
17
+ ```bash
18
+ vllm serve jdopensource/JoyAI-LLM-Flash-FP8 -tp 1 --trust-remote-code \
19
+ --tool-call-parser qwen3_coder --enable-auto-tool-choice \
20
+ --speculative-config $'{"method": "mtp", "num_speculative_tokens": 3}'
21
+ ```
22
+ **Key notes**
23
+ - `--tool-call-parser qwen3_coder`: Required for enabling tool calling
24
+
25
+ ## SGLang Deployment
26
+
27
+ Similarly, here is the example to run on a single GPU card via SGLang:
28
+
29
+ 1. pull the Docker image.
30
+ ```bash
31
+ docker pull jdopensource/joyai-llm-sglang:v0.5.8-joyai_llm_flash
32
+ ```
33
+ 2. launch JoyAI-LLM Flash model with dense MTP (Also quantized to FP8).
34
+
35
+ ```bash
36
+ python3 -m sglang.launch_server --model-path jdopensource/JoyAI-LLM-Flash-FP8 --tp-size 1 --trust-remote-code \
37
+ --tool-call-parser qwen3_coder \
38
+ --speculative-algorithm EAGLE \
39
+ --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4
40
+ ```
41
+ **Key notes:**
42
+ - `--tool-call-parser qwen3_coder`: Required when enabling tool usage.
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_deepseek.py ADDED
@@ -0,0 +1,1030 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/deepseek_v3/modular_deepseek_v3.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_deepseek_v3.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ import math
8
+ from functools import partial
9
+ from typing import Callable, Optional, Tuple, Union
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+ from transformers.activations import ACT2FN
15
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
16
+ from transformers.generation import GenerationMixin
17
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
18
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
19
+ from transformers.modeling_outputs import (
20
+ BaseModelOutputWithPast,
21
+ CausalLMOutputWithPast,
22
+ )
23
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
24
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
25
+ from transformers.processing_utils import Unpack
26
+ from transformers.utils import (
27
+ LossKwargs,
28
+ add_start_docstrings,
29
+ add_start_docstrings_to_model_forward,
30
+ can_return_tuple,
31
+ is_torch_flex_attn_available,
32
+ logging,
33
+ replace_return_docstrings,
34
+ )
35
+ from transformers.utils.deprecation import deprecate_kwarg
36
+
37
+ from .configuration_deepseek import DeepseekV3Config
38
+
39
+ if is_torch_flex_attn_available():
40
+ from torch.nn.attention.flex_attention import BlockMask
41
+ from transformers.integrations.flex_attention import make_flex_block_causal_mask
42
+
43
+
44
+ logger = logging.get_logger(__name__)
45
+ _CONFIG_FOR_DOC = "DeepseekV3Config"
46
+
47
+
48
+ class DeepseekV3RMSNorm(nn.Module):
49
+ def __init__(self, hidden_size, eps=1e-6):
50
+ """
51
+ DeepseekV3RMSNorm is equivalent to T5LayerNorm
52
+ """
53
+ super().__init__()
54
+ self.weight = nn.Parameter(torch.ones(hidden_size))
55
+ self.variance_epsilon = eps
56
+
57
+ def forward(self, hidden_states):
58
+ input_dtype = hidden_states.dtype
59
+ hidden_states = hidden_states.to(torch.float32)
60
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
61
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
62
+ return self.weight * hidden_states.to(input_dtype)
63
+
64
+ def extra_repr(self):
65
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
66
+
67
+
68
+ class DeepseekV3RotaryEmbedding(nn.Module):
69
+ def __init__(self, config: DeepseekV3Config, device=None):
70
+ super().__init__()
71
+ # BC: "rope_type" was originally "type"
72
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
73
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
74
+ else:
75
+ self.rope_type = "default"
76
+ self.max_seq_len_cached = config.max_position_embeddings
77
+ self.original_max_seq_len = config.max_position_embeddings
78
+
79
+ self.config = config
80
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
81
+
82
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
83
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
84
+ self.original_inv_freq = self.inv_freq
85
+
86
+ @torch.no_grad()
87
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
88
+ def forward(self, x, position_ids):
89
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
90
+ position_ids_expanded = position_ids[:, None, :].float()
91
+
92
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
93
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
94
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
95
+ emb = torch.cat((freqs, freqs), dim=-1)
96
+ cos = emb.cos() * self.attention_scaling
97
+ sin = emb.sin() * self.attention_scaling
98
+
99
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
100
+
101
+
102
+ class DeepseekV3MLP(nn.Module):
103
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
104
+ super().__init__()
105
+ self.config = config
106
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
107
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
108
+
109
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
110
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
111
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
112
+ self.act_fn = ACT2FN[config.hidden_act]
113
+
114
+ def forward(self, x):
115
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
116
+ return down_proj
117
+
118
+
119
+ class DeepseekV3TopkRouter(nn.Module):
120
+ def __init__(self, config):
121
+ super().__init__()
122
+ self.config = config
123
+ self.top_k = config.num_experts_per_tok
124
+ self.n_routed_experts = config.n_routed_experts
125
+ self.routed_scaling_factor = config.routed_scaling_factor
126
+ self.n_group = config.n_group
127
+ self.topk_group = config.topk_group
128
+ self.norm_topk_prob = config.norm_topk_prob
129
+
130
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
131
+ self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts)))
132
+
133
+ @torch.no_grad()
134
+ def get_topk_indices(self, scores):
135
+ scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
136
+ group_scores = (
137
+ scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
138
+ .topk(2, dim=-1)[0]
139
+ .sum(dim=-1)
140
+ )
141
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
142
+ group_mask = torch.zeros_like(group_scores)
143
+ group_mask.scatter_(1, group_idx, 1)
144
+ score_mask = (
145
+ group_mask.unsqueeze(-1)
146
+ .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
147
+ .reshape(-1, self.n_routed_experts)
148
+ )
149
+ scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
150
+ topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
151
+ return topk_indices
152
+
153
+ def forward(self, hidden_states):
154
+ hidden_states = hidden_states.view(-1, self.config.hidden_size)
155
+ router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
156
+ scores = router_logits.sigmoid()
157
+ topk_indices = self.get_topk_indices(scores)
158
+ topk_weights = scores.gather(1, topk_indices)
159
+ if self.norm_topk_prob:
160
+ denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
161
+ topk_weights /= denominator
162
+ topk_weights = topk_weights * self.routed_scaling_factor
163
+ return topk_indices, topk_weights
164
+
165
+
166
+ class DeepseekV3MoE(nn.Module):
167
+ """
168
+ A mixed expert module containing shared experts.
169
+ """
170
+
171
+ def __init__(self, config):
172
+ super().__init__()
173
+ self.config = config
174
+ self.experts = nn.ModuleList(
175
+ [
176
+ DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
177
+ for _ in range(config.n_routed_experts)
178
+ ]
179
+ )
180
+ self.gate = DeepseekV3TopkRouter(config)
181
+ self.shared_experts = DeepseekV3MLP(
182
+ config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
183
+ )
184
+
185
+ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
186
+ r"""
187
+ CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
188
+ to not have to do a loop here (deepseek has 256 experts soooo yeah).
189
+ """
190
+ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
191
+ expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
192
+ expert_mask = expert_mask.permute(2, 0, 1)
193
+
194
+ for expert_idx in range(len(self.experts)):
195
+ expert = self.experts[expert_idx]
196
+ mask = expert_mask[expert_idx]
197
+ token_indices, weight_indices = torch.where(mask)
198
+
199
+ if token_indices.numel() > 0:
200
+ expert_weights = topk_weights[token_indices, weight_indices]
201
+ expert_input = hidden_states[token_indices]
202
+ expert_output = expert(expert_input)
203
+ weighted_output = expert_output * expert_weights.unsqueeze(-1)
204
+ final_hidden_states.index_add_(0, token_indices, weighted_output)
205
+
206
+ # in original deepseek, the output of the experts are gathered once we leave this module
207
+ # thus the moe module is itelsf an IsolatedParallel module
208
+ # and all expert are "local" meaning we shard but we don't gather
209
+ return final_hidden_states.type(hidden_states.dtype)
210
+
211
+ def forward(self, hidden_states):
212
+ residuals = hidden_states
213
+ orig_shape = hidden_states.shape
214
+ topk_indices, topk_weights = self.gate(hidden_states)
215
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
216
+ hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
217
+ hidden_states = hidden_states + self.shared_experts(residuals)
218
+ return hidden_states
219
+
220
+
221
+ def rotate_half(x):
222
+ """Rotates half the hidden dims of the input."""
223
+ x1 = x[..., : x.shape[-1] // 2]
224
+ x2 = x[..., x.shape[-1] // 2 :]
225
+ return torch.cat((-x2, x1), dim=-1)
226
+
227
+
228
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
229
+ """Applies Rotary Position Embedding to the query and key tensors.
230
+
231
+ Args:
232
+ q (`torch.Tensor`): The query tensor.
233
+ k (`torch.Tensor`): The key tensor.
234
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
235
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
236
+ position_ids (`torch.Tensor`, *optional*):
237
+ Deprecated and unused.
238
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
239
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
240
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
241
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
242
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
243
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
244
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
245
+ Returns:
246
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
247
+ """
248
+ cos = cos.unsqueeze(unsqueeze_dim)
249
+ sin = sin.unsqueeze(unsqueeze_dim)
250
+ q_embed = (q * cos) + (rotate_half(q) * sin)
251
+ k_embed = (k * cos) + (rotate_half(k) * sin)
252
+ return q_embed, k_embed
253
+
254
+
255
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
256
+ """
257
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
258
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
259
+ """
260
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
261
+ if n_rep == 1:
262
+ return hidden_states
263
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
264
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
265
+
266
+
267
+ def eager_attention_forward(
268
+ module: nn.Module,
269
+ query: torch.Tensor,
270
+ key: torch.Tensor,
271
+ value: torch.Tensor,
272
+ attention_mask: Optional[torch.Tensor],
273
+ scaling: float,
274
+ dropout: float = 0.0,
275
+ **kwargs,
276
+ ):
277
+ key_states = repeat_kv(key, module.num_key_value_groups)
278
+ value_states = repeat_kv(value, module.num_key_value_groups)
279
+
280
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
281
+ if attention_mask is not None:
282
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
283
+ attn_weights = attn_weights + causal_mask
284
+
285
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
286
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
287
+ attn_output = torch.matmul(attn_weights, value_states)
288
+ attn_output = attn_output.transpose(1, 2).contiguous()
289
+
290
+ return attn_output, attn_weights
291
+
292
+
293
+ def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
294
+ r"""
295
+ TODO let's just use the original freqcis computation to not have the view
296
+ transpose + reshape! This is not optimized!
297
+ Applies Rotary Position Embedding to the query and key tensors.
298
+
299
+ Args:
300
+ q (`torch.Tensor`): The query tensor.
301
+ k (`torch.Tensor`): The key tensor.
302
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
303
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
304
+ position_ids (`torch.Tensor`):
305
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
306
+ used to pass offsetted position ids when working with a KV-cache.
307
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
308
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
309
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
310
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
311
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
312
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
313
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
314
+ Returns:
315
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
316
+ """
317
+ cos = cos.unsqueeze(unsqueeze_dim)
318
+ sin = sin.unsqueeze(unsqueeze_dim)
319
+
320
+ b, h, s, d = q.shape
321
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
322
+
323
+ b, h, s, d = k.shape
324
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
325
+
326
+ q_embed = (q * cos) + (rotate_half(q) * sin)
327
+ k_embed = (k * cos) + (rotate_half(k) * sin)
328
+ return q_embed, k_embed
329
+
330
+
331
+ def yarn_get_mscale(scale=1, mscale=1):
332
+ if scale <= 1:
333
+ return 1.0
334
+ return 0.1 * mscale * math.log(scale) + 1.0
335
+
336
+
337
+ class DeepseekV3Attention(nn.Module):
338
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
339
+
340
+ def __init__(self, config: DeepseekV3Config, layer_idx: int):
341
+ super().__init__()
342
+ self.config = config
343
+ self.layer_idx = layer_idx
344
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
345
+ self.attention_dropout = config.attention_dropout
346
+ self.num_heads = config.num_attention_heads
347
+ self.rope_theta = config.rope_theta
348
+ self.q_lora_rank = config.q_lora_rank
349
+ self.qk_rope_head_dim = config.qk_rope_head_dim
350
+ self.kv_lora_rank = config.kv_lora_rank
351
+ self.v_head_dim = config.v_head_dim
352
+ self.qk_nope_head_dim = config.qk_nope_head_dim
353
+ self.qk_head_dim = config.qk_head_dim
354
+
355
+ self.is_causal = True
356
+ self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
357
+ self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
358
+ self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
359
+
360
+ self.kv_a_proj_with_mqa = nn.Linear(
361
+ config.hidden_size,
362
+ self.kv_lora_rank + self.qk_rope_head_dim,
363
+ bias=config.attention_bias,
364
+ )
365
+ self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank)
366
+ self.kv_b_proj = nn.Linear(
367
+ self.kv_lora_rank,
368
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
369
+ bias=False,
370
+ )
371
+
372
+ self.o_proj = nn.Linear(
373
+ self.num_heads * self.v_head_dim,
374
+ config.hidden_size,
375
+ bias=config.attention_bias,
376
+ )
377
+
378
+ self.scaling = self.qk_head_dim ** (-0.5)
379
+ if self.config.rope_scaling is not None:
380
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
381
+ scaling_factor = self.config.rope_scaling["factor"]
382
+ if mscale_all_dim:
383
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
384
+ self.scaling = self.scaling * mscale * mscale
385
+
386
+ def forward(
387
+ self,
388
+ hidden_states: torch.Tensor,
389
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
390
+ attention_mask: Optional[torch.Tensor],
391
+ past_key_value: Optional[Cache] = None,
392
+ cache_position: Optional[torch.LongTensor] = None,
393
+ **kwargs: Unpack[FlashAttentionKwargs],
394
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
395
+ batch_size, seq_length = hidden_states.shape[:-1]
396
+ query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
397
+ key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
398
+
399
+ q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2)
400
+ q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
401
+
402
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
403
+ k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
404
+
405
+ k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
406
+ k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
407
+
408
+ k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
409
+
410
+ cos, sin = position_embeddings
411
+ if self.config.rope_interleave: # support using interleaved weights for efficiency
412
+ q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
413
+ else:
414
+ q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
415
+ k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
416
+
417
+ query_states = torch.cat((q_pass, q_rot), dim=-1)
418
+ key_states = torch.cat((k_pass, k_rot), dim=-1)
419
+
420
+ if past_key_value is not None:
421
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
422
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
423
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
424
+
425
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
426
+ value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
427
+
428
+ attention_interface: Callable = eager_attention_forward
429
+ if self.config._attn_implementation != "eager":
430
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
431
+ logger.warning_once(
432
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
433
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
434
+ )
435
+ else:
436
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
437
+
438
+ attn_output, attn_weights = attention_interface(
439
+ self,
440
+ query_states,
441
+ key_states,
442
+ value_states,
443
+ attention_mask,
444
+ dropout=0.0 if not self.training else self.attention_dropout,
445
+ scaling=self.scaling,
446
+ **kwargs,
447
+ )
448
+
449
+ if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
450
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
451
+
452
+ attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
453
+ attn_output = self.o_proj(attn_output)
454
+ return attn_output, attn_weights
455
+
456
+
457
+ class DeepseekV3DecoderLayer(nn.Module):
458
+ def __init__(self, config: DeepseekV3Config, layer_idx: int):
459
+ super().__init__()
460
+ self.hidden_size = config.hidden_size
461
+
462
+ self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx)
463
+
464
+ if layer_idx >= config.first_k_dense_replace:
465
+ self.mlp = DeepseekV3MoE(config)
466
+ else:
467
+ self.mlp = DeepseekV3MLP(config)
468
+
469
+ self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
470
+ self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
471
+
472
+ def forward(
473
+ self,
474
+ hidden_states: torch.Tensor,
475
+ attention_mask: Optional[torch.Tensor] = None,
476
+ position_ids: Optional[torch.LongTensor] = None,
477
+ past_key_value: Optional[Cache] = None,
478
+ output_attentions: Optional[bool] = False,
479
+ use_cache: Optional[bool] = False,
480
+ cache_position: Optional[torch.LongTensor] = None,
481
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
482
+ **kwargs: Unpack[FlashAttentionKwargs],
483
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
484
+ residual = hidden_states
485
+
486
+ hidden_states = self.input_layernorm(hidden_states)
487
+
488
+ # Self Attention
489
+ hidden_states, self_attn_weights = self.self_attn(
490
+ hidden_states=hidden_states,
491
+ attention_mask=attention_mask,
492
+ position_ids=position_ids,
493
+ past_key_value=past_key_value,
494
+ output_attentions=output_attentions,
495
+ use_cache=use_cache,
496
+ cache_position=cache_position,
497
+ position_embeddings=position_embeddings,
498
+ **kwargs,
499
+ )
500
+ hidden_states = residual + hidden_states
501
+
502
+ # Fully Connected
503
+ residual = hidden_states
504
+ hidden_states = self.post_attention_layernorm(hidden_states)
505
+ hidden_states = self.mlp(hidden_states)
506
+ hidden_states = residual + hidden_states
507
+
508
+ outputs = (hidden_states,)
509
+ if output_attentions:
510
+ outputs += (self_attn_weights,)
511
+
512
+ return outputs
513
+
514
+
515
+ DEEPSEEK_V3_START_DOCSTRING = r"""
516
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
517
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
518
+ etc.)
519
+
520
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
521
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
522
+ and behavior.
523
+
524
+ Parameters:
525
+ config ([`DeepseekV3Config`]):
526
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
527
+ load the weights associated with the model, only the configuration. Check out the
528
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
529
+ """
530
+
531
+
532
+ @add_start_docstrings(
533
+ "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
534
+ DEEPSEEK_V3_START_DOCSTRING,
535
+ )
536
+ class DeepseekV3PreTrainedModel(PreTrainedModel):
537
+ config_class = DeepseekV3Config
538
+ base_model_prefix = "model"
539
+ supports_gradient_checkpointing = True
540
+ _no_split_modules = ["DeepseekV3DecoderLayer"]
541
+ _skip_keys_device_placement = ["past_key_values"]
542
+ _supports_flash_attn_2 = True
543
+ _supports_sdpa = True
544
+ _supports_flex_attn = True
545
+ _supports_cache_class = True
546
+ _supports_quantized_cache = True
547
+ _supports_static_cache = True
548
+ _supports_attention_backend = True
549
+
550
+ def _init_weights(self, module):
551
+ std = self.config.initializer_range
552
+ if isinstance(module, nn.Linear):
553
+ module.weight.data.normal_(mean=0.0, std=std)
554
+ if module.bias is not None:
555
+ module.bias.data.zero_()
556
+ elif isinstance(module, nn.Embedding):
557
+ module.weight.data.normal_(mean=0.0, std=std)
558
+ if module.padding_idx is not None:
559
+ module.weight.data[module.padding_idx].zero_()
560
+ elif isinstance(module, DeepseekV3TopkRouter):
561
+ module.weight.data.normal_(mean=0.0, std=std)
562
+ elif isinstance(module, nn.Parameter):
563
+ module.weight.data.normal_(mean=0.0, std=std)
564
+
565
+
566
+ DEEPSEEK_V3_INPUTS_DOCSTRING = r"""
567
+ Args:
568
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
569
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
570
+ it.
571
+
572
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
573
+ [`PreTrainedTokenizer.__call__`] for details.
574
+
575
+ [What are input IDs?](../glossary#input-ids)
576
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
577
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
578
+
579
+ - 1 for tokens that are **not masked**,
580
+ - 0 for tokens that are **masked**.
581
+
582
+ [What are attention masks?](../glossary#attention-mask)
583
+
584
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
585
+ [`PreTrainedTokenizer.__call__`] for details.
586
+
587
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
588
+ `past_key_values`).
589
+
590
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
591
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
592
+ information on the default strategy.
593
+
594
+ - 1 indicates the head is **not masked**,
595
+ - 0 indicates the head is **masked**.
596
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
597
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
598
+ config.n_positions - 1]`.
599
+
600
+ [What are position IDs?](../glossary#position-ids)
601
+ past_key_values (`Cache`, *optional*):
602
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
603
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
604
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
605
+
606
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
607
+
608
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
609
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
610
+ of shape `(batch_size, sequence_length)`.
611
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
612
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
613
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
614
+ model's internal embedding lookup matrix.
615
+ use_cache (`bool`, *optional*):
616
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
617
+ `past_key_values`).
618
+ output_attentions (`bool`, *optional*):
619
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
620
+ tensors for more detail.
621
+ output_hidden_states (`bool`, *optional*):
622
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
623
+ more detail.
624
+ return_dict (`bool`, *optional*):
625
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
626
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
627
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
628
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
629
+ the complete sequence length.
630
+ """
631
+
632
+
633
+ @add_start_docstrings(
634
+ "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
635
+ DEEPSEEK_V3_START_DOCSTRING,
636
+ )
637
+ class DeepseekV3Model(DeepseekV3PreTrainedModel):
638
+ """
639
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`]
640
+
641
+ Args:
642
+ config: DeepseekV3Config
643
+ """
644
+
645
+ _keys_to_ignore_on_load_unexpected = [r"model\.layers\.61.*"]
646
+
647
+ def __init__(self, config: DeepseekV3Config):
648
+ super().__init__(config)
649
+ self.padding_idx = config.pad_token_id
650
+ self.vocab_size = config.vocab_size
651
+
652
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
653
+ self.layers = nn.ModuleList(
654
+ [DeepseekV3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
655
+ )
656
+ self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
657
+ self.rotary_emb = DeepseekV3RotaryEmbedding(config=config)
658
+ self.gradient_checkpointing = False
659
+
660
+ # Initialize weights and apply final processing
661
+ self.post_init()
662
+
663
+ def get_input_embeddings(self):
664
+ return self.embed_tokens
665
+
666
+ def set_input_embeddings(self, value):
667
+ self.embed_tokens = value
668
+
669
+ @can_return_tuple
670
+ @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
671
+ def forward(
672
+ self,
673
+ input_ids: Optional[torch.LongTensor] = None,
674
+ attention_mask: Optional[torch.Tensor] = None,
675
+ position_ids: Optional[torch.LongTensor] = None,
676
+ past_key_values: Optional[Cache] = None,
677
+ inputs_embeds: Optional[torch.FloatTensor] = None,
678
+ use_cache: Optional[bool] = None,
679
+ output_attentions: Optional[bool] = None,
680
+ output_hidden_states: Optional[bool] = None,
681
+ cache_position: Optional[torch.LongTensor] = None,
682
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
683
+ ) -> BaseModelOutputWithPast:
684
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
685
+ output_hidden_states = (
686
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
687
+ )
688
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
689
+
690
+ if (input_ids is None) ^ (inputs_embeds is not None):
691
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
692
+
693
+ if self.gradient_checkpointing and self.training and use_cache:
694
+ logger.warning_once(
695
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
696
+ )
697
+ use_cache = False
698
+
699
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
700
+ if not isinstance(past_key_values, (type(None), Cache)):
701
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
702
+
703
+ if inputs_embeds is None:
704
+ inputs_embeds = self.embed_tokens(input_ids)
705
+
706
+ if use_cache and past_key_values is None:
707
+ past_key_values = DynamicCache()
708
+
709
+ if cache_position is None:
710
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
711
+ cache_position = torch.arange(
712
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
713
+ )
714
+
715
+ if position_ids is None:
716
+ position_ids = cache_position.unsqueeze(0)
717
+
718
+ causal_mask = self._update_causal_mask(
719
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
720
+ )
721
+
722
+ hidden_states = inputs_embeds
723
+
724
+ # create position embeddings to be shared across the decoder layers
725
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
726
+
727
+ # decoder layers
728
+ all_hidden_states = () if output_hidden_states else None
729
+ all_self_attns = () if output_attentions else None
730
+
731
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
732
+ if output_hidden_states:
733
+ all_hidden_states += (hidden_states,)
734
+
735
+ if self.gradient_checkpointing and self.training:
736
+ layer_outputs = self._gradient_checkpointing_func(
737
+ partial(decoder_layer.__call__, **flash_attn_kwargs),
738
+ hidden_states,
739
+ causal_mask,
740
+ position_ids,
741
+ past_key_values,
742
+ output_attentions,
743
+ use_cache,
744
+ cache_position,
745
+ position_embeddings,
746
+ )
747
+ else:
748
+ layer_outputs = decoder_layer(
749
+ hidden_states,
750
+ attention_mask=causal_mask,
751
+ position_ids=position_ids,
752
+ past_key_value=past_key_values,
753
+ output_attentions=output_attentions,
754
+ use_cache=use_cache,
755
+ cache_position=cache_position,
756
+ position_embeddings=position_embeddings,
757
+ **flash_attn_kwargs,
758
+ )
759
+
760
+ hidden_states = layer_outputs[0]
761
+
762
+ if output_attentions:
763
+ all_self_attns += (layer_outputs[1],)
764
+
765
+ hidden_states = self.norm(hidden_states)
766
+
767
+ # add hidden states from the last decoder layer
768
+ if output_hidden_states:
769
+ all_hidden_states += (hidden_states,)
770
+
771
+ return BaseModelOutputWithPast(
772
+ last_hidden_state=hidden_states,
773
+ past_key_values=past_key_values if use_cache else None,
774
+ hidden_states=all_hidden_states,
775
+ attentions=all_self_attns,
776
+ )
777
+
778
+ def _update_causal_mask(
779
+ self,
780
+ attention_mask: torch.Tensor,
781
+ input_tensor: torch.Tensor,
782
+ cache_position: torch.Tensor,
783
+ past_key_values: Cache,
784
+ output_attentions: bool = False,
785
+ ):
786
+ if self.config._attn_implementation == "flash_attention_2":
787
+ if attention_mask is not None and (attention_mask == 0.0).any():
788
+ return attention_mask
789
+ return None
790
+ if self.config._attn_implementation == "flex_attention":
791
+ if isinstance(attention_mask, torch.Tensor):
792
+ attention_mask = make_flex_block_causal_mask(attention_mask)
793
+ if isinstance(attention_mask, BlockMask):
794
+ return attention_mask
795
+
796
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
797
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
798
+ # to infer the attention mask.
799
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
800
+ using_static_cache = isinstance(past_key_values, StaticCache)
801
+
802
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
803
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
804
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
805
+ attention_mask,
806
+ inputs_embeds=input_tensor,
807
+ past_key_values_length=past_seen_tokens,
808
+ is_training=self.training,
809
+ ):
810
+ return None
811
+
812
+ dtype, device = input_tensor.dtype, input_tensor.device
813
+ sequence_length = input_tensor.shape[1]
814
+ if using_static_cache:
815
+ target_length = past_key_values.get_max_cache_shape()
816
+ else:
817
+ target_length = (
818
+ attention_mask.shape[-1]
819
+ if isinstance(attention_mask, torch.Tensor)
820
+ else past_seen_tokens + sequence_length + 1
821
+ )
822
+
823
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
824
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
825
+ attention_mask,
826
+ sequence_length=sequence_length,
827
+ target_length=target_length,
828
+ dtype=dtype,
829
+ device=device,
830
+ cache_position=cache_position,
831
+ batch_size=input_tensor.shape[0],
832
+ )
833
+
834
+ if (
835
+ self.config._attn_implementation == "sdpa"
836
+ and attention_mask is not None
837
+ and attention_mask.device.type in ["cuda", "xpu"]
838
+ and not output_attentions
839
+ ):
840
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
841
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
842
+ # Details: https://github.com/pytorch/pytorch/issues/110213
843
+ min_dtype = torch.finfo(dtype).min
844
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
845
+
846
+ return causal_mask
847
+
848
+ @staticmethod
849
+ def _prepare_4d_causal_attention_mask_with_cache_position(
850
+ attention_mask: torch.Tensor,
851
+ sequence_length: int,
852
+ target_length: int,
853
+ dtype: torch.dtype,
854
+ device: torch.device,
855
+ cache_position: torch.Tensor,
856
+ batch_size: int,
857
+ **kwargs,
858
+ ):
859
+ """
860
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
861
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
862
+
863
+ Args:
864
+ attention_mask (`torch.Tensor`):
865
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
866
+ `(batch_size, 1, query_length, key_value_length)`.
867
+ sequence_length (`int`):
868
+ The sequence length being processed.
869
+ target_length (`int`):
870
+ The target length: when generating with static cache, the mask should be as long as the static cache,
871
+ to account for the 0 padding, the part of the cache that is not filled yet.
872
+ dtype (`torch.dtype`):
873
+ The dtype to use for the 4D attention mask.
874
+ device (`torch.device`):
875
+ The device to place the 4D attention mask on.
876
+ cache_position (`torch.Tensor`):
877
+ Indices depicting the position of the input sequence tokens in the sequence.
878
+ batch_size (`torch.Tensor`):
879
+ Batch size.
880
+ """
881
+ if attention_mask is not None and attention_mask.dim() == 4:
882
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
883
+ causal_mask = attention_mask
884
+ else:
885
+ min_dtype = torch.finfo(dtype).min
886
+ causal_mask = torch.full(
887
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
888
+ )
889
+ if sequence_length != 1:
890
+ causal_mask = torch.triu(causal_mask, diagonal=1)
891
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
892
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
893
+ if attention_mask is not None:
894
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
895
+ mask_length = attention_mask.shape[-1]
896
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
897
+ causal_mask.device
898
+ )
899
+ padding_mask = padding_mask == 0
900
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
901
+ padding_mask, min_dtype
902
+ )
903
+
904
+ return causal_mask
905
+
906
+
907
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
908
+
909
+
910
+ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
911
+ _tied_weights_keys = ["lm_head.weight"]
912
+ _tp_plan = {"lm_head": "colwise_rep"}
913
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
914
+
915
+ def __init__(self, config):
916
+ super().__init__(config)
917
+ self.model = DeepseekV3Model(config)
918
+ self.vocab_size = config.vocab_size
919
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
920
+
921
+ # Initialize weights and apply final processing
922
+ self.post_init()
923
+
924
+ def get_input_embeddings(self):
925
+ return self.model.embed_tokens
926
+
927
+ def set_input_embeddings(self, value):
928
+ self.model.embed_tokens = value
929
+
930
+ def get_output_embeddings(self):
931
+ return self.lm_head
932
+
933
+ def set_output_embeddings(self, new_embeddings):
934
+ self.lm_head = new_embeddings
935
+
936
+ def set_decoder(self, decoder):
937
+ self.model = decoder
938
+
939
+ def get_decoder(self):
940
+ return self.model
941
+
942
+ @can_return_tuple
943
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
944
+ @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
945
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
946
+ def forward(
947
+ self,
948
+ input_ids: Optional[torch.LongTensor] = None,
949
+ attention_mask: Optional[torch.Tensor] = None,
950
+ position_ids: Optional[torch.LongTensor] = None,
951
+ past_key_values: Optional[Cache] = None,
952
+ inputs_embeds: Optional[torch.FloatTensor] = None,
953
+ labels: Optional[torch.LongTensor] = None,
954
+ use_cache: Optional[bool] = None,
955
+ output_attentions: Optional[bool] = None,
956
+ output_hidden_states: Optional[bool] = None,
957
+ cache_position: Optional[torch.LongTensor] = None,
958
+ logits_to_keep: Union[int, torch.Tensor] = 0,
959
+ **kwargs: Unpack[KwargsForCausalLM],
960
+ ) -> CausalLMOutputWithPast:
961
+ r"""
962
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
963
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
964
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
965
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
966
+
967
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
968
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
969
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
970
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
971
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
972
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
973
+
974
+ Returns:
975
+
976
+ Example:
977
+
978
+ ```python
979
+ >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
980
+
981
+ >>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
982
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
983
+
984
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
985
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
986
+
987
+ >>> # Generate
988
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
989
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
990
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
991
+ ```"""
992
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
993
+ output_hidden_states = (
994
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
995
+ )
996
+
997
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
998
+ outputs: BaseModelOutputWithPast = self.model(
999
+ input_ids=input_ids,
1000
+ attention_mask=attention_mask,
1001
+ position_ids=position_ids,
1002
+ past_key_values=past_key_values,
1003
+ inputs_embeds=inputs_embeds,
1004
+ use_cache=use_cache,
1005
+ output_attentions=output_attentions,
1006
+ output_hidden_states=output_hidden_states,
1007
+ cache_position=cache_position,
1008
+ **kwargs,
1009
+ )
1010
+
1011
+ hidden_states = outputs.last_hidden_state
1012
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1013
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1014
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1015
+
1016
+ loss = None
1017
+ if labels is not None:
1018
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
1019
+
1020
+ return CausalLMOutputWithPast(
1021
+ loss=loss,
1022
+ logits=logits,
1023
+ past_key_values=outputs.past_key_values,
1024
+ hidden_states=outputs.hidden_states,
1025
+ attentions=outputs.attentions,
1026
+ )
1027
+
1028
+
1029
+ __all__ = ["DeepseekV3PreTrainedModel", "DeepseekV3Model", "DeepseekV3ForCausalLM"]
1030
+ __all__ = ["DeepseekV3PreTrainedModel", "DeepseekV3Model", "DeepseekV3ForCausalLM"]
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<|begin▁of▁sentence|>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "clean_up_tokenization_spaces": false,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "<|end▁of▁sentence|>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "legacy": true,
22
+ "model_max_length": 131072,
23
+ "pad_token": {
24
+ "__type": "AddedToken",
25
+ "content": "<|▁pad▁|>",
26
+ "lstrip": false,
27
+ "normalized": true,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ },
31
+ "sp_model_kwargs": {},
32
+ "unk_token": null,
33
+ "tokenizer_class": "LlamaTokenizerFast"
34
+ }
venv/bin/Activate.ps1 ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <#
2
+ .Synopsis
3
+ Activate a Python virtual environment for the current PowerShell session.
4
+
5
+ .Description
6
+ Pushes the python executable for a virtual environment to the front of the
7
+ $Env:PATH environment variable and sets the prompt to signify that you are
8
+ in a Python virtual environment. Makes use of the command line switches as
9
+ well as the `pyvenv.cfg` file values present in the virtual environment.
10
+
11
+ .Parameter VenvDir
12
+ Path to the directory that contains the virtual environment to activate. The
13
+ default value for this is the parent of the directory that the Activate.ps1
14
+ script is located within.
15
+
16
+ .Parameter Prompt
17
+ The prompt prefix to display when this virtual environment is activated. By
18
+ default, this prompt is the name of the virtual environment folder (VenvDir)
19
+ surrounded by parentheses and followed by a single space (ie. '(.venv) ').
20
+
21
+ .Example
22
+ Activate.ps1
23
+ Activates the Python virtual environment that contains the Activate.ps1 script.
24
+
25
+ .Example
26
+ Activate.ps1 -Verbose
27
+ Activates the Python virtual environment that contains the Activate.ps1 script,
28
+ and shows extra information about the activation as it executes.
29
+
30
+ .Example
31
+ Activate.ps1 -VenvDir C:\Users\MyUser\Common\.venv
32
+ Activates the Python virtual environment located in the specified location.
33
+
34
+ .Example
35
+ Activate.ps1 -Prompt "MyPython"
36
+ Activates the Python virtual environment that contains the Activate.ps1 script,
37
+ and prefixes the current prompt with the specified string (surrounded in
38
+ parentheses) while the virtual environment is active.
39
+
40
+ .Notes
41
+ On Windows, it may be required to enable this Activate.ps1 script by setting the
42
+ execution policy for the user. You can do this by issuing the following PowerShell
43
+ command:
44
+
45
+ PS C:\> Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
46
+
47
+ For more information on Execution Policies:
48
+ https://go.microsoft.com/fwlink/?LinkID=135170
49
+
50
+ #>
51
+ Param(
52
+ [Parameter(Mandatory = $false)]
53
+ [String]
54
+ $VenvDir,
55
+ [Parameter(Mandatory = $false)]
56
+ [String]
57
+ $Prompt
58
+ )
59
+
60
+ <# Function declarations --------------------------------------------------- #>
61
+
62
+ <#
63
+ .Synopsis
64
+ Remove all shell session elements added by the Activate script, including the
65
+ addition of the virtual environment's Python executable from the beginning of
66
+ the PATH variable.
67
+
68
+ .Parameter NonDestructive
69
+ If present, do not remove this function from the global namespace for the
70
+ session.
71
+
72
+ #>
73
+ function global:deactivate ([switch]$NonDestructive) {
74
+ # Revert to original values
75
+
76
+ # The prior prompt:
77
+ if (Test-Path -Path Function:_OLD_VIRTUAL_PROMPT) {
78
+ Copy-Item -Path Function:_OLD_VIRTUAL_PROMPT -Destination Function:prompt
79
+ Remove-Item -Path Function:_OLD_VIRTUAL_PROMPT
80
+ }
81
+
82
+ # The prior PYTHONHOME:
83
+ if (Test-Path -Path Env:_OLD_VIRTUAL_PYTHONHOME) {
84
+ Copy-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME -Destination Env:PYTHONHOME
85
+ Remove-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME
86
+ }
87
+
88
+ # The prior PATH:
89
+ if (Test-Path -Path Env:_OLD_VIRTUAL_PATH) {
90
+ Copy-Item -Path Env:_OLD_VIRTUAL_PATH -Destination Env:PATH
91
+ Remove-Item -Path Env:_OLD_VIRTUAL_PATH
92
+ }
93
+
94
+ # Just remove the VIRTUAL_ENV altogether:
95
+ if (Test-Path -Path Env:VIRTUAL_ENV) {
96
+ Remove-Item -Path env:VIRTUAL_ENV
97
+ }
98
+
99
+ # Just remove VIRTUAL_ENV_PROMPT altogether.
100
+ if (Test-Path -Path Env:VIRTUAL_ENV_PROMPT) {
101
+ Remove-Item -Path env:VIRTUAL_ENV_PROMPT
102
+ }
103
+
104
+ # Just remove the _PYTHON_VENV_PROMPT_PREFIX altogether:
105
+ if (Get-Variable -Name "_PYTHON_VENV_PROMPT_PREFIX" -ErrorAction SilentlyContinue) {
106
+ Remove-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Scope Global -Force
107
+ }
108
+
109
+ # Leave deactivate function in the global namespace if requested:
110
+ if (-not $NonDestructive) {
111
+ Remove-Item -Path function:deactivate
112
+ }
113
+ }
114
+
115
+ <#
116
+ .Description
117
+ Get-PyVenvConfig parses the values from the pyvenv.cfg file located in the
118
+ given folder, and returns them in a map.
119
+
120
+ For each line in the pyvenv.cfg file, if that line can be parsed into exactly
121
+ two strings separated by `=` (with any amount of whitespace surrounding the =)
122
+ then it is considered a `key = value` line. The left hand string is the key,
123
+ the right hand is the value.
124
+
125
+ If the value starts with a `'` or a `"` then the first and last character is
126
+ stripped from the value before being captured.
127
+
128
+ .Parameter ConfigDir
129
+ Path to the directory that contains the `pyvenv.cfg` file.
130
+ #>
131
+ function Get-PyVenvConfig(
132
+ [String]
133
+ $ConfigDir
134
+ ) {
135
+ Write-Verbose "Given ConfigDir=$ConfigDir, obtain values in pyvenv.cfg"
136
+
137
+ # Ensure the file exists, and issue a warning if it doesn't (but still allow the function to continue).
138
+ $pyvenvConfigPath = Join-Path -Resolve -Path $ConfigDir -ChildPath 'pyvenv.cfg' -ErrorAction Continue
139
+
140
+ # An empty map will be returned if no config file is found.
141
+ $pyvenvConfig = @{ }
142
+
143
+ if ($pyvenvConfigPath) {
144
+
145
+ Write-Verbose "File exists, parse `key = value` lines"
146
+ $pyvenvConfigContent = Get-Content -Path $pyvenvConfigPath
147
+
148
+ $pyvenvConfigContent | ForEach-Object {
149
+ $keyval = $PSItem -split "\s*=\s*", 2
150
+ if ($keyval[0] -and $keyval[1]) {
151
+ $val = $keyval[1]
152
+
153
+ # Remove extraneous quotations around a string value.
154
+ if ("'""".Contains($val.Substring(0, 1))) {
155
+ $val = $val.Substring(1, $val.Length - 2)
156
+ }
157
+
158
+ $pyvenvConfig[$keyval[0]] = $val
159
+ Write-Verbose "Adding Key: '$($keyval[0])'='$val'"
160
+ }
161
+ }
162
+ }
163
+ return $pyvenvConfig
164
+ }
165
+
166
+
167
+ <# Begin Activate script --------------------------------------------------- #>
168
+
169
+ # Determine the containing directory of this script
170
+ $VenvExecPath = Split-Path -Parent $MyInvocation.MyCommand.Definition
171
+ $VenvExecDir = Get-Item -Path $VenvExecPath
172
+
173
+ Write-Verbose "Activation script is located in path: '$VenvExecPath'"
174
+ Write-Verbose "VenvExecDir Fullname: '$($VenvExecDir.FullName)"
175
+ Write-Verbose "VenvExecDir Name: '$($VenvExecDir.Name)"
176
+
177
+ # Set values required in priority: CmdLine, ConfigFile, Default
178
+ # First, get the location of the virtual environment, it might not be
179
+ # VenvExecDir if specified on the command line.
180
+ if ($VenvDir) {
181
+ Write-Verbose "VenvDir given as parameter, using '$VenvDir' to determine values"
182
+ }
183
+ else {
184
+ Write-Verbose "VenvDir not given as a parameter, using parent directory name as VenvDir."
185
+ $VenvDir = $VenvExecDir.Parent.FullName.TrimEnd("\\/")
186
+ Write-Verbose "VenvDir=$VenvDir"
187
+ }
188
+
189
+ # Next, read the `pyvenv.cfg` file to determine any required value such
190
+ # as `prompt`.
191
+ $pyvenvCfg = Get-PyVenvConfig -ConfigDir $VenvDir
192
+
193
+ # Next, set the prompt from the command line, or the config file, or
194
+ # just use the name of the virtual environment folder.
195
+ if ($Prompt) {
196
+ Write-Verbose "Prompt specified as argument, using '$Prompt'"
197
+ }
198
+ else {
199
+ Write-Verbose "Prompt not specified as argument to script, checking pyvenv.cfg value"
200
+ if ($pyvenvCfg -and $pyvenvCfg['prompt']) {
201
+ Write-Verbose " Setting based on value in pyvenv.cfg='$($pyvenvCfg['prompt'])'"
202
+ $Prompt = $pyvenvCfg['prompt'];
203
+ }
204
+ else {
205
+ Write-Verbose " Setting prompt based on parent's directory's name. (Is the directory name passed to venv module when creating the virtual environment)"
206
+ Write-Verbose " Got leaf-name of $VenvDir='$(Split-Path -Path $venvDir -Leaf)'"
207
+ $Prompt = Split-Path -Path $venvDir -Leaf
208
+ }
209
+ }
210
+
211
+ Write-Verbose "Prompt = '$Prompt'"
212
+ Write-Verbose "VenvDir='$VenvDir'"
213
+
214
+ # Deactivate any currently active virtual environment, but leave the
215
+ # deactivate function in place.
216
+ deactivate -nondestructive
217
+
218
+ # Now set the environment variable VIRTUAL_ENV, used by many tools to determine
219
+ # that there is an activated venv.
220
+ $env:VIRTUAL_ENV = $VenvDir
221
+
222
+ if (-not $Env:VIRTUAL_ENV_DISABLE_PROMPT) {
223
+
224
+ Write-Verbose "Setting prompt to '$Prompt'"
225
+
226
+ # Set the prompt to include the env name
227
+ # Make sure _OLD_VIRTUAL_PROMPT is global
228
+ function global:_OLD_VIRTUAL_PROMPT { "" }
229
+ Copy-Item -Path function:prompt -Destination function:_OLD_VIRTUAL_PROMPT
230
+ New-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Description "Python virtual environment prompt prefix" -Scope Global -Option ReadOnly -Visibility Public -Value $Prompt
231
+
232
+ function global:prompt {
233
+ Write-Host -NoNewline -ForegroundColor Green "($_PYTHON_VENV_PROMPT_PREFIX) "
234
+ _OLD_VIRTUAL_PROMPT
235
+ }
236
+ $env:VIRTUAL_ENV_PROMPT = $Prompt
237
+ }
238
+
239
+ # Clear PYTHONHOME
240
+ if (Test-Path -Path Env:PYTHONHOME) {
241
+ Copy-Item -Path Env:PYTHONHOME -Destination Env:_OLD_VIRTUAL_PYTHONHOME
242
+ Remove-Item -Path Env:PYTHONHOME
243
+ }
244
+
245
+ # Add the venv to the PATH
246
+ Copy-Item -Path Env:PATH -Destination Env:_OLD_VIRTUAL_PATH
247
+ $Env:PATH = "$VenvExecDir$([System.IO.Path]::PathSeparator)$Env:PATH"
venv/bin/activate ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file must be used with "source bin/activate" *from bash*
2
+ # you cannot run it directly
3
+
4
+ deactivate () {
5
+ # reset old environment variables
6
+ if [ -n "${_OLD_VIRTUAL_PATH:-}" ] ; then
7
+ PATH="${_OLD_VIRTUAL_PATH:-}"
8
+ export PATH
9
+ unset _OLD_VIRTUAL_PATH
10
+ fi
11
+ if [ -n "${_OLD_VIRTUAL_PYTHONHOME:-}" ] ; then
12
+ PYTHONHOME="${_OLD_VIRTUAL_PYTHONHOME:-}"
13
+ export PYTHONHOME
14
+ unset _OLD_VIRTUAL_PYTHONHOME
15
+ fi
16
+
17
+ # This should detect bash and zsh, which have a hash command that must
18
+ # be called to get it to forget past commands. Without forgetting
19
+ # past commands the $PATH changes we made may not be respected
20
+ if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then
21
+ hash -r 2> /dev/null
22
+ fi
23
+
24
+ if [ -n "${_OLD_VIRTUAL_PS1:-}" ] ; then
25
+ PS1="${_OLD_VIRTUAL_PS1:-}"
26
+ export PS1
27
+ unset _OLD_VIRTUAL_PS1
28
+ fi
29
+
30
+ unset VIRTUAL_ENV
31
+ unset VIRTUAL_ENV_PROMPT
32
+ if [ ! "${1:-}" = "nondestructive" ] ; then
33
+ # Self destruct!
34
+ unset -f deactivate
35
+ fi
36
+ }
37
+
38
+ # unset irrelevant variables
39
+ deactivate nondestructive
40
+
41
+ VIRTUAL_ENV=/mnt/llm-data/users/xieshuai/codes/hf_model/omni/deepseek_40b/20260211-dpo-0210-0208-v2-dpoaddid-965-mtp-qiangzhifeisikao/fp8_model/venv
42
+ export VIRTUAL_ENV
43
+
44
+ _OLD_VIRTUAL_PATH="$PATH"
45
+ PATH="$VIRTUAL_ENV/"bin":$PATH"
46
+ export PATH
47
+
48
+ # unset PYTHONHOME if set
49
+ # this will fail if PYTHONHOME is set to the empty string (which is bad anyway)
50
+ # could use `if (set -u; : $PYTHONHOME) ;` in bash
51
+ if [ -n "${PYTHONHOME:-}" ] ; then
52
+ _OLD_VIRTUAL_PYTHONHOME="${PYTHONHOME:-}"
53
+ unset PYTHONHOME
54
+ fi
55
+
56
+ if [ -z "${VIRTUAL_ENV_DISABLE_PROMPT:-}" ] ; then
57
+ _OLD_VIRTUAL_PS1="${PS1:-}"
58
+ PS1='(venv) '"${PS1:-}"
59
+ export PS1
60
+ VIRTUAL_ENV_PROMPT='(venv) '
61
+ export VIRTUAL_ENV_PROMPT
62
+ fi
63
+
64
+ # This should detect bash and zsh, which have a hash command that must
65
+ # be called to get it to forget past commands. Without forgetting
66
+ # past commands the $PATH changes we made may not be respected
67
+ if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then
68
+ hash -r 2> /dev/null
69
+ fi
venv/bin/activate.csh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file must be used with "source bin/activate.csh" *from csh*.
2
+ # You cannot run it directly.
3
+ # Created by Davide Di Blasi <davidedb@gmail.com>.
4
+ # Ported to Python 3.3 venv by Andrew Svetlov <andrew.svetlov@gmail.com>
5
+
6
+ alias deactivate 'test $?_OLD_VIRTUAL_PATH != 0 && setenv PATH "$_OLD_VIRTUAL_PATH" && unset _OLD_VIRTUAL_PATH; rehash; test $?_OLD_VIRTUAL_PROMPT != 0 && set prompt="$_OLD_VIRTUAL_PROMPT" && unset _OLD_VIRTUAL_PROMPT; unsetenv VIRTUAL_ENV; unsetenv VIRTUAL_ENV_PROMPT; test "\!:*" != "nondestructive" && unalias deactivate'
7
+
8
+ # Unset irrelevant variables.
9
+ deactivate nondestructive
10
+
11
+ setenv VIRTUAL_ENV /mnt/llm-data/users/xieshuai/codes/hf_model/omni/deepseek_40b/20260211-dpo-0210-0208-v2-dpoaddid-965-mtp-qiangzhifeisikao/fp8_model/venv
12
+
13
+ set _OLD_VIRTUAL_PATH="$PATH"
14
+ setenv PATH "$VIRTUAL_ENV/"bin":$PATH"
15
+
16
+
17
+ set _OLD_VIRTUAL_PROMPT="$prompt"
18
+
19
+ if (! "$?VIRTUAL_ENV_DISABLE_PROMPT") then
20
+ set prompt = '(venv) '"$prompt"
21
+ setenv VIRTUAL_ENV_PROMPT '(venv) '
22
+ endif
23
+
24
+ alias pydoc python -m pydoc
25
+
26
+ rehash
venv/bin/activate.fish ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file must be used with "source <venv>/bin/activate.fish" *from fish*
2
+ # (https://fishshell.com/); you cannot run it directly.
3
+
4
+ function deactivate -d "Exit virtual environment and return to normal shell environment"
5
+ # reset old environment variables
6
+ if test -n "$_OLD_VIRTUAL_PATH"
7
+ set -gx PATH $_OLD_VIRTUAL_PATH
8
+ set -e _OLD_VIRTUAL_PATH
9
+ end
10
+ if test -n "$_OLD_VIRTUAL_PYTHONHOME"
11
+ set -gx PYTHONHOME $_OLD_VIRTUAL_PYTHONHOME
12
+ set -e _OLD_VIRTUAL_PYTHONHOME
13
+ end
14
+
15
+ if test -n "$_OLD_FISH_PROMPT_OVERRIDE"
16
+ set -e _OLD_FISH_PROMPT_OVERRIDE
17
+ # prevents error when using nested fish instances (Issue #93858)
18
+ if functions -q _old_fish_prompt
19
+ functions -e fish_prompt
20
+ functions -c _old_fish_prompt fish_prompt
21
+ functions -e _old_fish_prompt
22
+ end
23
+ end
24
+
25
+ set -e VIRTUAL_ENV
26
+ set -e VIRTUAL_ENV_PROMPT
27
+ if test "$argv[1]" != "nondestructive"
28
+ # Self-destruct!
29
+ functions -e deactivate
30
+ end
31
+ end
32
+
33
+ # Unset irrelevant variables.
34
+ deactivate nondestructive
35
+
36
+ set -gx VIRTUAL_ENV /mnt/llm-data/users/xieshuai/codes/hf_model/omni/deepseek_40b/20260211-dpo-0210-0208-v2-dpoaddid-965-mtp-qiangzhifeisikao/fp8_model/venv
37
+
38
+ set -gx _OLD_VIRTUAL_PATH $PATH
39
+ set -gx PATH "$VIRTUAL_ENV/"bin $PATH
40
+
41
+ # Unset PYTHONHOME if set.
42
+ if set -q PYTHONHOME
43
+ set -gx _OLD_VIRTUAL_PYTHONHOME $PYTHONHOME
44
+ set -e PYTHONHOME
45
+ end
46
+
47
+ if test -z "$VIRTUAL_ENV_DISABLE_PROMPT"
48
+ # fish uses a function instead of an env var to generate the prompt.
49
+
50
+ # Save the current fish_prompt function as the function _old_fish_prompt.
51
+ functions -c fish_prompt _old_fish_prompt
52
+
53
+ # With the original prompt function renamed, we can override with our own.
54
+ function fish_prompt
55
+ # Save the return status of the last command.
56
+ set -l old_status $status
57
+
58
+ # Output the venv prompt; color taken from the blue of the Python logo.
59
+ printf "%s%s%s" (set_color 4B8BBE) '(venv) ' (set_color normal)
60
+
61
+ # Restore the return status of the previous command.
62
+ echo "exit $old_status" | .
63
+ # Output the original/"old" prompt.
64
+ _old_fish_prompt
65
+ end
66
+
67
+ set -gx _OLD_FISH_PROMPT_OVERRIDE "$VIRTUAL_ENV"
68
+ set -gx VIRTUAL_ENV_PROMPT '(venv) '
69
+ end
venv/bin/hf ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ '''exec' /mnt/llm-data/users/xieshuai/codes/hf_model/omni/deepseek_40b/20260211-dpo-0210-0208-v2-dpoaddid-965-mtp-qiangzhifeisikao/fp8_model/venv/bin/python3 "$0" "$@"
3
+ ' '''
4
+ # -*- coding: utf-8 -*-
5
+ import re
6
+ import sys
7
+ from huggingface_hub.cli.hf import main
8
+ if __name__ == '__main__':
9
+ sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
10
+ sys.exit(main())
venv/bin/httpx ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ '''exec' /mnt/llm-data/users/xieshuai/codes/hf_model/omni/deepseek_40b/20260211-dpo-0210-0208-v2-dpoaddid-965-mtp-qiangzhifeisikao/fp8_model/venv/bin/python3 "$0" "$@"
3
+ ' '''
4
+ # -*- coding: utf-8 -*-
5
+ import re
6
+ import sys
7
+ from httpx import main
8
+ if __name__ == '__main__':
9
+ sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
10
+ sys.exit(main())
venv/bin/markdown-it ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ '''exec' /mnt/llm-data/users/xieshuai/codes/hf_model/omni/deepseek_40b/20260211-dpo-0210-0208-v2-dpoaddid-965-mtp-qiangzhifeisikao/fp8_model/venv/bin/python3 "$0" "$@"
3
+ ' '''
4
+ # -*- coding: utf-8 -*-
5
+ import re
6
+ import sys
7
+ from markdown_it.cli.parse import main
8
+ if __name__ == '__main__':
9
+ sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
10
+ sys.exit(main())
venv/bin/pip ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ '''exec' /mnt/llm-data/users/xieshuai/codes/hf_model/omni/deepseek_40b/20260211-dpo-0210-0208-v2-dpoaddid-965-mtp-qiangzhifeisikao/fp8_model/venv/bin/python3 "$0" "$@"
3
+ ' '''
4
+ # -*- coding: utf-8 -*-
5
+ import re
6
+ import sys
7
+ from pip._internal.cli.main import main
8
+ if __name__ == '__main__':
9
+ sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
10
+ sys.exit(main())
venv/bin/pip3 ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ '''exec' /mnt/llm-data/users/xieshuai/codes/hf_model/omni/deepseek_40b/20260211-dpo-0210-0208-v2-dpoaddid-965-mtp-qiangzhifeisikao/fp8_model/venv/bin/python3 "$0" "$@"
3
+ ' '''
4
+ # -*- coding: utf-8 -*-
5
+ import re
6
+ import sys
7
+ from pip._internal.cli.main import main
8
+ if __name__ == '__main__':
9
+ sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
10
+ sys.exit(main())
venv/bin/pip3.10 ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ '''exec' /mnt/llm-data/users/xieshuai/codes/hf_model/omni/deepseek_40b/20260211-dpo-0210-0208-v2-dpoaddid-965-mtp-qiangzhifeisikao/fp8_model/venv/bin/python3 "$0" "$@"
3
+ ' '''
4
+ # -*- coding: utf-8 -*-
5
+ import re
6
+ import sys
7
+ from pip._internal.cli.main import main
8
+ if __name__ == '__main__':
9
+ sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
10
+ sys.exit(main())
venv/bin/pygmentize ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ '''exec' /mnt/llm-data/users/xieshuai/codes/hf_model/omni/deepseek_40b/20260211-dpo-0210-0208-v2-dpoaddid-965-mtp-qiangzhifeisikao/fp8_model/venv/bin/python3 "$0" "$@"
3
+ ' '''
4
+ # -*- coding: utf-8 -*-
5
+ import re
6
+ import sys
7
+ from pygments.cmdline import main
8
+ if __name__ == '__main__':
9
+ sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
10
+ sys.exit(main())
venv/bin/tiny-agents ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ '''exec' /mnt/llm-data/users/xieshuai/codes/hf_model/omni/deepseek_40b/20260211-dpo-0210-0208-v2-dpoaddid-965-mtp-qiangzhifeisikao/fp8_model/venv/bin/python3 "$0" "$@"
3
+ ' '''
4
+ # -*- coding: utf-8 -*-
5
+ import re
6
+ import sys
7
+ from huggingface_hub.inference._mcp.cli import app
8
+ if __name__ == '__main__':
9
+ sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
10
+ sys.exit(app())
venv/bin/tqdm ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ '''exec' /mnt/llm-data/users/xieshuai/codes/hf_model/omni/deepseek_40b/20260211-dpo-0210-0208-v2-dpoaddid-965-mtp-qiangzhifeisikao/fp8_model/venv/bin/python3 "$0" "$@"
3
+ ' '''
4
+ # -*- coding: utf-8 -*-
5
+ import re
6
+ import sys
7
+ from tqdm.cli import main
8
+ if __name__ == '__main__':
9
+ sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
10
+ sys.exit(main())
venv/bin/typer ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ '''exec' /mnt/llm-data/users/xieshuai/codes/hf_model/omni/deepseek_40b/20260211-dpo-0210-0208-v2-dpoaddid-965-mtp-qiangzhifeisikao/fp8_model/venv/bin/python3 "$0" "$@"
3
+ ' '''
4
+ # -*- coding: utf-8 -*-
5
+ import re
6
+ import sys
7
+ from typer.cli import main
8
+ if __name__ == '__main__':
9
+ sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
10
+ sys.exit(main())
venv/lib/python3.10/site-packages/_distutils_hack/__init__.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import re
4
+ import importlib
5
+ import warnings
6
+
7
+
8
+ is_pypy = '__pypy__' in sys.builtin_module_names
9
+
10
+
11
+ warnings.filterwarnings('ignore',
12
+ r'.+ distutils\b.+ deprecated',
13
+ DeprecationWarning)
14
+
15
+
16
+ def warn_distutils_present():
17
+ if 'distutils' not in sys.modules:
18
+ return
19
+ if is_pypy and sys.version_info < (3, 7):
20
+ # PyPy for 3.6 unconditionally imports distutils, so bypass the warning
21
+ # https://foss.heptapod.net/pypy/pypy/-/blob/be829135bc0d758997b3566062999ee8b23872b4/lib-python/3/site.py#L250
22
+ return
23
+ warnings.warn(
24
+ "Distutils was imported before Setuptools, but importing Setuptools "
25
+ "also replaces the `distutils` module in `sys.modules`. This may lead "
26
+ "to undesirable behaviors or errors. To avoid these issues, avoid "
27
+ "using distutils directly, ensure that setuptools is installed in the "
28
+ "traditional way (e.g. not an editable install), and/or make sure "
29
+ "that setuptools is always imported before distutils.")
30
+
31
+
32
+ def clear_distutils():
33
+ if 'distutils' not in sys.modules:
34
+ return
35
+ warnings.warn("Setuptools is replacing distutils.")
36
+ mods = [name for name in sys.modules if re.match(r'distutils\b', name)]
37
+ for name in mods:
38
+ del sys.modules[name]
39
+
40
+
41
+ def enabled():
42
+ """
43
+ Allow selection of distutils by environment variable.
44
+ """
45
+ which = os.environ.get('SETUPTOOLS_USE_DISTUTILS', 'stdlib')
46
+ return which == 'local'
47
+
48
+
49
+ def ensure_local_distutils():
50
+ clear_distutils()
51
+
52
+ # With the DistutilsMetaFinder in place,
53
+ # perform an import to cause distutils to be
54
+ # loaded from setuptools._distutils. Ref #2906.
55
+ add_shim()
56
+ importlib.import_module('distutils')
57
+ remove_shim()
58
+
59
+ # check that submodules load as expected
60
+ core = importlib.import_module('distutils.core')
61
+ assert '_distutils' in core.__file__, core.__file__
62
+
63
+
64
+ def do_override():
65
+ """
66
+ Ensure that the local copy of distutils is preferred over stdlib.
67
+
68
+ See https://github.com/pypa/setuptools/issues/417#issuecomment-392298401
69
+ for more motivation.
70
+ """
71
+ if enabled():
72
+ warn_distutils_present()
73
+ ensure_local_distutils()
74
+
75
+
76
+ class DistutilsMetaFinder:
77
+ def find_spec(self, fullname, path, target=None):
78
+ if path is not None:
79
+ return
80
+
81
+ method_name = 'spec_for_{fullname}'.format(**locals())
82
+ method = getattr(self, method_name, lambda: None)
83
+ return method()
84
+
85
+ def spec_for_distutils(self):
86
+ import importlib.abc
87
+ import importlib.util
88
+
89
+ class DistutilsLoader(importlib.abc.Loader):
90
+
91
+ def create_module(self, spec):
92
+ return importlib.import_module('setuptools._distutils')
93
+
94
+ def exec_module(self, module):
95
+ pass
96
+
97
+ return importlib.util.spec_from_loader('distutils', DistutilsLoader())
98
+
99
+ def spec_for_pip(self):
100
+ """
101
+ Ensure stdlib distutils when running under pip.
102
+ See pypa/pip#8761 for rationale.
103
+ """
104
+ if self.pip_imported_during_build():
105
+ return
106
+ clear_distutils()
107
+ self.spec_for_distutils = lambda: None
108
+
109
+ @staticmethod
110
+ def pip_imported_during_build():
111
+ """
112
+ Detect if pip is being imported in a build script. Ref #2355.
113
+ """
114
+ import traceback
115
+ return any(
116
+ frame.f_globals['__file__'].endswith('setup.py')
117
+ for frame, line in traceback.walk_stack(None)
118
+ )
119
+
120
+
121
+ DISTUTILS_FINDER = DistutilsMetaFinder()
122
+
123
+
124
+ def add_shim():
125
+ sys.meta_path.insert(0, DISTUTILS_FINDER)
126
+
127
+
128
+ def remove_shim():
129
+ try:
130
+ sys.meta_path.remove(DISTUTILS_FINDER)
131
+ except ValueError:
132
+ pass
venv/lib/python3.10/site-packages/_distutils_hack/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (5.21 kB). View file
 
venv/lib/python3.10/site-packages/_distutils_hack/__pycache__/override.cpython-310.pyc ADDED
Binary file (337 Bytes). View file
 
venv/lib/python3.10/site-packages/_distutils_hack/override.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __import__('_distutils_hack').do_override()
venv/lib/python3.10/site-packages/_yaml/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is a stub package designed to roughly emulate the _yaml
2
+ # extension module, which previously existed as a standalone module
3
+ # and has been moved into the `yaml` package namespace.
4
+ # It does not perfectly mimic its old counterpart, but should get
5
+ # close enough for anyone who's relying on it even when they shouldn't.
6
+ import yaml
7
+
8
+ # in some circumstances, the yaml module we imoprted may be from a different version, so we need
9
+ # to tread carefully when poking at it here (it may not have the attributes we expect)
10
+ if not getattr(yaml, '__with_libyaml__', False):
11
+ from sys import version_info
12
+
13
+ exc = ModuleNotFoundError if version_info >= (3, 6) else ImportError
14
+ raise exc("No module named '_yaml'")
15
+ else:
16
+ from yaml._yaml import *
17
+ import warnings
18
+ warnings.warn(
19
+ 'The _yaml extension module is now located at yaml._yaml'
20
+ ' and its location is subject to change. To use the'
21
+ ' LibYAML-based parser and emitter, import from `yaml`:'
22
+ ' `from yaml import CLoader as Loader, CDumper as Dumper`.',
23
+ DeprecationWarning
24
+ )
25
+ del warnings
26
+ # Don't `del yaml` here because yaml is actually an existing
27
+ # namespace member of _yaml.
28
+
29
+ __name__ = '_yaml'
30
+ # If the module is top-level (i.e. not a part of any specific package)
31
+ # then the attribute should be set to ''.
32
+ # https://docs.python.org/3.8/library/types.html
33
+ __package__ = ''
venv/lib/python3.10/site-packages/_yaml/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (838 Bytes). View file
 
venv/lib/python3.10/site-packages/annotated_doc-0.0.4.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
venv/lib/python3.10/site-packages/annotated_doc-0.0.4.dist-info/METADATA ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: annotated-doc
3
+ Version: 0.0.4
4
+ Summary: Document parameters, class attributes, return types, and variables inline, with Annotated.
5
+ Author-Email: =?utf-8?q?Sebasti=C3=A1n_Ram=C3=ADrez?= <tiangolo@gmail.com>
6
+ License-Expression: MIT
7
+ License-File: LICENSE
8
+ Classifier: Intended Audience :: Information Technology
9
+ Classifier: Intended Audience :: System Administrators
10
+ Classifier: Operating System :: OS Independent
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python
13
+ Classifier: Topic :: Internet
14
+ Classifier: Topic :: Software Development :: Libraries :: Application Frameworks
15
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
16
+ Classifier: Topic :: Software Development :: Libraries
17
+ Classifier: Topic :: Software Development
18
+ Classifier: Typing :: Typed
19
+ Classifier: Development Status :: 4 - Beta
20
+ Classifier: Intended Audience :: Developers
21
+ Classifier: Programming Language :: Python :: 3 :: Only
22
+ Classifier: Programming Language :: Python :: 3.8
23
+ Classifier: Programming Language :: Python :: 3.9
24
+ Classifier: Programming Language :: Python :: 3.10
25
+ Classifier: Programming Language :: Python :: 3.11
26
+ Classifier: Programming Language :: Python :: 3.12
27
+ Classifier: Programming Language :: Python :: 3.13
28
+ Classifier: Programming Language :: Python :: 3.14
29
+ Project-URL: Homepage, https://github.com/fastapi/annotated-doc
30
+ Project-URL: Documentation, https://github.com/fastapi/annotated-doc
31
+ Project-URL: Repository, https://github.com/fastapi/annotated-doc
32
+ Project-URL: Issues, https://github.com/fastapi/annotated-doc/issues
33
+ Project-URL: Changelog, https://github.com/fastapi/annotated-doc/release-notes.md
34
+ Requires-Python: >=3.8
35
+ Description-Content-Type: text/markdown
36
+
37
+ # Annotated Doc
38
+
39
+ Document parameters, class attributes, return types, and variables inline, with `Annotated`.
40
+
41
+ <a href="https://github.com/fastapi/annotated-doc/actions?query=workflow%3ATest+event%3Apush+branch%3Amain" target="_blank">
42
+ <img src="https://github.com/fastapi/annotated-doc/actions/workflows/test.yml/badge.svg?event=push&branch=main" alt="Test">
43
+ </a>
44
+ <a href="https://coverage-badge.samuelcolvin.workers.dev/redirect/fastapi/annotated-doc" target="_blank">
45
+ <img src="https://coverage-badge.samuelcolvin.workers.dev/fastapi/annotated-doc.svg" alt="Coverage">
46
+ </a>
47
+ <a href="https://pypi.org/project/annotated-doc" target="_blank">
48
+ <img src="https://img.shields.io/pypi/v/annotated-doc?color=%2334D058&label=pypi%20package" alt="Package version">
49
+ </a>
50
+ <a href="https://pypi.org/project/annotated-doc" target="_blank">
51
+ <img src="https://img.shields.io/pypi/pyversions/annotated-doc.svg?color=%2334D058" alt="Supported Python versions">
52
+ </a>
53
+
54
+ ## Installation
55
+
56
+ ```bash
57
+ pip install annotated-doc
58
+ ```
59
+
60
+ Or with `uv`:
61
+
62
+ ```Python
63
+ uv add annotated-doc
64
+ ```
65
+
66
+ ## Usage
67
+
68
+ Import `Doc` and pass a single literal string with the documentation for the specific parameter, class attribute, return type, or variable.
69
+
70
+ For example, to document a parameter `name` in a function `hi` you could do:
71
+
72
+ ```Python
73
+ from typing import Annotated
74
+
75
+ from annotated_doc import Doc
76
+
77
+ def hi(name: Annotated[str, Doc("Who to say hi to")]) -> None:
78
+ print(f"Hi, {name}!")
79
+ ```
80
+
81
+ You can also use it to document class attributes:
82
+
83
+ ```Python
84
+ from typing import Annotated
85
+
86
+ from annotated_doc import Doc
87
+
88
+ class User:
89
+ name: Annotated[str, Doc("The user's name")]
90
+ age: Annotated[int, Doc("The user's age")]
91
+ ```
92
+
93
+ The same way, you could document return types and variables, or anything that could have a type annotation with `Annotated`.
94
+
95
+ ## Who Uses This
96
+
97
+ `annotated-doc` was made for:
98
+
99
+ * [FastAPI](https://fastapi.tiangolo.com/)
100
+ * [Typer](https://typer.tiangolo.com/)
101
+ * [SQLModel](https://sqlmodel.tiangolo.com/)
102
+ * [Asyncer](https://asyncer.tiangolo.com/)
103
+
104
+ `annotated-doc` is supported by [griffe-typingdoc](https://github.com/mkdocstrings/griffe-typingdoc), which powers reference documentation like the one in the [FastAPI Reference](https://fastapi.tiangolo.com/reference/).
105
+
106
+ ## Reasons not to use `annotated-doc`
107
+
108
+ You are already comfortable with one of the existing docstring formats, like:
109
+
110
+ * Sphinx
111
+ * numpydoc
112
+ * Google
113
+ * Keras
114
+
115
+ Your team is already comfortable using them.
116
+
117
+ You prefer having the documentation about parameters all together in a docstring, separated from the code defining them.
118
+
119
+ You care about a specific set of users, using one specific editor, and that editor already has support for the specific docstring format you use.
120
+
121
+ ## Reasons to use `annotated-doc`
122
+
123
+ * No micro-syntax to learn for newcomers, it’s **just Python** syntax.
124
+ * **Editing** would be already fully supported by default by any editor (current or future) supporting Python syntax, including syntax errors, syntax highlighting, etc.
125
+ * **Rendering** would be relatively straightforward to implement by static tools (tools that don't need runtime execution), as the information can be extracted from the AST they normally already create.
126
+ * **Deduplication of information**: the name of a parameter would be defined in a single place, not duplicated inside of a docstring.
127
+ * **Elimination** of the possibility of having **inconsistencies** when removing a parameter or class variable and **forgetting to remove** its documentation.
128
+ * **Minimization** of the probability of adding a new parameter or class variable and **forgetting to add its documentation**.
129
+ * **Elimination** of the possibility of having **inconsistencies** between the **name** of a parameter in the **signature** and the name in the docstring when it is renamed.
130
+ * **Access** to the documentation string for each symbol at **runtime**, including existing (older) Python versions.
131
+ * A more formalized way to document other symbols, like type aliases, that could use Annotated.
132
+ * **Support** for apps using FastAPI, Typer and others.
133
+ * **AI Accessibility**: AI tools will have an easier way understanding each parameter as the distance from documentation to parameter is much closer.
134
+
135
+ ## History
136
+
137
+ I ([@tiangolo](https://github.com/tiangolo)) originally wanted for this to be part of the Python standard library (in [PEP 727](https://peps.python.org/pep-0727/)), but the proposal was withdrawn as there was a fair amount of negative feedback and opposition.
138
+
139
+ The conclusion was that this was better done as an external effort, in a third-party library.
140
+
141
+ So, here it is, with a simpler approach, as a third-party library, in a way that can be used by others, starting with FastAPI and friends.
142
+
143
+ ## License
144
+
145
+ This project is licensed under the terms of the MIT license.
venv/lib/python3.10/site-packages/annotated_doc-0.0.4.dist-info/RECORD ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ annotated_doc-0.0.4.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
2
+ annotated_doc-0.0.4.dist-info/METADATA,sha256=Irm5KJua33dY2qKKAjJ-OhKaVBVIfwFGej_dSe3Z1TU,6566
3
+ annotated_doc-0.0.4.dist-info/RECORD,,
4
+ annotated_doc-0.0.4.dist-info/WHEEL,sha256=9P2ygRxDrTJz3gsagc0Z96ukrxjr-LFBGOgv3AuKlCA,90
5
+ annotated_doc-0.0.4.dist-info/entry_points.txt,sha256=6OYgBcLyFCUgeqLgnvMyOJxPCWzgy7se4rLPKtNonMs,34
6
+ annotated_doc-0.0.4.dist-info/licenses/LICENSE,sha256=__Fwd5pqy_ZavbQFwIfxzuF4ZpHkqWpANFF-SlBKDN8,1086
7
+ annotated_doc/__init__.py,sha256=VuyxxUe80kfEyWnOrCx_Bk8hybo3aKo6RYBlkBBYW8k,52
8
+ annotated_doc/__pycache__/__init__.cpython-310.pyc,,
9
+ annotated_doc/__pycache__/main.cpython-310.pyc,,
10
+ annotated_doc/main.py,sha256=5Zfvxv80SwwLqpRW73AZyZyiM4bWma9QWRbp_cgD20s,1075
11
+ annotated_doc/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
venv/lib/python3.10/site-packages/annotated_doc-0.0.4.dist-info/WHEEL ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: pdm-backend (2.4.5)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
venv/lib/python3.10/site-packages/annotated_doc-0.0.4.dist-info/entry_points.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [console_scripts]
2
+
3
+ [gui_scripts]
4
+
venv/lib/python3.10/site-packages/annotated_doc-0.0.4.dist-info/licenses/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The MIT License (MIT)
2
+
3
+ Copyright (c) 2025 Sebastián Ramírez
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in
13
+ all copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21
+ THE SOFTWARE.
venv/lib/python3.10/site-packages/annotated_doc/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .main import Doc as Doc
2
+
3
+ __version__ = "0.0.4"
venv/lib/python3.10/site-packages/annotated_doc/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (342 Bytes). View file
 
venv/lib/python3.10/site-packages/annotated_doc/__pycache__/main.cpython-310.pyc ADDED
Binary file (1.72 kB). View file
 
venv/lib/python3.10/site-packages/annotated_doc/main.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Doc:
2
+ """Define the documentation of a type annotation using `Annotated`, to be
3
+ used in class attributes, function and method parameters, return values,
4
+ and variables.
5
+
6
+ The value should be a positional-only string literal to allow static tools
7
+ like editors and documentation generators to use it.
8
+
9
+ This complements docstrings.
10
+
11
+ The string value passed is available in the attribute `documentation`.
12
+
13
+ Example:
14
+
15
+ ```Python
16
+ from typing import Annotated
17
+ from annotated_doc import Doc
18
+
19
+ def hi(name: Annotated[str, Doc("Who to say hi to")]) -> None:
20
+ print(f"Hi, {name}!")
21
+ ```
22
+ """
23
+
24
+ def __init__(self, documentation: str, /) -> None:
25
+ self.documentation = documentation
26
+
27
+ def __repr__(self) -> str:
28
+ return f"Doc({self.documentation!r})"
29
+
30
+ def __hash__(self) -> int:
31
+ return hash(self.documentation)
32
+
33
+ def __eq__(self, other: object) -> bool:
34
+ if not isinstance(other, Doc):
35
+ return NotImplemented
36
+ return self.documentation == other.documentation
venv/lib/python3.10/site-packages/annotated_doc/py.typed ADDED
File without changes
venv/lib/python3.10/site-packages/anyio/__init__.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from ._core._contextmanagers import AsyncContextManagerMixin as AsyncContextManagerMixin
4
+ from ._core._contextmanagers import ContextManagerMixin as ContextManagerMixin
5
+ from ._core._eventloop import current_time as current_time
6
+ from ._core._eventloop import get_all_backends as get_all_backends
7
+ from ._core._eventloop import get_available_backends as get_available_backends
8
+ from ._core._eventloop import get_cancelled_exc_class as get_cancelled_exc_class
9
+ from ._core._eventloop import run as run
10
+ from ._core._eventloop import sleep as sleep
11
+ from ._core._eventloop import sleep_forever as sleep_forever
12
+ from ._core._eventloop import sleep_until as sleep_until
13
+ from ._core._exceptions import BrokenResourceError as BrokenResourceError
14
+ from ._core._exceptions import BrokenWorkerInterpreter as BrokenWorkerInterpreter
15
+ from ._core._exceptions import BrokenWorkerProcess as BrokenWorkerProcess
16
+ from ._core._exceptions import BusyResourceError as BusyResourceError
17
+ from ._core._exceptions import ClosedResourceError as ClosedResourceError
18
+ from ._core._exceptions import ConnectionFailed as ConnectionFailed
19
+ from ._core._exceptions import DelimiterNotFound as DelimiterNotFound
20
+ from ._core._exceptions import EndOfStream as EndOfStream
21
+ from ._core._exceptions import IncompleteRead as IncompleteRead
22
+ from ._core._exceptions import NoEventLoopError as NoEventLoopError
23
+ from ._core._exceptions import RunFinishedError as RunFinishedError
24
+ from ._core._exceptions import TypedAttributeLookupError as TypedAttributeLookupError
25
+ from ._core._exceptions import WouldBlock as WouldBlock
26
+ from ._core._fileio import AsyncFile as AsyncFile
27
+ from ._core._fileio import Path as Path
28
+ from ._core._fileio import open_file as open_file
29
+ from ._core._fileio import wrap_file as wrap_file
30
+ from ._core._resources import aclose_forcefully as aclose_forcefully
31
+ from ._core._signals import open_signal_receiver as open_signal_receiver
32
+ from ._core._sockets import TCPConnectable as TCPConnectable
33
+ from ._core._sockets import UNIXConnectable as UNIXConnectable
34
+ from ._core._sockets import as_connectable as as_connectable
35
+ from ._core._sockets import connect_tcp as connect_tcp
36
+ from ._core._sockets import connect_unix as connect_unix
37
+ from ._core._sockets import create_connected_udp_socket as create_connected_udp_socket
38
+ from ._core._sockets import (
39
+ create_connected_unix_datagram_socket as create_connected_unix_datagram_socket,
40
+ )
41
+ from ._core._sockets import create_tcp_listener as create_tcp_listener
42
+ from ._core._sockets import create_udp_socket as create_udp_socket
43
+ from ._core._sockets import create_unix_datagram_socket as create_unix_datagram_socket
44
+ from ._core._sockets import create_unix_listener as create_unix_listener
45
+ from ._core._sockets import getaddrinfo as getaddrinfo
46
+ from ._core._sockets import getnameinfo as getnameinfo
47
+ from ._core._sockets import notify_closing as notify_closing
48
+ from ._core._sockets import wait_readable as wait_readable
49
+ from ._core._sockets import wait_socket_readable as wait_socket_readable
50
+ from ._core._sockets import wait_socket_writable as wait_socket_writable
51
+ from ._core._sockets import wait_writable as wait_writable
52
+ from ._core._streams import create_memory_object_stream as create_memory_object_stream
53
+ from ._core._subprocesses import open_process as open_process
54
+ from ._core._subprocesses import run_process as run_process
55
+ from ._core._synchronization import CapacityLimiter as CapacityLimiter
56
+ from ._core._synchronization import (
57
+ CapacityLimiterStatistics as CapacityLimiterStatistics,
58
+ )
59
+ from ._core._synchronization import Condition as Condition
60
+ from ._core._synchronization import ConditionStatistics as ConditionStatistics
61
+ from ._core._synchronization import Event as Event
62
+ from ._core._synchronization import EventStatistics as EventStatistics
63
+ from ._core._synchronization import Lock as Lock
64
+ from ._core._synchronization import LockStatistics as LockStatistics
65
+ from ._core._synchronization import ResourceGuard as ResourceGuard
66
+ from ._core._synchronization import Semaphore as Semaphore
67
+ from ._core._synchronization import SemaphoreStatistics as SemaphoreStatistics
68
+ from ._core._tasks import TASK_STATUS_IGNORED as TASK_STATUS_IGNORED
69
+ from ._core._tasks import CancelScope as CancelScope
70
+ from ._core._tasks import create_task_group as create_task_group
71
+ from ._core._tasks import current_effective_deadline as current_effective_deadline
72
+ from ._core._tasks import fail_after as fail_after
73
+ from ._core._tasks import move_on_after as move_on_after
74
+ from ._core._tempfile import NamedTemporaryFile as NamedTemporaryFile
75
+ from ._core._tempfile import SpooledTemporaryFile as SpooledTemporaryFile
76
+ from ._core._tempfile import TemporaryDirectory as TemporaryDirectory
77
+ from ._core._tempfile import TemporaryFile as TemporaryFile
78
+ from ._core._tempfile import gettempdir as gettempdir
79
+ from ._core._tempfile import gettempdirb as gettempdirb
80
+ from ._core._tempfile import mkdtemp as mkdtemp
81
+ from ._core._tempfile import mkstemp as mkstemp
82
+ from ._core._testing import TaskInfo as TaskInfo
83
+ from ._core._testing import get_current_task as get_current_task
84
+ from ._core._testing import get_running_tasks as get_running_tasks
85
+ from ._core._testing import wait_all_tasks_blocked as wait_all_tasks_blocked
86
+ from ._core._typedattr import TypedAttributeProvider as TypedAttributeProvider
87
+ from ._core._typedattr import TypedAttributeSet as TypedAttributeSet
88
+ from ._core._typedattr import typed_attribute as typed_attribute
89
+
90
+ # Re-export imports so they look like they live directly in this package
91
+ for __value in list(locals().values()):
92
+ if getattr(__value, "__module__", "").startswith("anyio."):
93
+ __value.__module__ = __name__
94
+
95
+
96
+ del __value
97
+
98
+
99
+ def __getattr__(attr: str) -> type[BrokenWorkerInterpreter]:
100
+ """Support deprecated aliases."""
101
+ if attr == "BrokenWorkerIntepreter":
102
+ import warnings
103
+
104
+ warnings.warn(
105
+ "The 'BrokenWorkerIntepreter' alias is deprecated, use 'BrokenWorkerInterpreter' instead.",
106
+ DeprecationWarning,
107
+ stacklevel=2,
108
+ )
109
+ return BrokenWorkerInterpreter
110
+
111
+ raise AttributeError(f"module {__name__!r} has no attribute {attr!r}")
venv/lib/python3.10/site-packages/anyio/from_thread.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ __all__ = (
4
+ "BlockingPortal",
5
+ "BlockingPortalProvider",
6
+ "check_cancelled",
7
+ "run",
8
+ "run_sync",
9
+ "start_blocking_portal",
10
+ )
11
+
12
+ import sys
13
+ from collections.abc import Awaitable, Callable, Generator
14
+ from concurrent.futures import Future
15
+ from contextlib import (
16
+ AbstractAsyncContextManager,
17
+ AbstractContextManager,
18
+ contextmanager,
19
+ )
20
+ from dataclasses import dataclass, field
21
+ from functools import partial
22
+ from inspect import isawaitable
23
+ from threading import Lock, Thread, current_thread, get_ident
24
+ from types import TracebackType
25
+ from typing import (
26
+ Any,
27
+ Generic,
28
+ TypeVar,
29
+ cast,
30
+ overload,
31
+ )
32
+
33
+ from ._core._eventloop import (
34
+ get_cancelled_exc_class,
35
+ threadlocals,
36
+ )
37
+ from ._core._eventloop import run as run_eventloop
38
+ from ._core._exceptions import NoEventLoopError
39
+ from ._core._synchronization import Event
40
+ from ._core._tasks import CancelScope, create_task_group
41
+ from .abc._tasks import TaskStatus
42
+ from .lowlevel import EventLoopToken, current_token
43
+
44
+ if sys.version_info >= (3, 11):
45
+ from typing import TypeVarTuple, Unpack
46
+ else:
47
+ from typing_extensions import TypeVarTuple, Unpack
48
+
49
+ T_Retval = TypeVar("T_Retval")
50
+ T_co = TypeVar("T_co", covariant=True)
51
+ PosArgsT = TypeVarTuple("PosArgsT")
52
+
53
+
54
+ def _token_or_error(token: EventLoopToken | None) -> EventLoopToken:
55
+ if token is not None:
56
+ return token
57
+
58
+ try:
59
+ return threadlocals.current_token
60
+ except AttributeError:
61
+ raise NoEventLoopError(
62
+ "Not running inside an AnyIO worker thread, and no event loop token was "
63
+ "provided"
64
+ ) from None
65
+
66
+
67
+ def run(
68
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
69
+ *args: Unpack[PosArgsT],
70
+ token: EventLoopToken | None = None,
71
+ ) -> T_Retval:
72
+ """
73
+ Call a coroutine function from a worker thread.
74
+
75
+ :param func: a coroutine function
76
+ :param args: positional arguments for the callable
77
+ :param token: an event loop token to use to get back to the event loop thread
78
+ (required if calling this function from outside an AnyIO worker thread)
79
+ :return: the return value of the coroutine function
80
+ :raises MissingTokenError: if no token was provided and called from outside an
81
+ AnyIO worker thread
82
+ :raises RunFinishedError: if the event loop tied to ``token`` is no longer running
83
+
84
+ .. versionchanged:: 4.11.0
85
+ Added the ``token`` parameter.
86
+
87
+ """
88
+ explicit_token = token is not None
89
+ token = _token_or_error(token)
90
+ return token.backend_class.run_async_from_thread(
91
+ func, args, token=token.native_token if explicit_token else None
92
+ )
93
+
94
+
95
+ def run_sync(
96
+ func: Callable[[Unpack[PosArgsT]], T_Retval],
97
+ *args: Unpack[PosArgsT],
98
+ token: EventLoopToken | None = None,
99
+ ) -> T_Retval:
100
+ """
101
+ Call a function in the event loop thread from a worker thread.
102
+
103
+ :param func: a callable
104
+ :param args: positional arguments for the callable
105
+ :param token: an event loop token to use to get back to the event loop thread
106
+ (required if calling this function from outside an AnyIO worker thread)
107
+ :return: the return value of the callable
108
+ :raises MissingTokenError: if no token was provided and called from outside an
109
+ AnyIO worker thread
110
+ :raises RunFinishedError: if the event loop tied to ``token`` is no longer running
111
+
112
+ .. versionchanged:: 4.11.0
113
+ Added the ``token`` parameter.
114
+
115
+ """
116
+ explicit_token = token is not None
117
+ token = _token_or_error(token)
118
+ return token.backend_class.run_sync_from_thread(
119
+ func, args, token=token.native_token if explicit_token else None
120
+ )
121
+
122
+
123
+ class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager):
124
+ _enter_future: Future[T_co]
125
+ _exit_future: Future[bool | None]
126
+ _exit_event: Event
127
+ _exit_exc_info: tuple[
128
+ type[BaseException] | None, BaseException | None, TracebackType | None
129
+ ] = (None, None, None)
130
+
131
+ def __init__(
132
+ self, async_cm: AbstractAsyncContextManager[T_co], portal: BlockingPortal
133
+ ):
134
+ self._async_cm = async_cm
135
+ self._portal = portal
136
+
137
+ async def run_async_cm(self) -> bool | None:
138
+ try:
139
+ self._exit_event = Event()
140
+ value = await self._async_cm.__aenter__()
141
+ except BaseException as exc:
142
+ self._enter_future.set_exception(exc)
143
+ raise
144
+ else:
145
+ self._enter_future.set_result(value)
146
+
147
+ try:
148
+ # Wait for the sync context manager to exit.
149
+ # This next statement can raise `get_cancelled_exc_class()` if
150
+ # something went wrong in a task group in this async context
151
+ # manager.
152
+ await self._exit_event.wait()
153
+ finally:
154
+ # In case of cancellation, it could be that we end up here before
155
+ # `_BlockingAsyncContextManager.__exit__` is called, and an
156
+ # `_exit_exc_info` has been set.
157
+ result = await self._async_cm.__aexit__(*self._exit_exc_info)
158
+
159
+ return result
160
+
161
+ def __enter__(self) -> T_co:
162
+ self._enter_future = Future()
163
+ self._exit_future = self._portal.start_task_soon(self.run_async_cm)
164
+ return self._enter_future.result()
165
+
166
+ def __exit__(
167
+ self,
168
+ __exc_type: type[BaseException] | None,
169
+ __exc_value: BaseException | None,
170
+ __traceback: TracebackType | None,
171
+ ) -> bool | None:
172
+ self._exit_exc_info = __exc_type, __exc_value, __traceback
173
+ self._portal.call(self._exit_event.set)
174
+ return self._exit_future.result()
175
+
176
+
177
+ class _BlockingPortalTaskStatus(TaskStatus):
178
+ def __init__(self, future: Future):
179
+ self._future = future
180
+
181
+ def started(self, value: object = None) -> None:
182
+ self._future.set_result(value)
183
+
184
+
185
+ class BlockingPortal:
186
+ """
187
+ An object that lets external threads run code in an asynchronous event loop.
188
+
189
+ :raises NoEventLoopError: if no supported asynchronous event loop is running in the
190
+ current thread
191
+ """
192
+
193
+ def __init__(self) -> None:
194
+ self._token = current_token()
195
+ self._event_loop_thread_id: int | None = get_ident()
196
+ self._stop_event = Event()
197
+ self._task_group = create_task_group()
198
+
199
+ async def __aenter__(self) -> BlockingPortal:
200
+ await self._task_group.__aenter__()
201
+ return self
202
+
203
+ async def __aexit__(
204
+ self,
205
+ exc_type: type[BaseException] | None,
206
+ exc_val: BaseException | None,
207
+ exc_tb: TracebackType | None,
208
+ ) -> bool:
209
+ await self.stop()
210
+ return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
211
+
212
+ def _check_running(self) -> None:
213
+ if self._event_loop_thread_id is None:
214
+ raise RuntimeError("This portal is not running")
215
+ if self._event_loop_thread_id == get_ident():
216
+ raise RuntimeError(
217
+ "This method cannot be called from the event loop thread"
218
+ )
219
+
220
+ async def sleep_until_stopped(self) -> None:
221
+ """Sleep until :meth:`stop` is called."""
222
+ await self._stop_event.wait()
223
+
224
+ async def stop(self, cancel_remaining: bool = False) -> None:
225
+ """
226
+ Signal the portal to shut down.
227
+
228
+ This marks the portal as no longer accepting new calls and exits from
229
+ :meth:`sleep_until_stopped`.
230
+
231
+ :param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False``
232
+ to let them finish before returning
233
+
234
+ """
235
+ self._event_loop_thread_id = None
236
+ self._stop_event.set()
237
+ if cancel_remaining:
238
+ self._task_group.cancel_scope.cancel("the blocking portal is shutting down")
239
+
240
+ async def _call_func(
241
+ self,
242
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
243
+ args: tuple[Unpack[PosArgsT]],
244
+ kwargs: dict[str, Any],
245
+ future: Future[T_Retval],
246
+ ) -> None:
247
+ def callback(f: Future[T_Retval]) -> None:
248
+ if f.cancelled():
249
+ if self._event_loop_thread_id == get_ident():
250
+ scope.cancel("the future was cancelled")
251
+ elif self._event_loop_thread_id is not None:
252
+ self.call(scope.cancel, "the future was cancelled")
253
+
254
+ try:
255
+ retval_or_awaitable = func(*args, **kwargs)
256
+ if isawaitable(retval_or_awaitable):
257
+ with CancelScope() as scope:
258
+ future.add_done_callback(callback)
259
+ retval = await retval_or_awaitable
260
+ else:
261
+ retval = retval_or_awaitable
262
+ except get_cancelled_exc_class():
263
+ future.cancel()
264
+ future.set_running_or_notify_cancel()
265
+ except BaseException as exc:
266
+ if not future.cancelled():
267
+ future.set_exception(exc)
268
+
269
+ # Let base exceptions fall through
270
+ if not isinstance(exc, Exception):
271
+ raise
272
+ else:
273
+ if not future.cancelled():
274
+ future.set_result(retval)
275
+ finally:
276
+ scope = None # type: ignore[assignment]
277
+
278
+ def _spawn_task_from_thread(
279
+ self,
280
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
281
+ args: tuple[Unpack[PosArgsT]],
282
+ kwargs: dict[str, Any],
283
+ name: object,
284
+ future: Future[T_Retval],
285
+ ) -> None:
286
+ """
287
+ Spawn a new task using the given callable.
288
+
289
+ :param func: a callable
290
+ :param args: positional arguments to be passed to the callable
291
+ :param kwargs: keyword arguments to be passed to the callable
292
+ :param name: name of the task (will be coerced to a string if not ``None``)
293
+ :param future: a future that will resolve to the return value of the callable,
294
+ or the exception raised during its execution
295
+
296
+ """
297
+ run_sync(
298
+ partial(self._task_group.start_soon, name=name),
299
+ self._call_func,
300
+ func,
301
+ args,
302
+ kwargs,
303
+ future,
304
+ token=self._token,
305
+ )
306
+
307
+ @overload
308
+ def call(
309
+ self,
310
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
311
+ *args: Unpack[PosArgsT],
312
+ ) -> T_Retval: ...
313
+
314
+ @overload
315
+ def call(
316
+ self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
317
+ ) -> T_Retval: ...
318
+
319
+ def call(
320
+ self,
321
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
322
+ *args: Unpack[PosArgsT],
323
+ ) -> T_Retval:
324
+ """
325
+ Call the given function in the event loop thread.
326
+
327
+ If the callable returns a coroutine object, it is awaited on.
328
+
329
+ :param func: any callable
330
+ :raises RuntimeError: if the portal is not running or if this method is called
331
+ from within the event loop thread
332
+
333
+ """
334
+ return cast(T_Retval, self.start_task_soon(func, *args).result())
335
+
336
+ @overload
337
+ def start_task_soon(
338
+ self,
339
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
340
+ *args: Unpack[PosArgsT],
341
+ name: object = None,
342
+ ) -> Future[T_Retval]: ...
343
+
344
+ @overload
345
+ def start_task_soon(
346
+ self,
347
+ func: Callable[[Unpack[PosArgsT]], T_Retval],
348
+ *args: Unpack[PosArgsT],
349
+ name: object = None,
350
+ ) -> Future[T_Retval]: ...
351
+
352
+ def start_task_soon(
353
+ self,
354
+ func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
355
+ *args: Unpack[PosArgsT],
356
+ name: object = None,
357
+ ) -> Future[T_Retval]:
358
+ """
359
+ Start a task in the portal's task group.
360
+
361
+ The task will be run inside a cancel scope which can be cancelled by cancelling
362
+ the returned future.
363
+
364
+ :param func: the target function
365
+ :param args: positional arguments passed to ``func``
366
+ :param name: name of the task (will be coerced to a string if not ``None``)
367
+ :return: a future that resolves with the return value of the callable if the
368
+ task completes successfully, or with the exception raised in the task
369
+ :raises RuntimeError: if the portal is not running or if this method is called
370
+ from within the event loop thread
371
+ :rtype: concurrent.futures.Future[T_Retval]
372
+
373
+ .. versionadded:: 3.0
374
+
375
+ """
376
+ self._check_running()
377
+ f: Future[T_Retval] = Future()
378
+ self._spawn_task_from_thread(func, args, {}, name, f)
379
+ return f
380
+
381
+ def start_task(
382
+ self,
383
+ func: Callable[..., Awaitable[T_Retval]],
384
+ *args: object,
385
+ name: object = None,
386
+ ) -> tuple[Future[T_Retval], Any]:
387
+ """
388
+ Start a task in the portal's task group and wait until it signals for readiness.
389
+
390
+ This method works the same way as :meth:`.abc.TaskGroup.start`.
391
+
392
+ :param func: the target function
393
+ :param args: positional arguments passed to ``func``
394
+ :param name: name of the task (will be coerced to a string if not ``None``)
395
+ :return: a tuple of (future, task_status_value) where the ``task_status_value``
396
+ is the value passed to ``task_status.started()`` from within the target
397
+ function
398
+ :rtype: tuple[concurrent.futures.Future[T_Retval], Any]
399
+
400
+ .. versionadded:: 3.0
401
+
402
+ """
403
+
404
+ def task_done(future: Future[T_Retval]) -> None:
405
+ if not task_status_future.done():
406
+ if future.cancelled():
407
+ task_status_future.cancel()
408
+ elif future.exception():
409
+ task_status_future.set_exception(future.exception())
410
+ else:
411
+ exc = RuntimeError(
412
+ "Task exited without calling task_status.started()"
413
+ )
414
+ task_status_future.set_exception(exc)
415
+
416
+ self._check_running()
417
+ task_status_future: Future = Future()
418
+ task_status = _BlockingPortalTaskStatus(task_status_future)
419
+ f: Future = Future()
420
+ f.add_done_callback(task_done)
421
+ self._spawn_task_from_thread(func, args, {"task_status": task_status}, name, f)
422
+ return f, task_status_future.result()
423
+
424
+ def wrap_async_context_manager(
425
+ self, cm: AbstractAsyncContextManager[T_co]
426
+ ) -> AbstractContextManager[T_co]:
427
+ """
428
+ Wrap an async context manager as a synchronous context manager via this portal.
429
+
430
+ Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping
431
+ in the middle until the synchronous context manager exits.
432
+
433
+ :param cm: an asynchronous context manager
434
+ :return: a synchronous context manager
435
+
436
+ .. versionadded:: 2.1
437
+
438
+ """
439
+ return _BlockingAsyncContextManager(cm, self)
440
+
441
+
442
+ @dataclass
443
+ class BlockingPortalProvider:
444
+ """
445
+ A manager for a blocking portal. Used as a context manager. The first thread to
446
+ enter this context manager causes a blocking portal to be started with the specific
447
+ parameters, and the last thread to exit causes the portal to be shut down. Thus,
448
+ there will be exactly one blocking portal running in this context as long as at
449
+ least one thread has entered this context manager.
450
+
451
+ The parameters are the same as for :func:`~anyio.run`.
452
+
453
+ :param backend: name of the backend
454
+ :param backend_options: backend options
455
+
456
+ .. versionadded:: 4.4
457
+ """
458
+
459
+ backend: str = "asyncio"
460
+ backend_options: dict[str, Any] | None = None
461
+ _lock: Lock = field(init=False, default_factory=Lock)
462
+ _leases: int = field(init=False, default=0)
463
+ _portal: BlockingPortal = field(init=False)
464
+ _portal_cm: AbstractContextManager[BlockingPortal] | None = field(
465
+ init=False, default=None
466
+ )
467
+
468
+ def __enter__(self) -> BlockingPortal:
469
+ with self._lock:
470
+ if self._portal_cm is None:
471
+ self._portal_cm = start_blocking_portal(
472
+ self.backend, self.backend_options
473
+ )
474
+ self._portal = self._portal_cm.__enter__()
475
+
476
+ self._leases += 1
477
+ return self._portal
478
+
479
+ def __exit__(
480
+ self,
481
+ exc_type: type[BaseException] | None,
482
+ exc_val: BaseException | None,
483
+ exc_tb: TracebackType | None,
484
+ ) -> None:
485
+ portal_cm: AbstractContextManager[BlockingPortal] | None = None
486
+ with self._lock:
487
+ assert self._portal_cm
488
+ assert self._leases > 0
489
+ self._leases -= 1
490
+ if not self._leases:
491
+ portal_cm = self._portal_cm
492
+ self._portal_cm = None
493
+ del self._portal
494
+
495
+ if portal_cm:
496
+ portal_cm.__exit__(None, None, None)
497
+
498
+
499
+ @contextmanager
500
+ def start_blocking_portal(
501
+ backend: str = "asyncio",
502
+ backend_options: dict[str, Any] | None = None,
503
+ *,
504
+ name: str | None = None,
505
+ ) -> Generator[BlockingPortal, Any, None]:
506
+ """
507
+ Start a new event loop in a new thread and run a blocking portal in its main task.
508
+
509
+ The parameters are the same as for :func:`~anyio.run`.
510
+
511
+ :param backend: name of the backend
512
+ :param backend_options: backend options
513
+ :param name: name of the thread
514
+ :return: a context manager that yields a blocking portal
515
+
516
+ .. versionchanged:: 3.0
517
+ Usage as a context manager is now required.
518
+
519
+ """
520
+
521
+ async def run_portal() -> None:
522
+ async with BlockingPortal() as portal_:
523
+ if name is None:
524
+ current_thread().name = f"{backend}-portal-{id(portal_):x}"
525
+
526
+ future.set_result(portal_)
527
+ await portal_.sleep_until_stopped()
528
+
529
+ def run_blocking_portal() -> None:
530
+ if future.set_running_or_notify_cancel():
531
+ try:
532
+ run_eventloop(
533
+ run_portal, backend=backend, backend_options=backend_options
534
+ )
535
+ except BaseException as exc:
536
+ if not future.done():
537
+ future.set_exception(exc)
538
+
539
+ future: Future[BlockingPortal] = Future()
540
+ thread = Thread(target=run_blocking_portal, daemon=True, name=name)
541
+ thread.start()
542
+ try:
543
+ cancel_remaining_tasks = False
544
+ portal = future.result()
545
+ try:
546
+ yield portal
547
+ except BaseException:
548
+ cancel_remaining_tasks = True
549
+ raise
550
+ finally:
551
+ try:
552
+ portal.call(portal.stop, cancel_remaining_tasks)
553
+ except RuntimeError:
554
+ pass
555
+ finally:
556
+ thread.join()
557
+
558
+
559
+ def check_cancelled() -> None:
560
+ """
561
+ Check if the cancel scope of the host task's running the current worker thread has
562
+ been cancelled.
563
+
564
+ If the host task's current cancel scope has indeed been cancelled, the
565
+ backend-specific cancellation exception will be raised.
566
+
567
+ :raises RuntimeError: if the current thread was not spawned by
568
+ :func:`.to_thread.run_sync`
569
+
570
+ """
571
+ try:
572
+ token: EventLoopToken = threadlocals.current_token
573
+ except AttributeError:
574
+ raise NoEventLoopError(
575
+ "This function can only be called inside an AnyIO worker thread"
576
+ ) from None
577
+
578
+ token.backend_class.check_cancelled()
venv/lib/python3.10/site-packages/anyio/functools.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ __all__ = (
4
+ "AsyncCacheInfo",
5
+ "AsyncCacheParameters",
6
+ "AsyncLRUCacheWrapper",
7
+ "cache",
8
+ "lru_cache",
9
+ "reduce",
10
+ )
11
+
12
+ import functools
13
+ import sys
14
+ from collections import OrderedDict
15
+ from collections.abc import (
16
+ AsyncIterable,
17
+ Awaitable,
18
+ Callable,
19
+ Coroutine,
20
+ Hashable,
21
+ Iterable,
22
+ )
23
+ from functools import update_wrapper
24
+ from inspect import iscoroutinefunction
25
+ from typing import (
26
+ Any,
27
+ Generic,
28
+ NamedTuple,
29
+ TypedDict,
30
+ TypeVar,
31
+ cast,
32
+ final,
33
+ overload,
34
+ )
35
+ from weakref import WeakKeyDictionary
36
+
37
+ from ._core._synchronization import Lock
38
+ from .lowlevel import RunVar, checkpoint
39
+
40
+ if sys.version_info >= (3, 11):
41
+ from typing import ParamSpec
42
+ else:
43
+ from typing_extensions import ParamSpec
44
+
45
+ T = TypeVar("T")
46
+ S = TypeVar("S")
47
+ P = ParamSpec("P")
48
+ lru_cache_items: RunVar[
49
+ WeakKeyDictionary[
50
+ AsyncLRUCacheWrapper[Any, Any],
51
+ OrderedDict[Hashable, tuple[_InitialMissingType, Lock] | tuple[Any, None]],
52
+ ]
53
+ ] = RunVar("lru_cache_items")
54
+
55
+
56
+ class _InitialMissingType:
57
+ pass
58
+
59
+
60
+ initial_missing: _InitialMissingType = _InitialMissingType()
61
+
62
+
63
+ class AsyncCacheInfo(NamedTuple):
64
+ hits: int
65
+ misses: int
66
+ maxsize: int | None
67
+ currsize: int
68
+
69
+
70
+ class AsyncCacheParameters(TypedDict):
71
+ maxsize: int | None
72
+ typed: bool
73
+ always_checkpoint: bool
74
+
75
+
76
+ class _LRUMethodWrapper(Generic[T]):
77
+ def __init__(self, wrapper: AsyncLRUCacheWrapper[..., T], instance: object):
78
+ self.__wrapper = wrapper
79
+ self.__instance = instance
80
+
81
+ def cache_info(self) -> AsyncCacheInfo:
82
+ return self.__wrapper.cache_info()
83
+
84
+ def cache_parameters(self) -> AsyncCacheParameters:
85
+ return self.__wrapper.cache_parameters()
86
+
87
+ def cache_clear(self) -> None:
88
+ self.__wrapper.cache_clear()
89
+
90
+ async def __call__(self, *args: Any, **kwargs: Any) -> T:
91
+ if self.__instance is None:
92
+ return await self.__wrapper(*args, **kwargs)
93
+
94
+ return await self.__wrapper(self.__instance, *args, **kwargs)
95
+
96
+
97
+ @final
98
+ class AsyncLRUCacheWrapper(Generic[P, T]):
99
+ def __init__(
100
+ self,
101
+ func: Callable[P, Awaitable[T]],
102
+ maxsize: int | None,
103
+ typed: bool,
104
+ always_checkpoint: bool,
105
+ ):
106
+ self.__wrapped__ = func
107
+ self._hits: int = 0
108
+ self._misses: int = 0
109
+ self._maxsize = max(maxsize, 0) if maxsize is not None else None
110
+ self._currsize: int = 0
111
+ self._typed = typed
112
+ self._always_checkpoint = always_checkpoint
113
+ update_wrapper(self, func)
114
+
115
+ def cache_info(self) -> AsyncCacheInfo:
116
+ return AsyncCacheInfo(self._hits, self._misses, self._maxsize, self._currsize)
117
+
118
+ def cache_parameters(self) -> AsyncCacheParameters:
119
+ return {
120
+ "maxsize": self._maxsize,
121
+ "typed": self._typed,
122
+ "always_checkpoint": self._always_checkpoint,
123
+ }
124
+
125
+ def cache_clear(self) -> None:
126
+ if cache := lru_cache_items.get(None):
127
+ cache.pop(self, None)
128
+ self._hits = self._misses = self._currsize = 0
129
+
130
+ async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
131
+ # Easy case first: if maxsize == 0, no caching is done
132
+ if self._maxsize == 0:
133
+ value = await self.__wrapped__(*args, **kwargs)
134
+ self._misses += 1
135
+ return value
136
+
137
+ # The key is constructed as a flat tuple to avoid memory overhead
138
+ key: tuple[Any, ...] = args
139
+ if kwargs:
140
+ # initial_missing is used as a separator
141
+ key += (initial_missing,) + sum(kwargs.items(), ())
142
+
143
+ if self._typed:
144
+ key += tuple(type(arg) for arg in args)
145
+ if kwargs:
146
+ key += (initial_missing,) + tuple(type(val) for val in kwargs.values())
147
+
148
+ try:
149
+ cache = lru_cache_items.get()
150
+ except LookupError:
151
+ cache = WeakKeyDictionary()
152
+ lru_cache_items.set(cache)
153
+
154
+ try:
155
+ cache_entry = cache[self]
156
+ except KeyError:
157
+ cache_entry = cache[self] = OrderedDict()
158
+
159
+ cached_value: T | _InitialMissingType
160
+ try:
161
+ cached_value, lock = cache_entry[key]
162
+ except KeyError:
163
+ # We're the first task to call this function
164
+ cached_value, lock = (
165
+ initial_missing,
166
+ Lock(fast_acquire=not self._always_checkpoint),
167
+ )
168
+ cache_entry[key] = cached_value, lock
169
+
170
+ if lock is None:
171
+ # The value was already cached
172
+ self._hits += 1
173
+ cache_entry.move_to_end(key)
174
+ if self._always_checkpoint:
175
+ await checkpoint()
176
+
177
+ return cast(T, cached_value)
178
+
179
+ async with lock:
180
+ # Check if another task filled the cache while we acquired the lock
181
+ if (cached_value := cache_entry[key][0]) is initial_missing:
182
+ self._misses += 1
183
+ if self._maxsize is not None and self._currsize >= self._maxsize:
184
+ cache_entry.popitem(last=False)
185
+ else:
186
+ self._currsize += 1
187
+
188
+ value = await self.__wrapped__(*args, **kwargs)
189
+ cache_entry[key] = value, None
190
+ else:
191
+ # Another task filled the cache while we were waiting for the lock
192
+ self._hits += 1
193
+ cache_entry.move_to_end(key)
194
+ value = cast(T, cached_value)
195
+
196
+ return value
197
+
198
+ def __get__(
199
+ self, instance: object, owner: type | None = None
200
+ ) -> _LRUMethodWrapper[T]:
201
+ wrapper = _LRUMethodWrapper(self, instance)
202
+ update_wrapper(wrapper, self.__wrapped__)
203
+ return wrapper
204
+
205
+
206
+ class _LRUCacheWrapper(Generic[T]):
207
+ def __init__(self, maxsize: int | None, typed: bool, always_checkpoint: bool):
208
+ self._maxsize = maxsize
209
+ self._typed = typed
210
+ self._always_checkpoint = always_checkpoint
211
+
212
+ @overload
213
+ def __call__( # type: ignore[overload-overlap]
214
+ self, func: Callable[P, Coroutine[Any, Any, T]], /
215
+ ) -> AsyncLRUCacheWrapper[P, T]: ...
216
+
217
+ @overload
218
+ def __call__(
219
+ self, func: Callable[..., T], /
220
+ ) -> functools._lru_cache_wrapper[T]: ...
221
+
222
+ def __call__(
223
+ self, f: Callable[P, Coroutine[Any, Any, T]] | Callable[..., T], /
224
+ ) -> AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T]:
225
+ if iscoroutinefunction(f):
226
+ return AsyncLRUCacheWrapper(
227
+ f, self._maxsize, self._typed, self._always_checkpoint
228
+ )
229
+
230
+ return functools.lru_cache(maxsize=self._maxsize, typed=self._typed)(f) # type: ignore[arg-type]
231
+
232
+
233
+ @overload
234
+ def cache( # type: ignore[overload-overlap]
235
+ func: Callable[P, Coroutine[Any, Any, T]], /
236
+ ) -> AsyncLRUCacheWrapper[P, T]: ...
237
+
238
+
239
+ @overload
240
+ def cache(func: Callable[..., T], /) -> functools._lru_cache_wrapper[T]: ...
241
+
242
+
243
+ def cache(
244
+ func: Callable[..., T] | Callable[P, Coroutine[Any, Any, T]], /
245
+ ) -> AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T]:
246
+ """
247
+ A convenient shortcut for :func:`lru_cache` with ``maxsize=None``.
248
+
249
+ This is the asynchronous equivalent to :func:`functools.cache`.
250
+
251
+ """
252
+ return lru_cache(maxsize=None)(func)
253
+
254
+
255
+ @overload
256
+ def lru_cache(
257
+ *, maxsize: int | None = ..., typed: bool = ..., always_checkpoint: bool = ...
258
+ ) -> _LRUCacheWrapper[Any]: ...
259
+
260
+
261
+ @overload
262
+ def lru_cache( # type: ignore[overload-overlap]
263
+ func: Callable[P, Coroutine[Any, Any, T]], /
264
+ ) -> AsyncLRUCacheWrapper[P, T]: ...
265
+
266
+
267
+ @overload
268
+ def lru_cache(func: Callable[..., T], /) -> functools._lru_cache_wrapper[T]: ...
269
+
270
+
271
+ def lru_cache(
272
+ func: Callable[P, Coroutine[Any, Any, T]] | Callable[..., T] | None = None,
273
+ /,
274
+ *,
275
+ maxsize: int | None = 128,
276
+ typed: bool = False,
277
+ always_checkpoint: bool = False,
278
+ ) -> (
279
+ AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T] | _LRUCacheWrapper[Any]
280
+ ):
281
+ """
282
+ An asynchronous version of :func:`functools.lru_cache`.
283
+
284
+ If a synchronous function is passed, the standard library
285
+ :func:`functools.lru_cache` is applied instead.
286
+
287
+ :param always_checkpoint: if ``True``, every call to the cached function will be
288
+ guaranteed to yield control to the event loop at least once
289
+
290
+ .. note:: Caches and locks are managed on a per-event loop basis.
291
+
292
+ """
293
+ if func is None:
294
+ return _LRUCacheWrapper[Any](maxsize, typed, always_checkpoint)
295
+
296
+ if not callable(func):
297
+ raise TypeError("the first argument must be callable")
298
+
299
+ return _LRUCacheWrapper[T](maxsize, typed, always_checkpoint)(func)
300
+
301
+
302
+ @overload
303
+ async def reduce(
304
+ function: Callable[[T, S], Awaitable[T]],
305
+ iterable: Iterable[S] | AsyncIterable[S],
306
+ /,
307
+ initial: T,
308
+ ) -> T: ...
309
+
310
+
311
+ @overload
312
+ async def reduce(
313
+ function: Callable[[T, T], Awaitable[T]],
314
+ iterable: Iterable[T] | AsyncIterable[T],
315
+ /,
316
+ ) -> T: ...
317
+
318
+
319
+ async def reduce( # type: ignore[misc]
320
+ function: Callable[[T, T], Awaitable[T]] | Callable[[T, S], Awaitable[T]],
321
+ iterable: Iterable[T] | Iterable[S] | AsyncIterable[T] | AsyncIterable[S],
322
+ /,
323
+ initial: T | _InitialMissingType = initial_missing,
324
+ ) -> T:
325
+ """
326
+ Asynchronous version of :func:`functools.reduce`.
327
+
328
+ :param function: a coroutine function that takes two arguments: the accumulated
329
+ value and the next element from the iterable
330
+ :param iterable: an iterable or async iterable
331
+ :param initial: the initial value (if missing, the first element of the iterable is
332
+ used as the initial value)
333
+
334
+ """
335
+ element: Any
336
+ function_called = False
337
+ if isinstance(iterable, AsyncIterable):
338
+ async_it = iterable.__aiter__()
339
+ if initial is initial_missing:
340
+ try:
341
+ value = cast(T, await async_it.__anext__())
342
+ except StopAsyncIteration:
343
+ raise TypeError(
344
+ "reduce() of empty sequence with no initial value"
345
+ ) from None
346
+ else:
347
+ value = cast(T, initial)
348
+
349
+ async for element in async_it:
350
+ value = await function(value, element)
351
+ function_called = True
352
+ elif isinstance(iterable, Iterable):
353
+ it = iter(iterable)
354
+ if initial is initial_missing:
355
+ try:
356
+ value = cast(T, next(it))
357
+ except StopIteration:
358
+ raise TypeError(
359
+ "reduce() of empty sequence with no initial value"
360
+ ) from None
361
+ else:
362
+ value = cast(T, initial)
363
+
364
+ for element in it:
365
+ value = await function(value, element)
366
+ function_called = True
367
+ else:
368
+ raise TypeError("reduce() argument 2 must be an iterable or async iterable")
369
+
370
+ # Make sure there is at least one checkpoint, even if an empty iterable and an
371
+ # initial value were given
372
+ if not function_called:
373
+ await checkpoint()
374
+
375
+ return value
venv/lib/python3.10/site-packages/anyio/lowlevel.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ __all__ = (
4
+ "EventLoopToken",
5
+ "RunvarToken",
6
+ "RunVar",
7
+ "checkpoint",
8
+ "checkpoint_if_cancelled",
9
+ "cancel_shielded_checkpoint",
10
+ "current_token",
11
+ )
12
+
13
+ import enum
14
+ from dataclasses import dataclass
15
+ from types import TracebackType
16
+ from typing import Any, Generic, Literal, TypeVar, final, overload
17
+ from weakref import WeakKeyDictionary
18
+
19
+ from ._core._eventloop import get_async_backend
20
+ from .abc import AsyncBackend
21
+
22
+ T = TypeVar("T")
23
+ D = TypeVar("D")
24
+
25
+
26
+ async def checkpoint() -> None:
27
+ """
28
+ Check for cancellation and allow the scheduler to switch to another task.
29
+
30
+ Equivalent to (but more efficient than)::
31
+
32
+ await checkpoint_if_cancelled()
33
+ await cancel_shielded_checkpoint()
34
+
35
+ .. versionadded:: 3.0
36
+
37
+ """
38
+ await get_async_backend().checkpoint()
39
+
40
+
41
+ async def checkpoint_if_cancelled() -> None:
42
+ """
43
+ Enter a checkpoint if the enclosing cancel scope has been cancelled.
44
+
45
+ This does not allow the scheduler to switch to a different task.
46
+
47
+ .. versionadded:: 3.0
48
+
49
+ """
50
+ await get_async_backend().checkpoint_if_cancelled()
51
+
52
+
53
+ async def cancel_shielded_checkpoint() -> None:
54
+ """
55
+ Allow the scheduler to switch to another task but without checking for cancellation.
56
+
57
+ Equivalent to (but potentially more efficient than)::
58
+
59
+ with CancelScope(shield=True):
60
+ await checkpoint()
61
+
62
+ .. versionadded:: 3.0
63
+
64
+ """
65
+ await get_async_backend().cancel_shielded_checkpoint()
66
+
67
+
68
+ @final
69
+ @dataclass(frozen=True, repr=False)
70
+ class EventLoopToken:
71
+ """
72
+ An opaque object that holds a reference to an event loop.
73
+
74
+ .. versionadded:: 4.11.0
75
+ """
76
+
77
+ backend_class: type[AsyncBackend]
78
+ native_token: object
79
+
80
+
81
+ def current_token() -> EventLoopToken:
82
+ """
83
+ Return a token object that can be used to call code in the current event loop from
84
+ another thread.
85
+
86
+ :raises NoEventLoopError: if no supported asynchronous event loop is running in the
87
+ current thread
88
+
89
+ .. versionadded:: 4.11.0
90
+
91
+ """
92
+ backend_class = get_async_backend()
93
+ raw_token = backend_class.current_token()
94
+ return EventLoopToken(backend_class, raw_token)
95
+
96
+
97
+ _run_vars: WeakKeyDictionary[object, dict[RunVar[Any], Any]] = WeakKeyDictionary()
98
+
99
+
100
+ class _NoValueSet(enum.Enum):
101
+ NO_VALUE_SET = enum.auto()
102
+
103
+
104
+ class RunvarToken(Generic[T]):
105
+ __slots__ = "_var", "_value", "_redeemed"
106
+
107
+ def __init__(self, var: RunVar[T], value: T | Literal[_NoValueSet.NO_VALUE_SET]):
108
+ self._var = var
109
+ self._value: T | Literal[_NoValueSet.NO_VALUE_SET] = value
110
+ self._redeemed = False
111
+
112
+ def __enter__(self) -> RunvarToken[T]:
113
+ return self
114
+
115
+ def __exit__(
116
+ self,
117
+ exc_type: type[BaseException] | None,
118
+ exc_val: BaseException | None,
119
+ exc_tb: TracebackType | None,
120
+ ) -> None:
121
+ self._var.reset(self)
122
+
123
+
124
+ class RunVar(Generic[T]):
125
+ """
126
+ Like a :class:`~contextvars.ContextVar`, except scoped to the running event loop.
127
+
128
+ Can be used as a context manager, Just like :class:`~contextvars.ContextVar`, that
129
+ will reset the variable to its previous value when the context block is exited.
130
+ """
131
+
132
+ __slots__ = "_name", "_default"
133
+
134
+ NO_VALUE_SET: Literal[_NoValueSet.NO_VALUE_SET] = _NoValueSet.NO_VALUE_SET
135
+
136
+ def __init__(
137
+ self, name: str, default: T | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET
138
+ ):
139
+ self._name = name
140
+ self._default = default
141
+
142
+ @property
143
+ def _current_vars(self) -> dict[RunVar[T], T]:
144
+ native_token = current_token().native_token
145
+ try:
146
+ return _run_vars[native_token]
147
+ except KeyError:
148
+ run_vars = _run_vars[native_token] = {}
149
+ return run_vars
150
+
151
+ @overload
152
+ def get(self, default: D) -> T | D: ...
153
+
154
+ @overload
155
+ def get(self) -> T: ...
156
+
157
+ def get(
158
+ self, default: D | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET
159
+ ) -> T | D:
160
+ try:
161
+ return self._current_vars[self]
162
+ except KeyError:
163
+ if default is not RunVar.NO_VALUE_SET:
164
+ return default
165
+ elif self._default is not RunVar.NO_VALUE_SET:
166
+ return self._default
167
+
168
+ raise LookupError(
169
+ f'Run variable "{self._name}" has no value and no default set'
170
+ )
171
+
172
+ def set(self, value: T) -> RunvarToken[T]:
173
+ current_vars = self._current_vars
174
+ token = RunvarToken(self, current_vars.get(self, RunVar.NO_VALUE_SET))
175
+ current_vars[self] = value
176
+ return token
177
+
178
+ def reset(self, token: RunvarToken[T]) -> None:
179
+ if token._var is not self:
180
+ raise ValueError("This token does not belong to this RunVar")
181
+
182
+ if token._redeemed:
183
+ raise ValueError("This token has already been used")
184
+
185
+ if token._value is _NoValueSet.NO_VALUE_SET:
186
+ try:
187
+ del self._current_vars[self]
188
+ except KeyError:
189
+ pass
190
+ else:
191
+ self._current_vars[self] = token._value
192
+
193
+ token._redeemed = True
194
+
195
+ def __repr__(self) -> str:
196
+ return f"<RunVar name={self._name!r}>"
venv/lib/python3.10/site-packages/anyio/py.typed ADDED
File without changes
venv/lib/python3.10/site-packages/anyio/pytest_plugin.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import socket
4
+ import sys
5
+ from collections.abc import Callable, Generator, Iterator
6
+ from contextlib import ExitStack, contextmanager
7
+ from inspect import isasyncgenfunction, iscoroutinefunction, ismethod
8
+ from typing import Any, cast
9
+
10
+ import pytest
11
+ from _pytest.fixtures import SubRequest
12
+ from _pytest.outcomes import Exit
13
+
14
+ from . import get_available_backends
15
+ from ._core._eventloop import (
16
+ current_async_library,
17
+ get_async_backend,
18
+ reset_current_async_library,
19
+ set_current_async_library,
20
+ )
21
+ from ._core._exceptions import iterate_exceptions
22
+ from .abc import TestRunner
23
+
24
+ if sys.version_info < (3, 11):
25
+ from exceptiongroup import ExceptionGroup
26
+
27
+ _current_runner: TestRunner | None = None
28
+ _runner_stack: ExitStack | None = None
29
+ _runner_leases = 0
30
+
31
+
32
+ def extract_backend_and_options(backend: object) -> tuple[str, dict[str, Any]]:
33
+ if isinstance(backend, str):
34
+ return backend, {}
35
+ elif isinstance(backend, tuple) and len(backend) == 2:
36
+ if isinstance(backend[0], str) and isinstance(backend[1], dict):
37
+ return cast(tuple[str, dict[str, Any]], backend)
38
+
39
+ raise TypeError("anyio_backend must be either a string or tuple of (string, dict)")
40
+
41
+
42
+ @contextmanager
43
+ def get_runner(
44
+ backend_name: str, backend_options: dict[str, Any]
45
+ ) -> Iterator[TestRunner]:
46
+ global _current_runner, _runner_leases, _runner_stack
47
+ if _current_runner is None:
48
+ asynclib = get_async_backend(backend_name)
49
+ _runner_stack = ExitStack()
50
+ if current_async_library() is None:
51
+ # Since we're in control of the event loop, we can cache the name of the
52
+ # async library
53
+ token = set_current_async_library(backend_name)
54
+ _runner_stack.callback(reset_current_async_library, token)
55
+
56
+ backend_options = backend_options or {}
57
+ _current_runner = _runner_stack.enter_context(
58
+ asynclib.create_test_runner(backend_options)
59
+ )
60
+
61
+ _runner_leases += 1
62
+ try:
63
+ yield _current_runner
64
+ finally:
65
+ _runner_leases -= 1
66
+ if not _runner_leases:
67
+ assert _runner_stack is not None
68
+ _runner_stack.close()
69
+ _runner_stack = _current_runner = None
70
+
71
+
72
+ def pytest_addoption(parser: pytest.Parser) -> None:
73
+ parser.addini(
74
+ "anyio_mode",
75
+ default="strict",
76
+ help='AnyIO plugin mode (either "strict" or "auto")',
77
+ )
78
+
79
+
80
+ def pytest_configure(config: pytest.Config) -> None:
81
+ config.addinivalue_line(
82
+ "markers",
83
+ "anyio: mark the (coroutine function) test to be run asynchronously via anyio.",
84
+ )
85
+ if (
86
+ config.getini("anyio_mode") == "auto"
87
+ and config.pluginmanager.has_plugin("asyncio")
88
+ and config.getini("asyncio_mode") == "auto"
89
+ ):
90
+ config.issue_config_time_warning(
91
+ pytest.PytestConfigWarning(
92
+ "AnyIO auto mode has been enabled together with pytest-asyncio auto "
93
+ "mode. This may cause unexpected behavior."
94
+ ),
95
+ 1,
96
+ )
97
+
98
+
99
+ @pytest.hookimpl(hookwrapper=True)
100
+ def pytest_fixture_setup(fixturedef: Any, request: Any) -> Generator[Any]:
101
+ def wrapper(anyio_backend: Any, request: SubRequest, **kwargs: Any) -> Any:
102
+ # Rebind any fixture methods to the request instance
103
+ if (
104
+ request.instance
105
+ and ismethod(func)
106
+ and type(func.__self__) is type(request.instance)
107
+ ):
108
+ local_func = func.__func__.__get__(request.instance)
109
+ else:
110
+ local_func = func
111
+
112
+ backend_name, backend_options = extract_backend_and_options(anyio_backend)
113
+ if has_backend_arg:
114
+ kwargs["anyio_backend"] = anyio_backend
115
+
116
+ if has_request_arg:
117
+ kwargs["request"] = request
118
+
119
+ with get_runner(backend_name, backend_options) as runner:
120
+ if isasyncgenfunction(local_func):
121
+ yield from runner.run_asyncgen_fixture(local_func, kwargs)
122
+ else:
123
+ yield runner.run_fixture(local_func, kwargs)
124
+
125
+ # Only apply this to coroutine functions and async generator functions in requests
126
+ # that involve the anyio_backend fixture
127
+ func = fixturedef.func
128
+ if isasyncgenfunction(func) or iscoroutinefunction(func):
129
+ if "anyio_backend" in request.fixturenames:
130
+ fixturedef.func = wrapper
131
+ original_argname = fixturedef.argnames
132
+
133
+ if not (has_backend_arg := "anyio_backend" in fixturedef.argnames):
134
+ fixturedef.argnames += ("anyio_backend",)
135
+
136
+ if not (has_request_arg := "request" in fixturedef.argnames):
137
+ fixturedef.argnames += ("request",)
138
+
139
+ try:
140
+ return (yield)
141
+ finally:
142
+ fixturedef.func = func
143
+ fixturedef.argnames = original_argname
144
+
145
+ return (yield)
146
+
147
+
148
+ @pytest.hookimpl(tryfirst=True)
149
+ def pytest_pycollect_makeitem(
150
+ collector: pytest.Module | pytest.Class, name: str, obj: object
151
+ ) -> None:
152
+ if collector.istestfunction(obj, name):
153
+ inner_func = obj.hypothesis.inner_test if hasattr(obj, "hypothesis") else obj
154
+ if iscoroutinefunction(inner_func):
155
+ anyio_auto_mode = collector.config.getini("anyio_mode") == "auto"
156
+ marker = collector.get_closest_marker("anyio")
157
+ own_markers = getattr(obj, "pytestmark", ())
158
+ if (
159
+ anyio_auto_mode
160
+ or marker
161
+ or any(marker.name == "anyio" for marker in own_markers)
162
+ ):
163
+ pytest.mark.usefixtures("anyio_backend")(obj)
164
+
165
+
166
+ @pytest.hookimpl(tryfirst=True)
167
+ def pytest_pyfunc_call(pyfuncitem: Any) -> bool | None:
168
+ def run_with_hypothesis(**kwargs: Any) -> None:
169
+ with get_runner(backend_name, backend_options) as runner:
170
+ runner.run_test(original_func, kwargs)
171
+
172
+ backend = pyfuncitem.funcargs.get("anyio_backend")
173
+ if backend:
174
+ backend_name, backend_options = extract_backend_and_options(backend)
175
+
176
+ if hasattr(pyfuncitem.obj, "hypothesis"):
177
+ # Wrap the inner test function unless it's already wrapped
178
+ original_func = pyfuncitem.obj.hypothesis.inner_test
179
+ if original_func.__qualname__ != run_with_hypothesis.__qualname__:
180
+ if iscoroutinefunction(original_func):
181
+ pyfuncitem.obj.hypothesis.inner_test = run_with_hypothesis
182
+
183
+ return None
184
+
185
+ if iscoroutinefunction(pyfuncitem.obj):
186
+ funcargs = pyfuncitem.funcargs
187
+ testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
188
+ with get_runner(backend_name, backend_options) as runner:
189
+ try:
190
+ runner.run_test(pyfuncitem.obj, testargs)
191
+ except ExceptionGroup as excgrp:
192
+ for exc in iterate_exceptions(excgrp):
193
+ if isinstance(exc, (Exit, KeyboardInterrupt, SystemExit)):
194
+ raise exc from excgrp
195
+
196
+ raise
197
+
198
+ return True
199
+
200
+ return None
201
+
202
+
203
+ @pytest.fixture(scope="module", params=get_available_backends())
204
+ def anyio_backend(request: Any) -> Any:
205
+ return request.param
206
+
207
+
208
+ @pytest.fixture
209
+ def anyio_backend_name(anyio_backend: Any) -> str:
210
+ if isinstance(anyio_backend, str):
211
+ return anyio_backend
212
+ else:
213
+ return anyio_backend[0]
214
+
215
+
216
+ @pytest.fixture
217
+ def anyio_backend_options(anyio_backend: Any) -> dict[str, Any]:
218
+ if isinstance(anyio_backend, str):
219
+ return {}
220
+ else:
221
+ return anyio_backend[1]
222
+
223
+
224
+ class FreePortFactory:
225
+ """
226
+ Manages port generation based on specified socket kind, ensuring no duplicate
227
+ ports are generated.
228
+
229
+ This class provides functionality for generating available free ports on the
230
+ system. It is initialized with a specific socket kind and can generate ports
231
+ for given address families while avoiding reuse of previously generated ports.
232
+
233
+ Users should not instantiate this class directly, but use the
234
+ ``free_tcp_port_factory`` and ``free_udp_port_factory`` fixtures instead. For simple
235
+ uses cases, ``free_tcp_port`` and ``free_udp_port`` can be used instead.
236
+ """
237
+
238
+ def __init__(self, kind: socket.SocketKind) -> None:
239
+ self._kind = kind
240
+ self._generated = set[int]()
241
+
242
+ @property
243
+ def kind(self) -> socket.SocketKind:
244
+ """
245
+ The type of socket connection (e.g., :data:`~socket.SOCK_STREAM` or
246
+ :data:`~socket.SOCK_DGRAM`) used to bind for checking port availability
247
+
248
+ """
249
+ return self._kind
250
+
251
+ def __call__(self, family: socket.AddressFamily | None = None) -> int:
252
+ """
253
+ Return an unbound port for the given address family.
254
+
255
+ :param family: if omitted, both IPv4 and IPv6 addresses will be tried
256
+ :return: a port number
257
+
258
+ """
259
+ if family is not None:
260
+ families = [family]
261
+ else:
262
+ families = [socket.AF_INET]
263
+ if socket.has_ipv6:
264
+ families.append(socket.AF_INET6)
265
+
266
+ while True:
267
+ port = 0
268
+ with ExitStack() as stack:
269
+ for family in families:
270
+ sock = stack.enter_context(socket.socket(family, self._kind))
271
+ addr = "::1" if family == socket.AF_INET6 else "127.0.0.1"
272
+ try:
273
+ sock.bind((addr, port))
274
+ except OSError:
275
+ break
276
+
277
+ if not port:
278
+ port = sock.getsockname()[1]
279
+ else:
280
+ if port not in self._generated:
281
+ self._generated.add(port)
282
+ return port
283
+
284
+
285
+ @pytest.fixture(scope="session")
286
+ def free_tcp_port_factory() -> FreePortFactory:
287
+ return FreePortFactory(socket.SOCK_STREAM)
288
+
289
+
290
+ @pytest.fixture(scope="session")
291
+ def free_udp_port_factory() -> FreePortFactory:
292
+ return FreePortFactory(socket.SOCK_DGRAM)
293
+
294
+
295
+ @pytest.fixture
296
+ def free_tcp_port(free_tcp_port_factory: Callable[[], int]) -> int:
297
+ return free_tcp_port_factory()
298
+
299
+
300
+ @pytest.fixture
301
+ def free_udp_port(free_udp_port_factory: Callable[[], int]) -> int:
302
+ return free_udp_port_factory()
venv/lib/python3.10/site-packages/anyio/to_interpreter.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ __all__ = (
4
+ "run_sync",
5
+ "current_default_interpreter_limiter",
6
+ )
7
+
8
+ import atexit
9
+ import os
10
+ import sys
11
+ from collections import deque
12
+ from collections.abc import Callable
13
+ from typing import Any, Final, TypeVar
14
+
15
+ from . import current_time, to_thread
16
+ from ._core._exceptions import BrokenWorkerInterpreter
17
+ from ._core._synchronization import CapacityLimiter
18
+ from .lowlevel import RunVar
19
+
20
+ if sys.version_info >= (3, 11):
21
+ from typing import TypeVarTuple, Unpack
22
+ else:
23
+ from typing_extensions import TypeVarTuple, Unpack
24
+
25
+ if sys.version_info >= (3, 14):
26
+ from concurrent.interpreters import ExecutionFailed, create
27
+
28
+ def _interp_call(
29
+ func: Callable[..., Any], args: tuple[Any, ...]
30
+ ) -> tuple[Any, bool]:
31
+ try:
32
+ retval = func(*args)
33
+ except BaseException as exc:
34
+ return exc, True
35
+ else:
36
+ return retval, False
37
+
38
+ class _Worker:
39
+ last_used: float = 0
40
+
41
+ def __init__(self) -> None:
42
+ self._interpreter = create()
43
+
44
+ def destroy(self) -> None:
45
+ self._interpreter.close()
46
+
47
+ def call(
48
+ self,
49
+ func: Callable[..., T_Retval],
50
+ args: tuple[Any, ...],
51
+ ) -> T_Retval:
52
+ try:
53
+ res, is_exception = self._interpreter.call(_interp_call, func, args)
54
+ except ExecutionFailed as exc:
55
+ raise BrokenWorkerInterpreter(exc.excinfo) from exc
56
+
57
+ if is_exception:
58
+ raise res
59
+
60
+ return res
61
+ elif sys.version_info >= (3, 13):
62
+ import _interpqueues
63
+ import _interpreters
64
+
65
+ UNBOUND: Final = 2 # I have no clue how this works, but it was used in the stdlib
66
+ FMT_UNPICKLED: Final = 0
67
+ FMT_PICKLED: Final = 1
68
+ QUEUE_PICKLE_ARGS: Final = (FMT_PICKLED, UNBOUND)
69
+ QUEUE_UNPICKLE_ARGS: Final = (FMT_UNPICKLED, UNBOUND)
70
+
71
+ _run_func = compile(
72
+ """
73
+ import _interpqueues
74
+ from _interpreters import NotShareableError
75
+ from pickle import loads, dumps, HIGHEST_PROTOCOL
76
+
77
+ QUEUE_PICKLE_ARGS = (1, 2)
78
+ QUEUE_UNPICKLE_ARGS = (0, 2)
79
+
80
+ item = _interpqueues.get(queue_id)[0]
81
+ try:
82
+ func, args = loads(item)
83
+ retval = func(*args)
84
+ except BaseException as exc:
85
+ is_exception = True
86
+ retval = exc
87
+ else:
88
+ is_exception = False
89
+
90
+ try:
91
+ _interpqueues.put(queue_id, (retval, is_exception), *QUEUE_UNPICKLE_ARGS)
92
+ except NotShareableError:
93
+ retval = dumps(retval, HIGHEST_PROTOCOL)
94
+ _interpqueues.put(queue_id, (retval, is_exception), *QUEUE_PICKLE_ARGS)
95
+ """,
96
+ "<string>",
97
+ "exec",
98
+ )
99
+
100
+ class _Worker:
101
+ last_used: float = 0
102
+
103
+ def __init__(self) -> None:
104
+ self._interpreter_id = _interpreters.create()
105
+ self._queue_id = _interpqueues.create(1, *QUEUE_UNPICKLE_ARGS)
106
+ _interpreters.set___main___attrs(
107
+ self._interpreter_id, {"queue_id": self._queue_id}
108
+ )
109
+
110
+ def destroy(self) -> None:
111
+ _interpqueues.destroy(self._queue_id)
112
+ _interpreters.destroy(self._interpreter_id)
113
+
114
+ def call(
115
+ self,
116
+ func: Callable[..., T_Retval],
117
+ args: tuple[Any, ...],
118
+ ) -> T_Retval:
119
+ import pickle
120
+
121
+ item = pickle.dumps((func, args), pickle.HIGHEST_PROTOCOL)
122
+ _interpqueues.put(self._queue_id, item, *QUEUE_PICKLE_ARGS)
123
+ exc_info = _interpreters.exec(self._interpreter_id, _run_func)
124
+ if exc_info:
125
+ raise BrokenWorkerInterpreter(exc_info)
126
+
127
+ res = _interpqueues.get(self._queue_id)
128
+ (res, is_exception), fmt = res[:2]
129
+ if fmt == FMT_PICKLED:
130
+ res = pickle.loads(res)
131
+
132
+ if is_exception:
133
+ raise res
134
+
135
+ return res
136
+ else:
137
+
138
+ class _Worker:
139
+ last_used: float = 0
140
+
141
+ def __init__(self) -> None:
142
+ raise RuntimeError("subinterpreters require at least Python 3.13")
143
+
144
+ def call(
145
+ self,
146
+ func: Callable[..., T_Retval],
147
+ args: tuple[Any, ...],
148
+ ) -> T_Retval:
149
+ raise NotImplementedError
150
+
151
+ def destroy(self) -> None:
152
+ pass
153
+
154
+
155
+ DEFAULT_CPU_COUNT: Final = 8 # this is just an arbitrarily selected value
156
+ MAX_WORKER_IDLE_TIME = (
157
+ 30 # seconds a subinterpreter can be idle before becoming eligible for pruning
158
+ )
159
+
160
+ T_Retval = TypeVar("T_Retval")
161
+ PosArgsT = TypeVarTuple("PosArgsT")
162
+
163
+ _idle_workers = RunVar[deque[_Worker]]("_available_workers")
164
+ _default_interpreter_limiter = RunVar[CapacityLimiter]("_default_interpreter_limiter")
165
+
166
+
167
+ def _stop_workers(workers: deque[_Worker]) -> None:
168
+ for worker in workers:
169
+ worker.destroy()
170
+
171
+ workers.clear()
172
+
173
+
174
+ async def run_sync(
175
+ func: Callable[[Unpack[PosArgsT]], T_Retval],
176
+ *args: Unpack[PosArgsT],
177
+ limiter: CapacityLimiter | None = None,
178
+ ) -> T_Retval:
179
+ """
180
+ Call the given function with the given arguments in a subinterpreter.
181
+
182
+ .. warning:: On Python 3.13, the :mod:`concurrent.interpreters` module was not yet
183
+ available, so the code path for that Python version relies on an undocumented,
184
+ private API. As such, it is recommended to not rely on this function for anything
185
+ mission-critical on Python 3.13.
186
+
187
+ :param func: a callable
188
+ :param args: the positional arguments for the callable
189
+ :param limiter: capacity limiter to use to limit the total number of subinterpreters
190
+ running (if omitted, the default limiter is used)
191
+ :return: the result of the call
192
+ :raises BrokenWorkerInterpreter: if there's an internal error in a subinterpreter
193
+
194
+ """
195
+ if limiter is None:
196
+ limiter = current_default_interpreter_limiter()
197
+
198
+ try:
199
+ idle_workers = _idle_workers.get()
200
+ except LookupError:
201
+ idle_workers = deque()
202
+ _idle_workers.set(idle_workers)
203
+ atexit.register(_stop_workers, idle_workers)
204
+
205
+ async with limiter:
206
+ try:
207
+ worker = idle_workers.pop()
208
+ except IndexError:
209
+ worker = _Worker()
210
+
211
+ try:
212
+ return await to_thread.run_sync(
213
+ worker.call,
214
+ func,
215
+ args,
216
+ limiter=limiter,
217
+ )
218
+ finally:
219
+ # Prune workers that have been idle for too long
220
+ now = current_time()
221
+ while idle_workers:
222
+ if now - idle_workers[0].last_used <= MAX_WORKER_IDLE_TIME:
223
+ break
224
+
225
+ await to_thread.run_sync(idle_workers.popleft().destroy, limiter=limiter)
226
+
227
+ worker.last_used = current_time()
228
+ idle_workers.append(worker)
229
+
230
+
231
+ def current_default_interpreter_limiter() -> CapacityLimiter:
232
+ """
233
+ Return the capacity limiter used by default to limit the number of concurrently
234
+ running subinterpreters.
235
+
236
+ Defaults to the number of CPU cores.
237
+
238
+ :return: a capacity limiter object
239
+
240
+ """
241
+ try:
242
+ return _default_interpreter_limiter.get()
243
+ except LookupError:
244
+ limiter = CapacityLimiter(os.cpu_count() or DEFAULT_CPU_COUNT)
245
+ _default_interpreter_limiter.set(limiter)
246
+ return limiter
venv/lib/python3.10/site-packages/typing_extensions.py ADDED
The diff for this file is too large to render. See raw diff
 
venv/pyvenv.cfg ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ home = /usr/bin
2
+ include-system-site-packages = false
3
+ version = 3.10.12