Add Flux2 Klein compiled caches
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +14 -0
- meta.json +44 -0
- torchinductor/2h/a581feca05a976cd76073f2f954a7641097b9c5775b12cf6831b3149d528a8b4.best_config +1 -0
- torchinductor/2h/c2hij3hmloumxdmhuezsyhkmnqgnfa5ivre27uosymam3dr7a5xb.py +70 -0
- torchinductor/2o/c2oduffhka4c52657rppatcdtgtnibm42qywfo2spmul2dpsj6jj.py +297 -0
- torchinductor/3i/c3imyfibcq3zwrc5gvfscvpsaxdzjijymrnt2lfahtrjmrbtlmhe.py +45 -0
- torchinductor/3i/cf1587a2fd240ce39177274973308f6fd100d746bf6716a8d96ed4fd12c89d55.best_config +1 -0
- torchinductor/3v/4a00da1b5d4ce251d2cb392c24118fc2e6c3818f25b8457665f0d53e12234277.best_config +1 -0
- torchinductor/3v/c3vjilvcy7sdqcaxfspffadbsry3sqm4j7qp6u3tusr6p34sbiar.py +28 -0
- torchinductor/4y/c4ykjyk6fv6enet6mgkj5bsan42tc6rsdfs7aaskpjgv5rzw7tbr.py +357 -0
- torchinductor/6k/abd9e26dfce6bf628201c09f1f90f4340fdaab3cc2dd99f7186afe82fe013d1a.best_config +1 -0
- torchinductor/6k/c6kat5g7n3uukkfwgdxfwtmxblcmqzhbifo5g3rbaz3djxii35gi.py +30 -0
- torchinductor/6w/4fb0f9adeff50e9452e8fd238a1808052c095c59a0b2f1d9f3f7d7106bd1ede5.best_config +1 -0
- torchinductor/6w/c6w6sg4v3bcighwokzq6i43tl5xk5vz7zevuk7jhvlqyryyuhueo.py +33 -0
- torchinductor/7f/be95397d0c18f43f4314e0cac66d456d9d3e2b12116963a4bf988016e97f7a5e.best_config +1 -0
- torchinductor/7f/c7ff4ib6652ojllutm4c7mkzzpybond3pagu3glspw3sztkfe2za.py +30 -0
- torchinductor/a3/94dc88253134d772dc28ed260760d9a0059b054d472700be3c22dd06b228f22f.best_config +1 -0
- torchinductor/a3/ca3menlfuldthgmncfpjk452xkros7idrmil6pcoeigraymcg4e6.py +78 -0
- torchinductor/aotautograd/a27rkqg32yfaub3aygtms2gl3oet2qxfcnp4zxa3zy5h6c3risxz/aw5eda3h36wpnnltujgkb4mvobznersd4fuvo2p7vy2quujasos +0 -0
- torchinductor/aotautograd/a3443o3ywoehrda4trn5q47mauudwcinftvd52hitdnfmakyhqc4/lw6yvpbd45y77sg6fh5v4otbinchkwbf7b56u3rh3wgq3x2wkhq +0 -0
- torchinductor/aotautograd/a3554ihbxq57jan4ib74iqo5mnaqevqume4yzewukzkm6ehpsilz/eubahghkef62rmchvnle5v6h3ddip4av5qqjxomdlm7ura45qve +0 -0
- torchinductor/aotautograd/a3hojixb5fzn7f7jfco3ddoohdsuggk4qbop3lcg7rjy3e7fkgfz/o7wvolbgborwtoofbovayor23y4ubooymfcvv6jeqm2wbx3n2cs +0 -0
- torchinductor/aotautograd/a54twb2qknddjxnxtmkoagy3umo5y3ptsesm2pdhy7nkefklf6wx/emxzj524wmpvifsxw4dsnnkzemqpzfgkenbo5obwmksvlhsr354 +0 -0
- torchinductor/aotautograd/a5ksywxhfabbequvxwstheyyj5w3sinuubxcrypqjwbqsyw5la3l/ew4fxjyfoflznyuws2w2ylu4p7owpjuqoshsef75w43w2vvwejd +0 -0
- torchinductor/aotautograd/a7ptufzlocphh5n5o5u63gfzkf74tjb3l5is45u5hqjspv32qda6/an4kgppgf4vt5yfvvghrnmho6jc3qnj4l6c75zrsiotr5d4u5gv +0 -0
- torchinductor/aotautograd/aal6kceyfi7eazavxzpgcec5hzt32bkwo7p4doeyc56ubzlwuvx4/nkoni3ckgbheucucq64bmrta4lhz7x237lalaqcrejvdc3supg4 +0 -0
- torchinductor/aotautograd/aan5kpy6i54rnpeu5vlzbx6i6blimsvhducl7futzdjr4xciy472/a35s4usnkzmh6ybhedo3b6zehfepmwdv2gxscayjeeuucr3zat7 +0 -0
- torchinductor/aotautograd/aesonb7djseswkbtu2qzhvg6ikd5rewxnqlt6pwuytadpxxmjcod/lap2sypphhofd6d5rhojruk2vfyvw2olc7gtulmom4i5y7ix2cp +0 -0
- torchinductor/aotautograd/age65c4dyk2rxcqufpxd6bsafzao7tacrsvejbf3pjbsngnoashv/upzttal3jaj233iyzyps7mjpq75jt6qi6rzramvgyyewfg76h6s +0 -0
- torchinductor/aotautograd/ahkpwjcp2qqyj6wu2ckjqlrit2pbb3ig3ddi75hgbkgngvvipwyq/ha76p7wv3nimmrgvx6kdiqikd6adbw7nlnaiars5ey4anx46mwn +0 -0
- torchinductor/aotautograd/aiojzczi5txclvaydkrk5g3qlf33pdkkhxtefkhfphkpc3o6rr4p/w3n37k3qhqfhuewneurnairyblp3h7nrak6oyp2p3um7uwnfcz5 +0 -0
- torchinductor/aotautograd/ajdkg3gacw25klanvqotc3mkab3mi23jtjpagxrosdmqv3d4yg7v/ejzrqbsrchqzxfppkzo4ep7edhv7lrjjbcdxkxvodbk4vvk3b62 +0 -0
- torchinductor/aotautograd/amb262dx57ptj6gg2ch6skr372w6arsr3i7i4ed5pljhiycuxduw/fntav2w4z5lvr443jxseqalau2vuzp7x7ljd3hanoqubtutjkvp +0 -0
- torchinductor/aotautograd/amjjivi2p6firai3idkjgfxyy6z4prevujsjdno2uuchwvd7xqll/enc6ruqcyggs4mnt54tjdd2lvexcvipd5vhhamxwcj77g5fpyof +0 -0
- torchinductor/aotautograd/apfaqlwe555qd2zoz575w5mvoxoiasmcomkv76mhz5zvnm5jok66/epmli5r46rzrqf73pqrnb5tratdg3mbbwdf5vyzqr6ejyhnooye +0 -0
- torchinductor/aotautograd/asjbg7f735jw54kcldmvv5uost22wzpy3hkxgaihos4rllvagheu/lwqpsnp52rszp2nlwkgi33embno5st2u5bxfm4rpyoy6fql5aor +0 -0
- torchinductor/aotautograd/atc2ggqhejcse5aydwh2wjakijsc2dyhqjxwdqrwpra3mgjwe4st/xwy7lzraqocjillvk4s2yc2qhpkx43s2nbkxmeb2wpph3sgyc7n +0 -0
- torchinductor/aotautograd/atsevoi6zqdcnehuxassvjosi3j5vrk54uisibylfgspeewp6vyx/4sfzv7d6ch2yoi6nnr5ym3i6yibku3vfveyrr6sx6dqbmavxo32 +0 -0
- torchinductor/aotautograd/ax7bbwqbruobasu7vagn2oj2owh5vgosxbjelta324rvf4tkesd4/ipnutob47ydixp2zetluyw4apg7fe5sfkkiianwaawh6yq3uang +0 -0
- torchinductor/aotautograd/ay26zyuzpll2prvy7zzoeydo7r47lrr6s6jcmzi2zmytjxzebmnz/nzx7lukg3r25p6sjlwtqmkf6gmgzuq7iwagwki2x4kvhw5ducr5 +0 -0
- torchinductor/aotautograd/ay65riayezoo7bqggl72pzrzdi6lvy5mp23ajx4f453ylzpmve3s/p7clvcke3bsgsaumutstrxc7bkq4tq6yoia7nwigana3n3unini +0 -0
- torchinductor/aotautograd/azyih32olvhzuay5zpfypzhk2cdlosvaqxdhcnjzlwfs6k3a2ne6/5sz2kjdze7ixdny7hz24p4uma7uup7chdcpiumqznifqn4mpmqb +0 -0
- torchinductor/av/cavoaz6e7kbk5wq2n7vz6rxhcrwdu2trazexubdq5qwyv2ajmbkz.py +73 -0
- torchinductor/av/d186a24d3c8af5514b42dea48fc981efd3f5afb7bba6c30406e42c75862888b1.best_config +1 -0
- torchinductor/ay/cayicsdjyjxzpcmkvjbneubnqkuhs3y37qiwy5qlel3z2loa4qav.py +69 -0
- torchinductor/bv/7969eba2eb589b95d2894ee75ee67ba01cd2bee09cd64d315c70c0950888c19e.best_config +1 -0
- torchinductor/bv/cbvqhjtyg7fvxzwtbtt4vrdkbnb6n32fnrijjpl3vv4cfqd4mznr.py +162 -0
- torchinductor/cr/ccr2gijy4jp6vvdbewmzgaogxbf5as7ytxtou4zo2yelawomrjjg.py +131 -0
- torchinductor/cz/bb6645c6be31f426023ec47eef09e354ad9fa8b2d59e6e45ab49b803eb34d44e.best_config +1 -0
- torchinductor/cz/cczg7tpituprwgqpuajzy2nylfk43mdozd5vwo77muq3kospnf7b.py +25 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
torchinductor/fxgraph/22/f22et4hxfdbezzlil53lu2pcyq6hgd3sgpjtfwxqwvo3nhcfcvx3/iqgyrwyxmlbu22glkzz24rsbjiieww3sllxux2ttk4c6gtoiuba filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
torchinductor/fxgraph/6v/f6v2dym5xl4l4b2xlv35ic4ajld4mxcvhsvdsiwx2uug77q36cad/nhugjvtrt6cm53zizhnw673hj52m5m3kegaqvkjzkf4qh6d6rhm filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
torchinductor/fxgraph/ah/fahbtdmoejcqs352pnbnedqns63nbnu6hdbrwzvf6chptnsannjh/fhpbiokcxxh7ksbfgiljcvh7erywotuv4ddvlfcb4fk2ef7dd5c filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
torchinductor/fxgraph/bn/fbnlruhvmagcngqd5is2xjbucjaq7uf3sgsbdahfi6ovtehbhzyo/62gizmmmqz43cclymnr7ftyo5qt7ux4og6bc2xazrwstkjpsy2e filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
torchinductor/fxgraph/e2/fe2tjoiexjbavh5sakfaxvga43vsvwn5ev5bzhfjg76jvmtjqtbn/ejg3u4qymaxsvvl2vdequli7pwsrdjf5zdgqjkgbrxdsvgfv3h4 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
torchinductor/fxgraph/fr/ffrx7clryowwzulnhruopihutvaxlycymqopsyoha6yecifyw2m2/g3wh462wylribiwz4th3gnlt5rtnrcb5bkad6w3yucxodv3q5ks filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
torchinductor/fxgraph/k6/fk6cfyjfeiu7xe6ebkapsnixuplqczgfc5534mitqsfkssbzjyak/4xid4w6sg2yg7xaseouf2vwhp2fyff56a2t6z6ownb3yw3g25rk filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
torchinductor/fxgraph/kj/fkjh2kykxecmnv6oe3zzwtjpek77nmrm35vgv2daxfgkim6xfk4u/gigbnwpmixz5epksvvrh4mtg3nxlpy724eojb3ajzvtepgvx7y4 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
torchinductor/fxgraph/n6/fn6x7m44e35jdmh6iqj3eqiyrz7tbhzd3rqartt67myyrnickjmp/um3sgirsxogup4murdiaoy7dxu4ogolqsa343kh56kq24zd53fb filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
torchinductor/fxgraph/te/fte5y7bccssideiluepvpscj6srf7orxnfgql6to32ni27zf2uv2/vmrpdvw3meiqsf22oras6imorrybogkgp6jjr3ddtreaiuutais filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
torchinductor/fxgraph/tz/ftzd5ordyehsowurkwjjpkso24gayhyplcd6wz7xdv53fad276l6/36luuy7klcmb7554z63umrezysn6xbat5wxfaadxh4clxhcn2j7 filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
torchinductor/fxgraph/u4/fu47dchf76mmiajgnawm3xgek4ysnnmqaavupgy4cddyxygid6iq/725pjrxppjygb6kbca6zklwmn7iunv65thw23b5s4im6zi27j3i filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
torchinductor/fxgraph/uy/fuygegwmldon4qz3wvjs3cld4hnjz6yxh6aa2cmfsal4u3xxws43/l4w2mroymid3qdffvzt4wffavpm5it6rzi6lmbvxzzezfkbavuo filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
torchinductor/fxgraph/w5/fw5vzdkweh3kv3fm3mnal4wu63gxhw2anwx2pzuved4acfz4fdzm/n5dfreyro3slkufuydw2d54nm7bfiskxjjasoyg2z5yept5c3rf filter=lfs diff=lfs merge=lfs -text
|
meta.json
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cache_layout_version": 1,
|
| 3 |
+
"created_at": "2026-01-23T07:13:39Z",
|
| 4 |
+
"model_path": "/root/.cache/huggingface/hub/models--black-forest-labs--FLUX.2-klein-9B/snapshots/cd1bba5810fe2aba6666d9cf7352e25436426039",
|
| 5 |
+
"compile_command": [
|
| 6 |
+
"/usr/bin/python",
|
| 7 |
+
"/app/tensorrt_llm/visual_gen/examples/flux2_klein_9b.py",
|
| 8 |
+
"--model_path",
|
| 9 |
+
"/root/.cache/huggingface/hub/models--black-forest-labs--FLUX.2-klein-9B/snapshots/cd1bba5810fe2aba6666d9cf7352e25436426039",
|
| 10 |
+
"--height",
|
| 11 |
+
"512",
|
| 12 |
+
"--width",
|
| 13 |
+
"1024",
|
| 14 |
+
"--num_inference_steps",
|
| 15 |
+
"4",
|
| 16 |
+
"--num_images",
|
| 17 |
+
"6",
|
| 18 |
+
"--linear_type",
|
| 19 |
+
"te-fp8-per-tensor",
|
| 20 |
+
"--fallback_linear_type",
|
| 21 |
+
"default",
|
| 22 |
+
"--torch_compile_mode",
|
| 23 |
+
"default",
|
| 24 |
+
"--offload_text_encoder"
|
| 25 |
+
],
|
| 26 |
+
"height": 512,
|
| 27 |
+
"width": 1024,
|
| 28 |
+
"num_inference_steps": 4,
|
| 29 |
+
"num_images": 6,
|
| 30 |
+
"linear_type": "te-fp8-per-tensor",
|
| 31 |
+
"fallback_linear_type": "default",
|
| 32 |
+
"torch_compile_mode": "default",
|
| 33 |
+
"offload_text_encoder": true,
|
| 34 |
+
"offload_vae": false,
|
| 35 |
+
"disable_cuda_graph": false,
|
| 36 |
+
"disable_teacache": false,
|
| 37 |
+
"torch_version": "2.10.0a0+b4e4ee81d3.nv25.12",
|
| 38 |
+
"cuda_version": "13.1",
|
| 39 |
+
"device_name": "NVIDIA GeForce RTX 4090",
|
| 40 |
+
"device_capability": [
|
| 41 |
+
8,
|
| 42 |
+
9
|
| 43 |
+
]
|
| 44 |
+
}
|
torchinductor/2h/a581feca05a976cd76073f2f954a7641097b9c5775b12cf6831b3149d528a8b4.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 64, "YBLOCK": 64, "num_warps": 8, "num_stages": 1, "configs_hash": "1ce421918d79ed0f7edb09d0ee64f016daf650a007a21866fe52d592be55380c", "found_by_coordesc": false, "time_taken_ms": 143, "triton_cache_hash": "RNNMPWWZPRYLZDDP3QNL7R5SV7EYTG7WXIUJKWKAEGE4BUI424IA"}
|
torchinductor/2h/c2hij3hmloumxdmhuezsyhkmnqgnfa5ivre27uosymam3dr7a5xb.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'y': 131072, 'x': 128}, tile_hint=TileHint.DEFAULT,
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*fp32', 'in_ptr5': '*bf16', 'out_ptr0': '*bf16', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid2DWithYZOverflow', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__fused_rms_norm_cat_view_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 6, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 589824, 'x': 75497984}},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused__fused_rms_norm_cat_view_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
|
| 19 |
+
ynumel = 73728
|
| 20 |
+
xnumel = 128
|
| 21 |
+
yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
|
| 22 |
+
yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
|
| 23 |
+
ymask = yindex < ynumel
|
| 24 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 25 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
|
| 26 |
+
xmask = xindex < xnumel
|
| 27 |
+
y1 = yindex // 32
|
| 28 |
+
x2 = xindex
|
| 29 |
+
y0 = (yindex % 32)
|
| 30 |
+
y3 = yindex
|
| 31 |
+
tmp0 = y1
|
| 32 |
+
tmp1 = tl.full([1, 1], 0, tl.int64)
|
| 33 |
+
tmp2 = tmp0 >= tmp1
|
| 34 |
+
tmp3 = tl.full([1, 1], 256, tl.int64)
|
| 35 |
+
tmp4 = tmp0 < tmp3
|
| 36 |
+
tmp5 = tl.load(in_ptr0 + (x2 + 128*y0 + 12288*(y1)), tmp4 & xmask & ymask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 37 |
+
tmp6 = tmp5.to(tl.float32)
|
| 38 |
+
tmp7 = tl.load(in_ptr1 + (tl.broadcast_to(y0 + 32*(y1), [YBLOCK, XBLOCK])), tmp4 & xmask & ymask, eviction_policy='evict_last', other=0.0)
|
| 39 |
+
tmp8 = 128.0
|
| 40 |
+
tmp9 = (tmp7 / tmp8)
|
| 41 |
+
tmp10 = 1e-06
|
| 42 |
+
tmp11 = tmp9 + tmp10
|
| 43 |
+
tmp12 = libdevice.rsqrt(tmp11)
|
| 44 |
+
tmp13 = tmp6 * tmp12
|
| 45 |
+
tmp14 = tl.load(in_ptr2 + (tl.broadcast_to(x2, [YBLOCK, XBLOCK])), tmp4 & xmask & ymask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 46 |
+
tmp15 = tmp14.to(tl.float32)
|
| 47 |
+
tmp16 = tmp13 * tmp15
|
| 48 |
+
tmp17 = tmp16.to(tl.float32)
|
| 49 |
+
tmp18 = tl.full(tmp17.shape, 0.0, tmp17.dtype)
|
| 50 |
+
tmp19 = tl.where(tmp4, tmp17, tmp18)
|
| 51 |
+
tmp20 = tmp0 >= tmp3
|
| 52 |
+
tmp21 = tl.full([1, 1], 2304, tl.int64)
|
| 53 |
+
tmp22 = tmp0 < tmp21
|
| 54 |
+
tmp23 = tl.load(in_ptr3 + (x2 + 128*y0 + 12288*((-256) + y1)), tmp20 & xmask & ymask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 55 |
+
tmp24 = tmp23.to(tl.float32)
|
| 56 |
+
tmp25 = tl.load(in_ptr4 + (tl.broadcast_to(y0 + 32*((-256) + y1), [YBLOCK, XBLOCK])), tmp20 & xmask & ymask, eviction_policy='evict_last', other=0.0)
|
| 57 |
+
tmp26 = 128.0
|
| 58 |
+
tmp27 = (tmp25 / tmp26)
|
| 59 |
+
tmp28 = 1e-06
|
| 60 |
+
tmp29 = tmp27 + tmp28
|
| 61 |
+
tmp30 = libdevice.rsqrt(tmp29)
|
| 62 |
+
tmp31 = tmp24 * tmp30
|
| 63 |
+
tmp32 = tl.load(in_ptr5 + (tl.broadcast_to(x2, [YBLOCK, XBLOCK])), tmp20 & xmask & ymask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 64 |
+
tmp33 = tmp32.to(tl.float32)
|
| 65 |
+
tmp34 = tmp31 * tmp33
|
| 66 |
+
tmp35 = tmp34.to(tl.float32)
|
| 67 |
+
tmp36 = tl.full(tmp35.shape, 0.0, tmp35.dtype)
|
| 68 |
+
tmp37 = tl.where(tmp20, tmp35, tmp36)
|
| 69 |
+
tmp38 = tl.where(tmp4, tmp19, tmp37)
|
| 70 |
+
tl.store(out_ptr0 + (x2 + 128*y3), tmp38, xmask & ymask)
|
torchinductor/2o/c2oduffhka4c52657rppatcdtgtnibm42qywfo2spmul2dpsj6jj.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AOT ID: ['0_inference']
|
| 2 |
+
from ctypes import c_void_p, c_long, c_int
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from math import inf, nan
|
| 9 |
+
from cmath import nanj
|
| 10 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 11 |
+
from torch._inductor.utils import maybe_profile
|
| 12 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 13 |
+
from torch import device, empty_strided
|
| 14 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 15 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 19 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 20 |
+
|
| 21 |
+
aten = torch.ops.aten
|
| 22 |
+
inductor_ops = torch.ops.inductor
|
| 23 |
+
_quantized = torch.ops._quantized
|
| 24 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 25 |
+
assert_alignment = torch._C._dynamo.guards.assert_alignment
|
| 26 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 27 |
+
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
|
| 28 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 29 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 30 |
+
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
|
| 31 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 32 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 33 |
+
async_compile = AsyncCompile()
|
| 34 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# kernel path: /app/tensorrt_llm/visual_gen/compiled_cache/flux2_klein_9b_NVIDIA_GeForce_RTX_4090_sm89_torch2.10.0a0_b4e4ee81d3.nv25.12_cuda13_1/torchinductor/ww/cwwizzjwmd4ajlubxpvxidjiy3ldv5eflwludgizcahvsp4i75s2.py
|
| 38 |
+
# Topologically Sorted Source Nodes: [norm_hidden_states, add, mul, norm_hidden_states_1], Original ATen: [aten.native_layer_norm, aten.add, aten.mul]
|
| 39 |
+
# Source node to ATen node mapping:
|
| 40 |
+
# add => add_1
|
| 41 |
+
# mul => mul_1
|
| 42 |
+
# norm_hidden_states => add, convert_element_type, convert_element_type_1, mul, rsqrt, sub, var_mean
|
| 43 |
+
# norm_hidden_states_1 => add_2
|
| 44 |
+
# Graph fragment:
|
| 45 |
+
# %arg0_1 : Tensor "bf16[1, 2048, 4096][8388608, 4096, 1]cuda:0" = PlaceHolder[target=arg0_1]
|
| 46 |
+
# %arg1_1 : Tensor "bf16[1, 1, 4096][24576, 24576, 1]cuda:0" = PlaceHolder[target=arg1_1]
|
| 47 |
+
# %getitem_1 : Tensor "f32[1, 2048, 1][2048, 1, 2048]cuda:0" = PlaceHolder[target=getitem_1]
|
| 48 |
+
# %buf1 : Tensor "f32[1, 2048, 1][2048, 1, 2048]cuda:0" = PlaceHolder[target=buf1]
|
| 49 |
+
# %arg2_1 : Tensor "bf16[1, 1, 4096][24576, 24576, 1]cuda:0" = PlaceHolder[target=arg2_1]
|
| 50 |
+
# %convert_element_type : Tensor "f32[1, 2048, 4096][8388608, 4096, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg0_1, torch.float32), kwargs = {})
|
| 51 |
+
# %var_mean : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type, [2]), kwargs = {correction: 0, keepdim: True})
|
| 52 |
+
# %add_1 : Tensor "bf16[1, 1, 4096][4096, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg1_1, 1), kwargs = {})
|
| 53 |
+
# %sub : Tensor "f32[1, 2048, 4096][8388608, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem_1), kwargs = {})
|
| 54 |
+
# %add : Tensor "f32[1, 2048, 1][2048, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getitem, 1e-06), kwargs = {})
|
| 55 |
+
# %rsqrt : Tensor "f32[1, 2048, 1][2048, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add,), kwargs = {})
|
| 56 |
+
# %mul : Tensor "f32[1, 2048, 4096][8388608, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub, %rsqrt), kwargs = {})
|
| 57 |
+
# %convert_element_type_1 : Tensor "bf16[1, 2048, 4096][8388608, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
|
| 58 |
+
# %mul_1 : Tensor "bf16[1, 2048, 4096][8388608, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_1, %convert_element_type_1), kwargs = {})
|
| 59 |
+
# %add_2 : Tensor "bf16[1, 2048, 4096][8388608, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_1, %arg2_1), kwargs = {})
|
| 60 |
+
# return %getitem_1,%buf1,%add_2
|
| 61 |
+
triton_red_fused_add_mul_native_layer_norm_0 = async_compile.triton('triton_red_fused_add_mul_native_layer_norm_0', '''
|
| 62 |
+
import triton
|
| 63 |
+
import triton.language as tl
|
| 64 |
+
|
| 65 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 66 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 67 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 68 |
+
triton_helpers.set_driver_to_gpu()
|
| 69 |
+
|
| 70 |
+
@triton_heuristics.reduction(
|
| 71 |
+
size_hints={'x': 2048, 'r0_': 4096},
|
| 72 |
+
reduction_hint=ReductionHint.INNER,
|
| 73 |
+
filename=__file__,
|
| 74 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr2': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
|
| 75 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_add_mul_native_layer_norm_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 2, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 50348032}}
|
| 76 |
+
)
|
| 77 |
+
@triton.jit
|
| 78 |
+
def triton_red_fused_add_mul_native_layer_norm_0(in_ptr0, in_ptr1, in_ptr2, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
| 79 |
+
xnumel = 2048
|
| 80 |
+
r0_numel = 4096
|
| 81 |
+
rnumel = r0_numel
|
| 82 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 83 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 84 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 85 |
+
xmask = xindex < xnumel
|
| 86 |
+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
| 87 |
+
rbase = r0_base
|
| 88 |
+
x0 = xindex
|
| 89 |
+
tmp3_mean = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
|
| 90 |
+
tmp3_m2 = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
|
| 91 |
+
tmp3_weight = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
|
| 92 |
+
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
|
| 93 |
+
r0_index = r0_offset + r0_base
|
| 94 |
+
r0_mask = r0_index < r0_numel
|
| 95 |
+
roffset = r0_offset
|
| 96 |
+
rindex = r0_index
|
| 97 |
+
r0_1 = r0_index
|
| 98 |
+
tmp0 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 99 |
+
tmp1 = tmp0.to(tl.float32)
|
| 100 |
+
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
|
| 101 |
+
tmp3_mean_next, tmp3_m2_next, tmp3_weight_next = triton_helpers.welford_reduce(
|
| 102 |
+
tmp2, tmp3_mean, tmp3_m2, tmp3_weight, roffset == 0
|
| 103 |
+
)
|
| 104 |
+
tmp3_mean = tl.where(r0_mask & xmask, tmp3_mean_next, tmp3_mean)
|
| 105 |
+
tmp3_m2 = tl.where(r0_mask & xmask, tmp3_m2_next, tmp3_m2)
|
| 106 |
+
tmp3_weight = tl.where(r0_mask & xmask, tmp3_weight_next, tmp3_weight)
|
| 107 |
+
tmp4, tmp5, tmp6 = triton_helpers.welford(tmp3_mean, tmp3_m2, tmp3_weight, 1)
|
| 108 |
+
tmp3 = tmp4[:, None]
|
| 109 |
+
tmp7 = tmp5[:, None]
|
| 110 |
+
tmp8 = tmp6[:, None]
|
| 111 |
+
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
|
| 112 |
+
r0_index = r0_offset + r0_base
|
| 113 |
+
r0_mask = r0_index < r0_numel
|
| 114 |
+
roffset = r0_offset
|
| 115 |
+
rindex = r0_index
|
| 116 |
+
r0_1 = r0_index
|
| 117 |
+
tmp9 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 118 |
+
tmp12 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 119 |
+
tmp23 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 120 |
+
tmp10 = 1.0
|
| 121 |
+
tmp11 = tmp9 + tmp10
|
| 122 |
+
tmp13 = tmp12.to(tl.float32)
|
| 123 |
+
tmp14 = tmp13 - tmp3
|
| 124 |
+
tmp15 = 4096.0
|
| 125 |
+
tmp16 = (tmp7 / tmp15)
|
| 126 |
+
tmp17 = 1e-06
|
| 127 |
+
tmp18 = tmp16 + tmp17
|
| 128 |
+
tmp19 = libdevice.rsqrt(tmp18)
|
| 129 |
+
tmp20 = tmp14 * tmp19
|
| 130 |
+
tmp21 = tmp20.to(tl.float32)
|
| 131 |
+
tmp22 = tmp11 * tmp21
|
| 132 |
+
tmp24 = tmp22 + tmp23
|
| 133 |
+
tl.store(out_ptr2 + (r0_1 + 4096*x0), tmp24, r0_mask & xmask)
|
| 134 |
+
''', device_str='cuda')
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# kernel path: /app/tensorrt_llm/visual_gen/compiled_cache/flux2_klein_9b_NVIDIA_GeForce_RTX_4090_sm89_torch2.10.0a0_b4e4ee81d3.nv25.12_cuda13_1/torchinductor/av/cavoaz6e7kbk5wq2n7vz6rxhcrwdu2trazexubdq5qwyv2ajmbkz.py
|
| 138 |
+
# Topologically Sorted Source Nodes: [norm_encoder_hidden_states, add_2, mul_1, norm_encoder_hidden_states_1], Original ATen: [aten.native_layer_norm, aten.add, aten.mul]
|
| 139 |
+
# Source node to ATen node mapping:
|
| 140 |
+
# add_2 => add_4
|
| 141 |
+
# mul_1 => mul_3
|
| 142 |
+
# norm_encoder_hidden_states => add_3, convert_element_type_2, convert_element_type_3, mul_2, rsqrt_1, sub_1, var_mean_1
|
| 143 |
+
# norm_encoder_hidden_states_1 => add_5
|
| 144 |
+
# Graph fragment:
|
| 145 |
+
# %arg3_1 : Tensor "bf16[1, 256, 4096][1048576, 4096, 1]cuda:0" = PlaceHolder[target=arg3_1]
|
| 146 |
+
# %arg4_1 : Tensor "bf16[1, 1, 4096][24576, 24576, 1]cuda:0" = PlaceHolder[target=arg4_1]
|
| 147 |
+
# %getitem_3 : Tensor "f32[1, 256, 1][256, 1, 256]cuda:0" = PlaceHolder[target=getitem_3]
|
| 148 |
+
# %buf4 : Tensor "f32[1, 256, 1][256, 1, 256]cuda:0" = PlaceHolder[target=buf4]
|
| 149 |
+
# %arg5_1 : Tensor "bf16[1, 1, 4096][24576, 24576, 1]cuda:0" = PlaceHolder[target=arg5_1]
|
| 150 |
+
# %convert_element_type_2 : Tensor "f32[1, 256, 4096][1048576, 4096, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg3_1, torch.float32), kwargs = {})
|
| 151 |
+
# %var_mean_1 : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_2, [2]), kwargs = {correction: 0, keepdim: True})
|
| 152 |
+
# %add_4 : Tensor "bf16[1, 1, 4096][4096, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg4_1, 1), kwargs = {})
|
| 153 |
+
# %sub_1 : Tensor "f32[1, 256, 4096][1048576, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type_2, %getitem_3), kwargs = {})
|
| 154 |
+
# %add_3 : Tensor "f32[1, 256, 1][256, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getitem_2, 1e-06), kwargs = {})
|
| 155 |
+
# %rsqrt_1 : Tensor "f32[1, 256, 1][256, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_3,), kwargs = {})
|
| 156 |
+
# %mul_2 : Tensor "f32[1, 256, 4096][1048576, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_1, %rsqrt_1), kwargs = {})
|
| 157 |
+
# %convert_element_type_3 : Tensor "bf16[1, 256, 4096][1048576, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_2, torch.bfloat16), kwargs = {})
|
| 158 |
+
# %mul_3 : Tensor "bf16[1, 256, 4096][1048576, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_4, %convert_element_type_3), kwargs = {})
|
| 159 |
+
# %add_5 : Tensor "bf16[1, 256, 4096][1048576, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_3, %arg5_1), kwargs = {})
|
| 160 |
+
# return %getitem_3,%buf4,%add_5
|
| 161 |
+
triton_red_fused_add_mul_native_layer_norm_1 = async_compile.triton('triton_red_fused_add_mul_native_layer_norm_1', '''
|
| 162 |
+
import triton
|
| 163 |
+
import triton.language as tl
|
| 164 |
+
|
| 165 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 166 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 167 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 168 |
+
triton_helpers.set_driver_to_gpu()
|
| 169 |
+
|
| 170 |
+
@triton_heuristics.reduction(
|
| 171 |
+
size_hints={'x': 256, 'r0_': 4096},
|
| 172 |
+
reduction_hint=ReductionHint.INNER,
|
| 173 |
+
filename=__file__,
|
| 174 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr2': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
|
| 175 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_add_mul_native_layer_norm_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 2, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 6307840}}
|
| 176 |
+
)
|
| 177 |
+
@triton.jit
|
| 178 |
+
def triton_red_fused_add_mul_native_layer_norm_1(in_ptr0, in_ptr1, in_ptr2, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
| 179 |
+
xnumel = 256
|
| 180 |
+
r0_numel = 4096
|
| 181 |
+
rnumel = r0_numel
|
| 182 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 183 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 184 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 185 |
+
xmask = xindex < xnumel
|
| 186 |
+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
| 187 |
+
rbase = r0_base
|
| 188 |
+
x0 = xindex
|
| 189 |
+
tmp3_mean = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
|
| 190 |
+
tmp3_m2 = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
|
| 191 |
+
tmp3_weight = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
|
| 192 |
+
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
|
| 193 |
+
r0_index = r0_offset + r0_base
|
| 194 |
+
r0_mask = r0_index < r0_numel
|
| 195 |
+
roffset = r0_offset
|
| 196 |
+
rindex = r0_index
|
| 197 |
+
r0_1 = r0_index
|
| 198 |
+
tmp0 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 199 |
+
tmp1 = tmp0.to(tl.float32)
|
| 200 |
+
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
|
| 201 |
+
tmp3_mean_next, tmp3_m2_next, tmp3_weight_next = triton_helpers.welford_reduce(
|
| 202 |
+
tmp2, tmp3_mean, tmp3_m2, tmp3_weight, roffset == 0
|
| 203 |
+
)
|
| 204 |
+
tmp3_mean = tl.where(r0_mask & xmask, tmp3_mean_next, tmp3_mean)
|
| 205 |
+
tmp3_m2 = tl.where(r0_mask & xmask, tmp3_m2_next, tmp3_m2)
|
| 206 |
+
tmp3_weight = tl.where(r0_mask & xmask, tmp3_weight_next, tmp3_weight)
|
| 207 |
+
tmp4, tmp5, tmp6 = triton_helpers.welford(tmp3_mean, tmp3_m2, tmp3_weight, 1)
|
| 208 |
+
tmp3 = tmp4[:, None]
|
| 209 |
+
tmp7 = tmp5[:, None]
|
| 210 |
+
tmp8 = tmp6[:, None]
|
| 211 |
+
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
|
| 212 |
+
r0_index = r0_offset + r0_base
|
| 213 |
+
r0_mask = r0_index < r0_numel
|
| 214 |
+
roffset = r0_offset
|
| 215 |
+
rindex = r0_index
|
| 216 |
+
r0_1 = r0_index
|
| 217 |
+
tmp9 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 218 |
+
tmp12 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 219 |
+
tmp23 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 220 |
+
tmp10 = 1.0
|
| 221 |
+
tmp11 = tmp9 + tmp10
|
| 222 |
+
tmp13 = tmp12.to(tl.float32)
|
| 223 |
+
tmp14 = tmp13 - tmp3
|
| 224 |
+
tmp15 = 4096.0
|
| 225 |
+
tmp16 = (tmp7 / tmp15)
|
| 226 |
+
tmp17 = 1e-06
|
| 227 |
+
tmp18 = tmp16 + tmp17
|
| 228 |
+
tmp19 = libdevice.rsqrt(tmp18)
|
| 229 |
+
tmp20 = tmp14 * tmp19
|
| 230 |
+
tmp21 = tmp20.to(tl.float32)
|
| 231 |
+
tmp22 = tmp11 * tmp21
|
| 232 |
+
tmp24 = tmp22 + tmp23
|
| 233 |
+
tl.store(out_ptr2 + (r0_1 + 4096*x0), tmp24, r0_mask & xmask)
|
| 234 |
+
''', device_str='cuda')
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
async_compile.wait(globals())
|
| 238 |
+
del async_compile
|
| 239 |
+
|
| 240 |
+
class Runner:
|
| 241 |
+
def __init__(self, partitions):
|
| 242 |
+
self.partitions = partitions
|
| 243 |
+
|
| 244 |
+
def recursively_apply_fns(self, fns):
|
| 245 |
+
new_callables = []
|
| 246 |
+
for fn, c in zip(fns, self.partitions):
|
| 247 |
+
new_callables.append(fn(c))
|
| 248 |
+
self.partitions = new_callables
|
| 249 |
+
|
| 250 |
+
def call(self, args):
|
| 251 |
+
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1 = args
|
| 252 |
+
args.clear()
|
| 253 |
+
assert_size_stride(arg0_1, (1, 2048, 4096), (8388608, 4096, 1))
|
| 254 |
+
assert_size_stride(arg1_1, (1, 1, 4096), (24576, 24576, 1))
|
| 255 |
+
assert_size_stride(arg2_1, (1, 1, 4096), (24576, 24576, 1))
|
| 256 |
+
assert_size_stride(arg3_1, (1, 256, 4096), (1048576, 4096, 1))
|
| 257 |
+
assert_size_stride(arg4_1, (1, 1, 4096), (24576, 24576, 1))
|
| 258 |
+
assert_size_stride(arg5_1, (1, 1, 4096), (24576, 24576, 1))
|
| 259 |
+
with torch.cuda._DeviceGuard(0):
|
| 260 |
+
torch.cuda.set_device(0)
|
| 261 |
+
buf6 = empty_strided_cuda((1, 2048, 4096), (8388608, 4096, 1), torch.bfloat16)
|
| 262 |
+
# Topologically Sorted Source Nodes: [norm_hidden_states, add, mul, norm_hidden_states_1], Original ATen: [aten.native_layer_norm, aten.add, aten.mul]
|
| 263 |
+
stream0 = get_raw_stream(0)
|
| 264 |
+
triton_red_fused_add_mul_native_layer_norm_0.run(arg0_1, arg1_1, arg2_1, buf6, 2048, 4096, stream=stream0)
|
| 265 |
+
del arg0_1
|
| 266 |
+
del arg1_1
|
| 267 |
+
del arg2_1
|
| 268 |
+
buf7 = empty_strided_cuda((1, 256, 4096), (1048576, 4096, 1), torch.bfloat16)
|
| 269 |
+
# Topologically Sorted Source Nodes: [norm_encoder_hidden_states, add_2, mul_1, norm_encoder_hidden_states_1], Original ATen: [aten.native_layer_norm, aten.add, aten.mul]
|
| 270 |
+
stream0 = get_raw_stream(0)
|
| 271 |
+
triton_red_fused_add_mul_native_layer_norm_1.run(arg3_1, arg4_1, arg5_1, buf7, 256, 4096, stream=stream0)
|
| 272 |
+
del arg3_1
|
| 273 |
+
del arg4_1
|
| 274 |
+
del arg5_1
|
| 275 |
+
return (buf6, buf7, )
|
| 276 |
+
|
| 277 |
+
runner = Runner(partitions=[])
|
| 278 |
+
call = runner.call
|
| 279 |
+
recursively_apply_fns = runner.recursively_apply_fns
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 283 |
+
from torch._dynamo.testing import rand_strided
|
| 284 |
+
from torch._inductor.utils import print_performance
|
| 285 |
+
arg0_1 = rand_strided((1, 2048, 4096), (8388608, 4096, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 286 |
+
arg1_1 = rand_strided((1, 1, 4096), (24576, 24576, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 287 |
+
arg2_1 = rand_strided((1, 1, 4096), (24576, 24576, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 288 |
+
arg3_1 = rand_strided((1, 256, 4096), (1048576, 4096, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 289 |
+
arg4_1 = rand_strided((1, 1, 4096), (24576, 24576, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 290 |
+
arg5_1 = rand_strided((1, 1, 4096), (24576, 24576, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 291 |
+
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1])
|
| 292 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
if __name__ == "__main__":
|
| 296 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 297 |
+
compiled_module_main('None', benchmark_compiled_module)
|
torchinductor/3i/c3imyfibcq3zwrc5gvfscvpsaxdzjijymrnt2lfahtrjmrbtlmhe.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'x': 67108864},
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_cat_mul_silu_split_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 377487360}},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused_cat_mul_silu_split_view_0(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
|
| 19 |
+
xnumel = 37748736
|
| 20 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 21 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 22 |
+
xmask = tl.full([XBLOCK], True, tl.int1)
|
| 23 |
+
x0 = (xindex % 16384)
|
| 24 |
+
x1 = xindex // 16384
|
| 25 |
+
x2 = xindex
|
| 26 |
+
tmp0 = x0
|
| 27 |
+
tmp1 = tl.full([1], 0, tl.int64)
|
| 28 |
+
tmp2 = tmp0 >= tmp1
|
| 29 |
+
tmp3 = tl.full([1], 4096, tl.int64)
|
| 30 |
+
tmp4 = tmp0 < tmp3
|
| 31 |
+
tmp5 = tl.load(in_ptr0 + (4096*x1 + (x0)), tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 32 |
+
tmp6 = tmp0 >= tmp3
|
| 33 |
+
tmp7 = tl.full([1], 16384, tl.int64)
|
| 34 |
+
tmp8 = tmp0 < tmp7
|
| 35 |
+
tmp9 = tl.load(in_ptr1 + (36864*x1 + ((-4096) + x0)), tmp6, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 36 |
+
tmp10 = tmp9.to(tl.float32)
|
| 37 |
+
tmp11 = tl.sigmoid(tmp10)
|
| 38 |
+
tmp12 = tmp10 * tmp11
|
| 39 |
+
tmp13 = tmp12.to(tl.float32)
|
| 40 |
+
tmp14 = tl.load(in_ptr1 + (12288 + 36864*x1 + ((-4096) + x0)), tmp6, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 41 |
+
tmp15 = tmp13 * tmp14
|
| 42 |
+
tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
|
| 43 |
+
tmp17 = tl.where(tmp6, tmp15, tmp16)
|
| 44 |
+
tmp18 = tl.where(tmp4, tmp5, tmp17)
|
| 45 |
+
tl.store(out_ptr0 + (x2), tmp18, None)
|
torchinductor/3i/cf1587a2fd240ce39177274973308f6fd100d746bf6716a8d96ed4fd12c89d55.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 81, "triton_cache_hash": "PPN4SVQW2UFKVPWUB7HCOIHQMJON3EA6PX7FI3IPMCGAPBBOTNMQ"}
|
torchinductor/3v/4a00da1b5d4ce251d2cb392c24118fc2e6c3818f25b8457665f0d53e12234277.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 1024, "num_warps": 4, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 43, "triton_cache_hash": "SJ2F5NEEPBSFTTPVSLW22OOIZQR5FPT5YWSURMFRPHLWAFZ5VB7A"}
|
torchinductor/3v/c3vjilvcy7sdqcaxfspffadbsry3sqm4j7qp6u3tusr6p34sbiar.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'x': 16777216},
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_cudnn_attention_clone_permute_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 37748736}},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused__scaled_dot_product_cudnn_attention_clone_permute_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
|
| 19 |
+
xnumel = 9437184
|
| 20 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 21 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 22 |
+
xmask = tl.full([XBLOCK], True, tl.int1)
|
| 23 |
+
x0 = (xindex % 128)
|
| 24 |
+
x1 = ((xindex // 128) % 2304)
|
| 25 |
+
x2 = xindex // 294912
|
| 26 |
+
x3 = xindex
|
| 27 |
+
tmp0 = tl.load(in_ptr0 + (x0 + 128*x2 + ks0*x1), None).to(tl.float32)
|
| 28 |
+
tl.store(out_ptr0 + (x3), tmp0, None)
|
torchinductor/4y/c4ykjyk6fv6enet6mgkj5bsan42tc6rsdfs7aaskpjgv5rzw7tbr.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AOT ID: ['25_inference']
|
| 2 |
+
from ctypes import c_void_p, c_long, c_int
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from math import inf, nan
|
| 9 |
+
from cmath import nanj
|
| 10 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 11 |
+
from torch._inductor.utils import maybe_profile
|
| 12 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 13 |
+
from torch import device, empty_strided
|
| 14 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 15 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 19 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 20 |
+
|
| 21 |
+
aten = torch.ops.aten
|
| 22 |
+
inductor_ops = torch.ops.inductor
|
| 23 |
+
_quantized = torch.ops._quantized
|
| 24 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 25 |
+
assert_alignment = torch._C._dynamo.guards.assert_alignment
|
| 26 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 27 |
+
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
|
| 28 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 29 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 30 |
+
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
|
| 31 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 32 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 33 |
+
async_compile = AsyncCompile()
|
| 34 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# kernel path: /app/tensorrt_llm/visual_gen/compiled_cache/flux2_klein_9b_NVIDIA_GeForce_RTX_4090_sm89_torch2.10.0a0_b4e4ee81d3.nv25.12_cuda13_1/torchinductor/bv/cbvqhjtyg7fvxzwtbtt4vrdkbnb6n32fnrijjpl3vv4cfqd4mznr.py
|
| 38 |
+
# Topologically Sorted Source Nodes: [split, chunk, query_1, query_2, reshape, unbind, key_1, key_2, reshape_1, unbind_1, float_1, cos, mul, neg, stack, x_rotated, float_2, sin, mul_1, add, out, float_3, cos_2, mul_2, neg_1, stack_1, x_rotated_1, float_4, sin_2, mul_3, add_1, out_1], Original ATen: [aten.split_with_sizes, aten.split, aten.view, aten._fused_rms_norm, aten.unbind, aten._to_copy, aten.unsqueeze, aten.mul, aten.neg, aten.stack, aten.add]
|
| 39 |
+
# Source node to ATen node mapping:
|
| 40 |
+
# add => add_2
|
| 41 |
+
# add_1 => add_3
|
| 42 |
+
# chunk => split
|
| 43 |
+
# cos => unsqueeze, unsqueeze_1
|
| 44 |
+
# cos_2 => unsqueeze_6, unsqueeze_7
|
| 45 |
+
# float_1 => convert_element_type_4
|
| 46 |
+
# float_2 => convert_element_type_5
|
| 47 |
+
# float_3 => convert_element_type_7
|
| 48 |
+
# float_4 => convert_element_type_8
|
| 49 |
+
# key_1 => view_1
|
| 50 |
+
# key_2 => add_1, convert_element_type_2, convert_element_type_3, mean_1, mul_2, mul_3, pow_2, rsqrt_1
|
| 51 |
+
# mul => mul_4
|
| 52 |
+
# mul_1 => mul_5
|
| 53 |
+
# mul_2 => mul_6
|
| 54 |
+
# mul_3 => mul_7
|
| 55 |
+
# neg => neg
|
| 56 |
+
# neg_1 => neg_1
|
| 57 |
+
# out => convert_element_type_6
|
| 58 |
+
# out_1 => convert_element_type_9
|
| 59 |
+
# query_1 => view
|
| 60 |
+
# query_2 => add, convert_element_type, convert_element_type_1, mean, mul, mul_1, pow_1, rsqrt
|
| 61 |
+
# reshape => view_3
|
| 62 |
+
# reshape_1 => view_5
|
| 63 |
+
# sin => unsqueeze_2, unsqueeze_3
|
| 64 |
+
# sin_2 => unsqueeze_8, unsqueeze_9
|
| 65 |
+
# split => split_with_sizes
|
| 66 |
+
# stack => cat, unsqueeze_4, unsqueeze_5
|
| 67 |
+
# stack_1 => cat_1, unsqueeze_10, unsqueeze_11
|
| 68 |
+
# unbind => unbind
|
| 69 |
+
# unbind_1 => unbind_1
|
| 70 |
+
# x_rotated => view_4
|
| 71 |
+
# x_rotated_1 => view_6
|
| 72 |
+
# Graph fragment:
|
| 73 |
+
# %arg0_1 : Tensor "bf16[1, 2304, 36864][84934656, 36864, 1]cuda:0" = PlaceHolder[target=arg0_1]
|
| 74 |
+
# %buf0 : Tensor "f32[1, 2304, 32, 1][73728, 32, 1, 73728]cuda:0" = PlaceHolder[target=buf0]
|
| 75 |
+
# %arg1_1 : Tensor "bf16[128][1]cuda:0" = PlaceHolder[target=arg1_1]
|
| 76 |
+
# %arg3_1 : Tensor "f32[2304, 128][128, 1]cuda:0" = PlaceHolder[target=arg3_1]
|
| 77 |
+
# %cat : Tensor "bf16[1, 2304, 32, 64, 2][9437184, 4096, 128, 2, 1]cuda:0" = PlaceHolder[target=cat]
|
| 78 |
+
# %arg4_1 : Tensor "f32[2304, 128][128, 1]cuda:0" = PlaceHolder[target=arg4_1]
|
| 79 |
+
# %buf1 : Tensor "f32[1, 2304, 32, 1][73728, 32, 1, 73728]cuda:0" = PlaceHolder[target=buf1]
|
| 80 |
+
# %arg2_1 : Tensor "bf16[128][1]cuda:0" = PlaceHolder[target=arg2_1]
|
| 81 |
+
# %cat_1 : Tensor "bf16[1, 2304, 32, 64, 2][9437184, 4096, 128, 2, 1]cuda:0" = PlaceHolder[target=cat_1]
|
| 82 |
+
# %split_with_sizes : [num_users=2] = call_function[target=torch.ops.aten.split_with_sizes.default](args = (%arg0_1, [12288, 24576], -1), kwargs = {})
|
| 83 |
+
# %split : [num_users=3] = call_function[target=torch.ops.aten.split.Tensor](args = (%getitem, 4096, -1), kwargs = {})
|
| 84 |
+
# %view : Tensor "bf16[1, 2304, 32, 128][84934656, 36864, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%getitem_2, [1, 2304, 32, 128]), kwargs = {})
|
| 85 |
+
# %convert_element_type : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view, torch.float32), kwargs = {})
|
| 86 |
+
# %pow_1 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type, 2), kwargs = {})
|
| 87 |
+
# %mean : Tensor "f32[1, 2304, 32, 1][73728, 32, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [3], True), kwargs = {})
|
| 88 |
+
# %add : Tensor "f32[1, 2304, 32, 1][73728, 32, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Scalar](args = (%mean, 1e-06), kwargs = {})
|
| 89 |
+
# %rsqrt : Tensor "f32[1, 2304, 32, 1][73728, 32, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add,), kwargs = {})
|
| 90 |
+
# %mul : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {})
|
| 91 |
+
# %mul_1 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul, %arg1_1), kwargs = {})
|
| 92 |
+
# %convert_element_type_1 : Tensor "bf16[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_1, torch.bfloat16), kwargs = {})
|
| 93 |
+
# %view_3 : Tensor "bf16[1, 2304, 32, 64, 2][9437184, 4096, 128, 2, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_1, [1, 2304, 32, -1, 2]), kwargs = {})
|
| 94 |
+
# %unbind : [num_users=2] = call_function[target=torch.ops.aten.unbind.int](args = (%view_3, -1), kwargs = {})
|
| 95 |
+
# %view_1 : Tensor "bf16[1, 2304, 32, 128][84934656, 36864, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%getitem_3, [1, 2304, 32, 128]), kwargs = {})
|
| 96 |
+
# %convert_element_type_2 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_1, torch.float32), kwargs = {})
|
| 97 |
+
# %pow_2 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_2, 2), kwargs = {})
|
| 98 |
+
# %mean_1 : Tensor "f32[1, 2304, 32, 1][73728, 32, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_2, [3], True), kwargs = {})
|
| 99 |
+
# %add_1 : Tensor "f32[1, 2304, 32, 1][73728, 32, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Scalar](args = (%mean_1, 1e-06), kwargs = {})
|
| 100 |
+
# %rsqrt_1 : Tensor "f32[1, 2304, 32, 1][73728, 32, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_1,), kwargs = {})
|
| 101 |
+
# %mul_2 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %rsqrt_1), kwargs = {})
|
| 102 |
+
# %mul_3 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_2, %arg2_1), kwargs = {})
|
| 103 |
+
# %convert_element_type_3 : Tensor "bf16[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_3, torch.bfloat16), kwargs = {})
|
| 104 |
+
# %view_5 : Tensor "bf16[1, 2304, 32, 64, 2][9437184, 4096, 128, 2, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_3, [1, 2304, 32, -1, 2]), kwargs = {})
|
| 105 |
+
# %unbind_1 : [num_users=2] = call_function[target=torch.ops.aten.unbind.int](args = (%view_5, -1), kwargs = {})
|
| 106 |
+
# %convert_element_type_4 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.float32), kwargs = {})
|
| 107 |
+
# %unsqueeze : Tensor "f32[1, 2304, 128][294912, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%arg3_1, 0), kwargs = {})
|
| 108 |
+
# %unsqueeze_1 : Tensor "f32[1, 2304, 1, 128][294912, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze, 2), kwargs = {})
|
| 109 |
+
# %mul_4 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_4, %unsqueeze_1), kwargs = {})
|
| 110 |
+
# %neg : Tensor "bf16[1, 2304, 32, 64][4718592, 2048, 64, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_6,), kwargs = {})
|
| 111 |
+
# %unsqueeze_4 : Tensor "bf16[1, 2304, 32, 64, 1][4718592, 2048, 64, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%neg, 4), kwargs = {})
|
| 112 |
+
# %unsqueeze_5 : Tensor "bf16[1, 2304, 32, 64, 1][9437184, 4096, 128, 2, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%getitem_5, 4), kwargs = {})
|
| 113 |
+
# %cat : Tensor "bf16[1, 2304, 32, 64, 2][9437184, 4096, 128, 2, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%unsqueeze_4, %unsqueeze_5], -1), kwargs = {})
|
| 114 |
+
# %view_4 : Tensor "bf16[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%cat, [1, 2304, 32, 128]), kwargs = {})
|
| 115 |
+
# %convert_element_type_5 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_4, torch.float32), kwargs = {})
|
| 116 |
+
# %unsqueeze_2 : Tensor "f32[1, 2304, 128][294912, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%arg4_1, 0), kwargs = {})
|
| 117 |
+
# %unsqueeze_3 : Tensor "f32[1, 2304, 1, 128][294912, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, 2), kwargs = {})
|
| 118 |
+
# %mul_5 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_5, %unsqueeze_3), kwargs = {})
|
| 119 |
+
# %add_2 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_4, %mul_5), kwargs = {})
|
| 120 |
+
# %convert_element_type_6 : Tensor "bf16[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_2, torch.bfloat16), kwargs = {})
|
| 121 |
+
# %convert_element_type_7 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_3, torch.float32), kwargs = {})
|
| 122 |
+
# %unsqueeze_6 : Tensor "f32[1, 2304, 128][294912, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%arg3_1, 0), kwargs = {})
|
| 123 |
+
# %unsqueeze_7 : Tensor "f32[1, 2304, 1, 128][294912, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_6, 2), kwargs = {})
|
| 124 |
+
# %mul_6 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_7, %unsqueeze_7), kwargs = {})
|
| 125 |
+
# %neg_1 : Tensor "bf16[1, 2304, 32, 64][4718592, 2048, 64, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_8,), kwargs = {})
|
| 126 |
+
# %unsqueeze_10 : Tensor "bf16[1, 2304, 32, 64, 1][4718592, 2048, 64, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%neg_1, 4), kwargs = {})
|
| 127 |
+
# %unsqueeze_11 : Tensor "bf16[1, 2304, 32, 64, 1][9437184, 4096, 128, 2, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%getitem_7, 4), kwargs = {})
|
| 128 |
+
# %cat_1 : Tensor "bf16[1, 2304, 32, 64, 2][9437184, 4096, 128, 2, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%unsqueeze_10, %unsqueeze_11], -1), kwargs = {})
|
| 129 |
+
# %view_6 : Tensor "bf16[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%cat_1, [1, 2304, 32, 128]), kwargs = {})
|
| 130 |
+
# %convert_element_type_8 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_6, torch.float32), kwargs = {})
|
| 131 |
+
# %unsqueeze_8 : Tensor "f32[1, 2304, 128][294912, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%arg4_1, 0), kwargs = {})
|
| 132 |
+
# %unsqueeze_9 : Tensor "f32[1, 2304, 1, 128][294912, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_8, 2), kwargs = {})
|
| 133 |
+
# %mul_7 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_8, %unsqueeze_9), kwargs = {})
|
| 134 |
+
# %add_3 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_6, %mul_7), kwargs = {})
|
| 135 |
+
# %convert_element_type_9 : Tensor "bf16[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_3, torch.bfloat16), kwargs = {})
|
| 136 |
+
# return %buf1,%buf0,%cat,%convert_element_type_6,%cat_1,%convert_element_type_9
|
| 137 |
+
triton_red_fused__fused_rms_norm__to_copy_add_mul_neg_split_split_with_sizes_stack_unbind_unsqueeze_view_0 = async_compile.triton('triton_red_fused__fused_rms_norm__to_copy_add_mul_neg_split_split_with_sizes_stack_unbind_unsqueeze_view_0', '''
|
| 138 |
+
import triton
|
| 139 |
+
import triton.language as tl
|
| 140 |
+
|
| 141 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 142 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 143 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 144 |
+
triton_helpers.set_driver_to_gpu()
|
| 145 |
+
|
| 146 |
+
@triton_heuristics.reduction(
|
| 147 |
+
size_hints={'x': 131072, 'r0_': 128},
|
| 148 |
+
reduction_hint=ReductionHint.DEFAULT,
|
| 149 |
+
filename=__file__,
|
| 150 |
+
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_out_ptr1': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
|
| 151 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__fused_rms_norm__to_copy_add_mul_neg_split_split_with_sizes_stack_unbind_unsqueeze_view_0', 'mutated_arg_names': ['in_out_ptr0', 'in_out_ptr1'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 16, 'num_store': 2, 'num_reduction': 2, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 115606016}}
|
| 152 |
+
)
|
| 153 |
+
@triton.jit
|
| 154 |
+
def triton_red_fused__fused_rms_norm__to_copy_add_mul_neg_split_split_with_sizes_stack_unbind_unsqueeze_view_0(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
| 155 |
+
xnumel = 73728
|
| 156 |
+
r0_numel = 128
|
| 157 |
+
rnumel = r0_numel
|
| 158 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 159 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 160 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 161 |
+
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 162 |
+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
| 163 |
+
rbase = r0_base
|
| 164 |
+
x0 = (xindex % 32)
|
| 165 |
+
x1 = xindex // 32
|
| 166 |
+
_tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
|
| 167 |
+
x5 = xindex
|
| 168 |
+
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
|
| 169 |
+
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
|
| 170 |
+
r0_index = r0_offset + r0_base
|
| 171 |
+
r0_mask = r0_index < r0_numel
|
| 172 |
+
roffset = r0_offset
|
| 173 |
+
rindex = r0_index
|
| 174 |
+
r0_2 = r0_index
|
| 175 |
+
tmp0 = tl.load(in_ptr0 + (4096 + r0_2 + 128*x0 + 36864*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 176 |
+
tmp6 = tl.load(in_ptr0 + (r0_2 + 128*x0 + 36864*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 177 |
+
tmp1 = tmp0.to(tl.float32)
|
| 178 |
+
tmp2 = tmp1 * tmp1
|
| 179 |
+
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
|
| 180 |
+
tmp5 = _tmp4 + tmp3
|
| 181 |
+
_tmp4 = tl.where(r0_mask, tmp5, _tmp4)
|
| 182 |
+
tmp7 = tmp6.to(tl.float32)
|
| 183 |
+
tmp8 = tmp7 * tmp7
|
| 184 |
+
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
|
| 185 |
+
tmp11 = _tmp10 + tmp9
|
| 186 |
+
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
|
| 187 |
+
tmp4 = tl.sum(_tmp4, 1)[:, None]
|
| 188 |
+
tmp10 = tl.sum(_tmp10, 1)[:, None]
|
| 189 |
+
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
|
| 190 |
+
r0_index = r0_offset + r0_base
|
| 191 |
+
r0_mask = r0_index < r0_numel
|
| 192 |
+
roffset = r0_offset
|
| 193 |
+
rindex = r0_index
|
| 194 |
+
r0_3 = (r0_index % 2)
|
| 195 |
+
r0_4 = r0_index // 2
|
| 196 |
+
r0_2 = r0_index
|
| 197 |
+
tmp50 = tl.load(in_ptr0 + (r0_2 + 128*x0 + 36864*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 198 |
+
tmp58 = tl.load(in_ptr1 + (r0_2), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 199 |
+
tmp63 = tl.load(in_ptr2 + (r0_2 + 128*x1), r0_mask, eviction_policy='evict_last', other=0.0)
|
| 200 |
+
tmp66 = tl.load(in_ptr3 + (r0_2 + 128*x1), r0_mask, eviction_policy='evict_last', other=0.0)
|
| 201 |
+
tmp96 = tl.load(in_ptr0 + (4096 + r0_2 + 128*x0 + 36864*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 202 |
+
tmp102 = tl.load(in_ptr4 + (r0_2), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 203 |
+
tmp12 = r0_3
|
| 204 |
+
tmp13 = tl.full([1, 1], 0, tl.int64)
|
| 205 |
+
tmp14 = tmp12 >= tmp13
|
| 206 |
+
tmp15 = tl.full([1, 1], 1, tl.int64)
|
| 207 |
+
tmp16 = tmp12 < tmp15
|
| 208 |
+
tmp17 = tl.load(in_ptr0 + (1 + 2*r0_4 + 128*x0 + 36864*x1), r0_mask & tmp16, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 209 |
+
tmp18 = tmp17.to(tl.float32)
|
| 210 |
+
tmp19 = 128.0
|
| 211 |
+
tmp20 = (tmp10 / tmp19)
|
| 212 |
+
tmp21 = 1e-06
|
| 213 |
+
tmp22 = tmp20 + tmp21
|
| 214 |
+
tmp23 = libdevice.rsqrt(tmp22)
|
| 215 |
+
tmp24 = tmp18 * tmp23
|
| 216 |
+
tmp25 = tl.load(in_ptr1 + (tl.broadcast_to(1 + 2*r0_4, [XBLOCK, R0_BLOCK])), r0_mask & tmp16, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 217 |
+
tmp26 = tmp25.to(tl.float32)
|
| 218 |
+
tmp27 = tmp24 * tmp26
|
| 219 |
+
tmp28 = tmp27.to(tl.float32)
|
| 220 |
+
tmp29 = -tmp28
|
| 221 |
+
tmp30 = tl.full(tmp29.shape, 0.0, tmp29.dtype)
|
| 222 |
+
tmp31 = tl.where(tmp16, tmp29, tmp30)
|
| 223 |
+
tmp32 = tmp12 >= tmp15
|
| 224 |
+
tmp33 = tl.full([1, 1], 2, tl.int64)
|
| 225 |
+
tmp34 = tmp12 < tmp33
|
| 226 |
+
tmp35 = tl.load(in_ptr0 + (2*r0_4 + 128*x0 + 36864*x1), r0_mask & tmp32, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 227 |
+
tmp36 = tmp35.to(tl.float32)
|
| 228 |
+
tmp37 = 128.0
|
| 229 |
+
tmp38 = (tmp10 / tmp37)
|
| 230 |
+
tmp39 = 1e-06
|
| 231 |
+
tmp40 = tmp38 + tmp39
|
| 232 |
+
tmp41 = libdevice.rsqrt(tmp40)
|
| 233 |
+
tmp42 = tmp36 * tmp41
|
| 234 |
+
tmp43 = tl.load(in_ptr1 + (tl.broadcast_to(2*r0_4, [XBLOCK, R0_BLOCK])), r0_mask & tmp32, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 235 |
+
tmp44 = tmp43.to(tl.float32)
|
| 236 |
+
tmp45 = tmp42 * tmp44
|
| 237 |
+
tmp46 = tmp45.to(tl.float32)
|
| 238 |
+
tmp47 = tl.full(tmp46.shape, 0.0, tmp46.dtype)
|
| 239 |
+
tmp48 = tl.where(tmp32, tmp46, tmp47)
|
| 240 |
+
tmp49 = tl.where(tmp16, tmp31, tmp48)
|
| 241 |
+
tmp51 = tmp50.to(tl.float32)
|
| 242 |
+
tmp52 = 128.0
|
| 243 |
+
tmp53 = (tmp10 / tmp52)
|
| 244 |
+
tmp54 = 1e-06
|
| 245 |
+
tmp55 = tmp53 + tmp54
|
| 246 |
+
tmp56 = libdevice.rsqrt(tmp55)
|
| 247 |
+
tmp57 = tmp51 * tmp56
|
| 248 |
+
tmp59 = tmp58.to(tl.float32)
|
| 249 |
+
tmp60 = tmp57 * tmp59
|
| 250 |
+
tmp61 = tmp60.to(tl.float32)
|
| 251 |
+
tmp62 = tmp61.to(tl.float32)
|
| 252 |
+
tmp64 = tmp62 * tmp63
|
| 253 |
+
tmp65 = tmp49.to(tl.float32)
|
| 254 |
+
tmp67 = tmp65 * tmp66
|
| 255 |
+
tmp68 = tmp64 + tmp67
|
| 256 |
+
tmp69 = tmp68.to(tl.float32)
|
| 257 |
+
tmp70 = tl.load(in_ptr0 + (4097 + 2*r0_4 + 128*x0 + 36864*x1), r0_mask & tmp16, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 258 |
+
tmp71 = tmp70.to(tl.float32)
|
| 259 |
+
tmp72 = (tmp4 / tmp19)
|
| 260 |
+
tmp73 = tmp72 + tmp21
|
| 261 |
+
tmp74 = libdevice.rsqrt(tmp73)
|
| 262 |
+
tmp75 = tmp71 * tmp74
|
| 263 |
+
tmp76 = tl.load(in_ptr4 + (tl.broadcast_to(1 + 2*r0_4, [XBLOCK, R0_BLOCK])), r0_mask & tmp16, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 264 |
+
tmp77 = tmp76.to(tl.float32)
|
| 265 |
+
tmp78 = tmp75 * tmp77
|
| 266 |
+
tmp79 = tmp78.to(tl.float32)
|
| 267 |
+
tmp80 = -tmp79
|
| 268 |
+
tmp81 = tl.full(tmp80.shape, 0.0, tmp80.dtype)
|
| 269 |
+
tmp82 = tl.where(tmp16, tmp80, tmp81)
|
| 270 |
+
tmp83 = tl.load(in_ptr0 + (4096 + 2*r0_4 + 128*x0 + 36864*x1), r0_mask & tmp32, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 271 |
+
tmp84 = tmp83.to(tl.float32)
|
| 272 |
+
tmp85 = (tmp4 / tmp37)
|
| 273 |
+
tmp86 = tmp85 + tmp39
|
| 274 |
+
tmp87 = libdevice.rsqrt(tmp86)
|
| 275 |
+
tmp88 = tmp84 * tmp87
|
| 276 |
+
tmp89 = tl.load(in_ptr4 + (tl.broadcast_to(2*r0_4, [XBLOCK, R0_BLOCK])), r0_mask & tmp32, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 277 |
+
tmp90 = tmp89.to(tl.float32)
|
| 278 |
+
tmp91 = tmp88 * tmp90
|
| 279 |
+
tmp92 = tmp91.to(tl.float32)
|
| 280 |
+
tmp93 = tl.full(tmp92.shape, 0.0, tmp92.dtype)
|
| 281 |
+
tmp94 = tl.where(tmp32, tmp92, tmp93)
|
| 282 |
+
tmp95 = tl.where(tmp16, tmp82, tmp94)
|
| 283 |
+
tmp97 = tmp96.to(tl.float32)
|
| 284 |
+
tmp98 = (tmp4 / tmp52)
|
| 285 |
+
tmp99 = tmp98 + tmp54
|
| 286 |
+
tmp100 = libdevice.rsqrt(tmp99)
|
| 287 |
+
tmp101 = tmp97 * tmp100
|
| 288 |
+
tmp103 = tmp102.to(tl.float32)
|
| 289 |
+
tmp104 = tmp101 * tmp103
|
| 290 |
+
tmp105 = tmp104.to(tl.float32)
|
| 291 |
+
tmp106 = tmp105.to(tl.float32)
|
| 292 |
+
tmp107 = tmp106 * tmp63
|
| 293 |
+
tmp108 = tmp95.to(tl.float32)
|
| 294 |
+
tmp109 = tmp108 * tmp66
|
| 295 |
+
tmp110 = tmp107 + tmp109
|
| 296 |
+
tmp111 = tmp110.to(tl.float32)
|
| 297 |
+
tl.store(in_out_ptr0 + (r0_2 + 128*x5), tmp69, r0_mask)
|
| 298 |
+
tl.store(in_out_ptr1 + (r0_2 + 128*x5), tmp111, r0_mask)
|
| 299 |
+
''', device_str='cuda')
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
async_compile.wait(globals())
|
| 303 |
+
del async_compile
|
| 304 |
+
|
| 305 |
+
class Runner:
|
| 306 |
+
def __init__(self, partitions):
|
| 307 |
+
self.partitions = partitions
|
| 308 |
+
|
| 309 |
+
def recursively_apply_fns(self, fns):
|
| 310 |
+
new_callables = []
|
| 311 |
+
for fn, c in zip(fns, self.partitions):
|
| 312 |
+
new_callables.append(fn(c))
|
| 313 |
+
self.partitions = new_callables
|
| 314 |
+
|
| 315 |
+
def call(self, args):
|
| 316 |
+
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1 = args
|
| 317 |
+
args.clear()
|
| 318 |
+
assert_size_stride(arg0_1, (1, 2304, 36864), (84934656, 36864, 1))
|
| 319 |
+
assert_size_stride(arg1_1, (128, ), (1, ))
|
| 320 |
+
assert_size_stride(arg2_1, (128, ), (1, ))
|
| 321 |
+
assert_size_stride(arg3_1, (2304, 128), (128, 1))
|
| 322 |
+
assert_size_stride(arg4_1, (2304, 128), (128, 1))
|
| 323 |
+
with torch.cuda._DeviceGuard(0):
|
| 324 |
+
torch.cuda.set_device(0)
|
| 325 |
+
buf2 = empty_strided_cuda((1, 2304, 32, 64, 2), (9437184, 4096, 128, 2, 1), torch.bfloat16)
|
| 326 |
+
buf3 = reinterpret_tensor(buf2, (1, 2304, 32, 128), (9437184, 4096, 128, 1), 0); del buf2 # reuse
|
| 327 |
+
buf4 = empty_strided_cuda((1, 2304, 32, 64, 2), (9437184, 4096, 128, 2, 1), torch.bfloat16)
|
| 328 |
+
buf5 = reinterpret_tensor(buf4, (1, 2304, 32, 128), (9437184, 4096, 128, 1), 0); del buf4 # reuse
|
| 329 |
+
# Topologically Sorted Source Nodes: [split, chunk, query_1, query_2, reshape, unbind, key_1, key_2, reshape_1, unbind_1, float_1, cos, mul, neg, stack, x_rotated, float_2, sin, mul_1, add, out, float_3, cos_2, mul_2, neg_1, stack_1, x_rotated_1, float_4, sin_2, mul_3, add_1, out_1], Original ATen: [aten.split_with_sizes, aten.split, aten.view, aten._fused_rms_norm, aten.unbind, aten._to_copy, aten.unsqueeze, aten.mul, aten.neg, aten.stack, aten.add]
|
| 330 |
+
stream0 = get_raw_stream(0)
|
| 331 |
+
triton_red_fused__fused_rms_norm__to_copy_add_mul_neg_split_split_with_sizes_stack_unbind_unsqueeze_view_0.run(buf3, buf5, arg0_1, arg1_1, arg3_1, arg4_1, arg2_1, 73728, 128, stream=stream0)
|
| 332 |
+
del arg1_1
|
| 333 |
+
del arg2_1
|
| 334 |
+
del arg3_1
|
| 335 |
+
del arg4_1
|
| 336 |
+
return (buf3, buf5, reinterpret_tensor(arg0_1, (1, 2304, 32, 128), (84934656, 36864, 128, 1), 8192), reinterpret_tensor(arg0_1, (1, 2304, 24576), (84934656, 36864, 1), 12288), )
|
| 337 |
+
|
| 338 |
+
runner = Runner(partitions=[])
|
| 339 |
+
call = runner.call
|
| 340 |
+
recursively_apply_fns = runner.recursively_apply_fns
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 344 |
+
from torch._dynamo.testing import rand_strided
|
| 345 |
+
from torch._inductor.utils import print_performance
|
| 346 |
+
arg0_1 = rand_strided((1, 2304, 36864), (84934656, 36864, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 347 |
+
arg1_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
|
| 348 |
+
arg2_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
|
| 349 |
+
arg3_1 = rand_strided((2304, 128), (128, 1), device='cuda:0', dtype=torch.float32)
|
| 350 |
+
arg4_1 = rand_strided((2304, 128), (128, 1), device='cuda:0', dtype=torch.float32)
|
| 351 |
+
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1])
|
| 352 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
if __name__ == "__main__":
|
| 356 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 357 |
+
compiled_module_main('None', benchmark_compiled_module)
|
torchinductor/6k/abd9e26dfce6bf628201c09f1f90f4340fdaab3cc2dd99f7186afe82fe013d1a.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 1024, "num_warps": 4, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 35, "triton_cache_hash": "Q5QIKEPJDRH7FHZ6CDBLMD5Y4GTGU6Y7IAWNFLJIJRNGOB7RFV4Q"}
|
torchinductor/6k/c6kat5g7n3uukkfwgdxfwtmxblcmqzhbifo5g3rbaz3djxii35gi.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'x': 1048576},
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 8396800}},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused_add_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
|
| 19 |
+
xnumel = 1048576
|
| 20 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 21 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 22 |
+
xmask = tl.full([XBLOCK], True, tl.int1)
|
| 23 |
+
x2 = xindex
|
| 24 |
+
x0 = (xindex % 4096)
|
| 25 |
+
tmp0 = tl.load(in_ptr0 + (x2), None).to(tl.float32)
|
| 26 |
+
tmp1 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
|
| 27 |
+
tmp2 = tl.load(in_ptr2 + (x2), None).to(tl.float32)
|
| 28 |
+
tmp3 = tmp1 * tmp2
|
| 29 |
+
tmp4 = tmp0 + tmp3
|
| 30 |
+
tl.store(out_ptr0 + (x2), tmp4, None)
|
torchinductor/6w/4fb0f9adeff50e9452e8fd238a1808052c095c59a0b2f1d9f3f7d7106bd1ede5.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 1024, "num_warps": 4, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 68, "triton_cache_hash": "6DP457QWOYDYHZ7TARQ4OLPDLGEKSLMNVUE3G2KQSQUSVRN7FBVA"}
|
torchinductor/6w/c6w6sg4v3bcighwokzq6i43tl5xk5vz7zevuk7jhvlqyryyuhueo.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'x': 33554432},
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_silu_split_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 201326592}},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused_mul_silu_split_0(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
|
| 19 |
+
xnumel = 25165824
|
| 20 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 21 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 22 |
+
xmask = tl.full([XBLOCK], True, tl.int1)
|
| 23 |
+
x0 = (xindex % 12288)
|
| 24 |
+
x1 = xindex // 12288
|
| 25 |
+
x2 = xindex
|
| 26 |
+
tmp0 = tl.load(in_ptr0 + (x0 + 24576*x1), None).to(tl.float32)
|
| 27 |
+
tmp5 = tl.load(in_ptr0 + (12288 + x0 + 24576*x1), None).to(tl.float32)
|
| 28 |
+
tmp1 = tmp0.to(tl.float32)
|
| 29 |
+
tmp2 = tl.sigmoid(tmp1)
|
| 30 |
+
tmp3 = tmp1 * tmp2
|
| 31 |
+
tmp4 = tmp3.to(tl.float32)
|
| 32 |
+
tmp6 = tmp4 * tmp5
|
| 33 |
+
tl.store(out_ptr0 + (x2), tmp6, None)
|
torchinductor/7f/be95397d0c18f43f4314e0cac66d456d9d3e2b12116963a4bf988016e97f7a5e.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 1024, "num_warps": 4, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 45, "triton_cache_hash": "EQOEBZDPMDVSX6EJFLBNKY5DUKJXFLSS4SF4QQZQUN6AV3JLHKJQ"}
|
torchinductor/7f/c7ff4ib6652ojllutm4c7mkzzpybond3pagu3glspw3sztkfe2za.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'x': 8388608},
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 67117056}},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused_add_mul_1(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
|
| 19 |
+
xnumel = 8388608
|
| 20 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 21 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 22 |
+
xmask = tl.full([XBLOCK], True, tl.int1)
|
| 23 |
+
x2 = xindex
|
| 24 |
+
x0 = (xindex % 4096)
|
| 25 |
+
tmp0 = tl.load(in_ptr0 + (x2), None).to(tl.float32)
|
| 26 |
+
tmp1 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
|
| 27 |
+
tmp2 = tl.load(in_ptr2 + (x2), None).to(tl.float32)
|
| 28 |
+
tmp3 = tmp1 * tmp2
|
| 29 |
+
tmp4 = tmp0 + tmp3
|
| 30 |
+
tl.store(out_ptr0 + (x2), tmp4, None)
|
torchinductor/a3/94dc88253134d772dc28ed260760d9a0059b054d472700be3c22dd06b228f22f.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 1, "R0_BLOCK": 2048, "num_warps": 16, "num_stages": 1, "configs_hash": "ba27f374f6982634f1ab959ad1e63f726920cfc2c7c821f8e68ec55c3d4d94fc", "found_by_coordesc": false, "time_taken_ms": 35, "triton_cache_hash": "H6VG26TW2DOV7R3PXVPFDX6HZCVIESL5ZYKZWLUWKZYONCE6NSLQ"}
|
torchinductor/a3/ca3menlfuldthgmncfpjk452xkros7idrmil6pcoeigraymcg4e6.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.reduction(
|
| 11 |
+
size_hints={'x': 256, 'r0_': 4096},
|
| 12 |
+
reduction_hint=ReductionHint.INNER,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr0': '*bf16', 'out_ptr3': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_add_mul_native_layer_norm_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 6, 'num_store': 2, 'num_reduction': 2, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 12607488}}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_red_fused_add_mul_native_layer_norm_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
| 19 |
+
xnumel = 256
|
| 20 |
+
r0_numel = 4096
|
| 21 |
+
rnumel = r0_numel
|
| 22 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 23 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 24 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 25 |
+
xmask = xindex < xnumel
|
| 26 |
+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
| 27 |
+
rbase = r0_base
|
| 28 |
+
x0 = xindex
|
| 29 |
+
tmp7_mean = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
|
| 30 |
+
tmp7_m2 = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
|
| 31 |
+
tmp7_weight = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
|
| 32 |
+
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
|
| 33 |
+
r0_index = r0_offset + r0_base
|
| 34 |
+
r0_mask = r0_index < r0_numel
|
| 35 |
+
roffset = r0_offset
|
| 36 |
+
rindex = r0_index
|
| 37 |
+
r0_1 = r0_index
|
| 38 |
+
tmp0 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 39 |
+
tmp1 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 40 |
+
tmp2 = tl.load(in_ptr2 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 41 |
+
tmp3 = tmp1 * tmp2
|
| 42 |
+
tmp4 = tmp0 + tmp3
|
| 43 |
+
tmp5 = tmp4.to(tl.float32)
|
| 44 |
+
tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK])
|
| 45 |
+
tmp7_mean_next, tmp7_m2_next, tmp7_weight_next = triton_helpers.welford_reduce(
|
| 46 |
+
tmp6, tmp7_mean, tmp7_m2, tmp7_weight, roffset == 0
|
| 47 |
+
)
|
| 48 |
+
tmp7_mean = tl.where(r0_mask & xmask, tmp7_mean_next, tmp7_mean)
|
| 49 |
+
tmp7_m2 = tl.where(r0_mask & xmask, tmp7_m2_next, tmp7_m2)
|
| 50 |
+
tmp7_weight = tl.where(r0_mask & xmask, tmp7_weight_next, tmp7_weight)
|
| 51 |
+
tl.store(out_ptr0 + (r0_1 + 4096*x0), tmp4, r0_mask & xmask)
|
| 52 |
+
tmp8, tmp9, tmp10 = triton_helpers.welford(tmp7_mean, tmp7_m2, tmp7_weight, 1)
|
| 53 |
+
tmp7 = tmp8[:, None]
|
| 54 |
+
tmp11 = tmp9[:, None]
|
| 55 |
+
tmp12 = tmp10[:, None]
|
| 56 |
+
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
|
| 57 |
+
r0_index = r0_offset + r0_base
|
| 58 |
+
r0_mask = r0_index < r0_numel
|
| 59 |
+
roffset = r0_offset
|
| 60 |
+
rindex = r0_index
|
| 61 |
+
r0_1 = r0_index
|
| 62 |
+
tmp13 = tl.load(out_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 63 |
+
tmp23 = tl.load(in_ptr3 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 64 |
+
tmp27 = tl.load(in_ptr4 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 65 |
+
tmp14 = tmp13.to(tl.float32)
|
| 66 |
+
tmp15 = tmp14 - tmp7
|
| 67 |
+
tmp16 = 4096.0
|
| 68 |
+
tmp17 = (tmp11 / tmp16)
|
| 69 |
+
tmp18 = 1e-06
|
| 70 |
+
tmp19 = tmp17 + tmp18
|
| 71 |
+
tmp20 = libdevice.rsqrt(tmp19)
|
| 72 |
+
tmp21 = tmp15 * tmp20
|
| 73 |
+
tmp22 = tmp21.to(tl.float32)
|
| 74 |
+
tmp24 = 1.0
|
| 75 |
+
tmp25 = tmp23 + tmp24
|
| 76 |
+
tmp26 = tmp22 * tmp25
|
| 77 |
+
tmp28 = tmp26 + tmp27
|
| 78 |
+
tl.store(out_ptr3 + (r0_1 + 4096*x0), tmp28, r0_mask & xmask)
|
torchinductor/aotautograd/a27rkqg32yfaub3aygtms2gl3oet2qxfcnp4zxa3zy5h6c3risxz/aw5eda3h36wpnnltujgkb4mvobznersd4fuvo2p7vy2quujasos
ADDED
|
Binary file (52.3 kB). View file
|
|
|
torchinductor/aotautograd/a3443o3ywoehrda4trn5q47mauudwcinftvd52hitdnfmakyhqc4/lw6yvpbd45y77sg6fh5v4otbinchkwbf7b56u3rh3wgq3x2wkhq
ADDED
|
Binary file (54.7 kB). View file
|
|
|
torchinductor/aotautograd/a3554ihbxq57jan4ib74iqo5mnaqevqume4yzewukzkm6ehpsilz/eubahghkef62rmchvnle5v6h3ddip4av5qqjxomdlm7ura45qve
ADDED
|
Binary file (54.9 kB). View file
|
|
|
torchinductor/aotautograd/a3hojixb5fzn7f7jfco3ddoohdsuggk4qbop3lcg7rjy3e7fkgfz/o7wvolbgborwtoofbovayor23y4ubooymfcvv6jeqm2wbx3n2cs
ADDED
|
Binary file (74.3 kB). View file
|
|
|
torchinductor/aotautograd/a54twb2qknddjxnxtmkoagy3umo5y3ptsesm2pdhy7nkefklf6wx/emxzj524wmpvifsxw4dsnnkzemqpzfgkenbo5obwmksvlhsr354
ADDED
|
Binary file (54.7 kB). View file
|
|
|
torchinductor/aotautograd/a5ksywxhfabbequvxwstheyyj5w3sinuubxcrypqjwbqsyw5la3l/ew4fxjyfoflznyuws2w2ylu4p7owpjuqoshsef75w43w2vvwejd
ADDED
|
Binary file (55.1 kB). View file
|
|
|
torchinductor/aotautograd/a7ptufzlocphh5n5o5u63gfzkf74tjb3l5is45u5hqjspv32qda6/an4kgppgf4vt5yfvvghrnmho6jc3qnj4l6c75zrsiotr5d4u5gv
ADDED
|
Binary file (55.1 kB). View file
|
|
|
torchinductor/aotautograd/aal6kceyfi7eazavxzpgcec5hzt32bkwo7p4doeyc56ubzlwuvx4/nkoni3ckgbheucucq64bmrta4lhz7x237lalaqcrejvdc3supg4
ADDED
|
Binary file (55.1 kB). View file
|
|
|
torchinductor/aotautograd/aan5kpy6i54rnpeu5vlzbx6i6blimsvhducl7futzdjr4xciy472/a35s4usnkzmh6ybhedo3b6zehfepmwdv2gxscayjeeuucr3zat7
ADDED
|
Binary file (54.6 kB). View file
|
|
|
torchinductor/aotautograd/aesonb7djseswkbtu2qzhvg6ikd5rewxnqlt6pwuytadpxxmjcod/lap2sypphhofd6d5rhojruk2vfyvw2olc7gtulmom4i5y7ix2cp
ADDED
|
Binary file (62.5 kB). View file
|
|
|
torchinductor/aotautograd/age65c4dyk2rxcqufpxd6bsafzao7tacrsvejbf3pjbsngnoashv/upzttal3jaj233iyzyps7mjpq75jt6qi6rzramvgyyewfg76h6s
ADDED
|
Binary file (83.4 kB). View file
|
|
|
torchinductor/aotautograd/ahkpwjcp2qqyj6wu2ckjqlrit2pbb3ig3ddi75hgbkgngvvipwyq/ha76p7wv3nimmrgvx6kdiqikd6adbw7nlnaiars5ey4anx46mwn
ADDED
|
Binary file (54.4 kB). View file
|
|
|
torchinductor/aotautograd/aiojzczi5txclvaydkrk5g3qlf33pdkkhxtefkhfphkpc3o6rr4p/w3n37k3qhqfhuewneurnairyblp3h7nrak6oyp2p3um7uwnfcz5
ADDED
|
Binary file (52 kB). View file
|
|
|
torchinductor/aotautograd/ajdkg3gacw25klanvqotc3mkab3mi23jtjpagxrosdmqv3d4yg7v/ejzrqbsrchqzxfppkzo4ep7edhv7lrjjbcdxkxvodbk4vvk3b62
ADDED
|
Binary file (52.7 kB). View file
|
|
|
torchinductor/aotautograd/amb262dx57ptj6gg2ch6skr372w6arsr3i7i4ed5pljhiycuxduw/fntav2w4z5lvr443jxseqalau2vuzp7x7ljd3hanoqubtutjkvp
ADDED
|
Binary file (56.9 kB). View file
|
|
|
torchinductor/aotautograd/amjjivi2p6firai3idkjgfxyy6z4prevujsjdno2uuchwvd7xqll/enc6ruqcyggs4mnt54tjdd2lvexcvipd5vhhamxwcj77g5fpyof
ADDED
|
Binary file (54.7 kB). View file
|
|
|
torchinductor/aotautograd/apfaqlwe555qd2zoz575w5mvoxoiasmcomkv76mhz5zvnm5jok66/epmli5r46rzrqf73pqrnb5tratdg3mbbwdf5vyzqr6ejyhnooye
ADDED
|
Binary file (55.1 kB). View file
|
|
|
torchinductor/aotautograd/asjbg7f735jw54kcldmvv5uost22wzpy3hkxgaihos4rllvagheu/lwqpsnp52rszp2nlwkgi33embno5st2u5bxfm4rpyoy6fql5aor
ADDED
|
Binary file (55.1 kB). View file
|
|
|
torchinductor/aotautograd/atc2ggqhejcse5aydwh2wjakijsc2dyhqjxwdqrwpra3mgjwe4st/xwy7lzraqocjillvk4s2yc2qhpkx43s2nbkxmeb2wpph3sgyc7n
ADDED
|
Binary file (56.1 kB). View file
|
|
|
torchinductor/aotautograd/atsevoi6zqdcnehuxassvjosi3j5vrk54uisibylfgspeewp6vyx/4sfzv7d6ch2yoi6nnr5ym3i6yibku3vfveyrr6sx6dqbmavxo32
ADDED
|
Binary file (55.1 kB). View file
|
|
|
torchinductor/aotautograd/ax7bbwqbruobasu7vagn2oj2owh5vgosxbjelta324rvf4tkesd4/ipnutob47ydixp2zetluyw4apg7fe5sfkkiianwaawh6yq3uang
ADDED
|
Binary file (52.8 kB). View file
|
|
|
torchinductor/aotautograd/ay26zyuzpll2prvy7zzoeydo7r47lrr6s6jcmzi2zmytjxzebmnz/nzx7lukg3r25p6sjlwtqmkf6gmgzuq7iwagwki2x4kvhw5ducr5
ADDED
|
Binary file (59.5 kB). View file
|
|
|
torchinductor/aotautograd/ay65riayezoo7bqggl72pzrzdi6lvy5mp23ajx4f453ylzpmve3s/p7clvcke3bsgsaumutstrxc7bkq4tq6yoia7nwigana3n3unini
ADDED
|
Binary file (56.5 kB). View file
|
|
|
torchinductor/aotautograd/azyih32olvhzuay5zpfypzhk2cdlosvaqxdhcnjzlwfs6k3a2ne6/5sz2kjdze7ixdny7hz24p4uma7uup7chdcpiumqznifqn4mpmqb
ADDED
|
Binary file (62.7 kB). View file
|
|
|
torchinductor/av/cavoaz6e7kbk5wq2n7vz6rxhcrwdu2trazexubdq5qwyv2ajmbkz.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.reduction(
|
| 11 |
+
size_hints={'x': 256, 'r0_': 4096},
|
| 12 |
+
reduction_hint=ReductionHint.INNER,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr2': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_add_mul_native_layer_norm_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 2, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 6307840}}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_red_fused_add_mul_native_layer_norm_1(in_ptr0, in_ptr1, in_ptr2, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
| 19 |
+
xnumel = 256
|
| 20 |
+
r0_numel = 4096
|
| 21 |
+
rnumel = r0_numel
|
| 22 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 23 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 24 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 25 |
+
xmask = xindex < xnumel
|
| 26 |
+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
| 27 |
+
rbase = r0_base
|
| 28 |
+
x0 = xindex
|
| 29 |
+
tmp3_mean = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
|
| 30 |
+
tmp3_m2 = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
|
| 31 |
+
tmp3_weight = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
|
| 32 |
+
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
|
| 33 |
+
r0_index = r0_offset + r0_base
|
| 34 |
+
r0_mask = r0_index < r0_numel
|
| 35 |
+
roffset = r0_offset
|
| 36 |
+
rindex = r0_index
|
| 37 |
+
r0_1 = r0_index
|
| 38 |
+
tmp0 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 39 |
+
tmp1 = tmp0.to(tl.float32)
|
| 40 |
+
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
|
| 41 |
+
tmp3_mean_next, tmp3_m2_next, tmp3_weight_next = triton_helpers.welford_reduce(
|
| 42 |
+
tmp2, tmp3_mean, tmp3_m2, tmp3_weight, roffset == 0
|
| 43 |
+
)
|
| 44 |
+
tmp3_mean = tl.where(r0_mask & xmask, tmp3_mean_next, tmp3_mean)
|
| 45 |
+
tmp3_m2 = tl.where(r0_mask & xmask, tmp3_m2_next, tmp3_m2)
|
| 46 |
+
tmp3_weight = tl.where(r0_mask & xmask, tmp3_weight_next, tmp3_weight)
|
| 47 |
+
tmp4, tmp5, tmp6 = triton_helpers.welford(tmp3_mean, tmp3_m2, tmp3_weight, 1)
|
| 48 |
+
tmp3 = tmp4[:, None]
|
| 49 |
+
tmp7 = tmp5[:, None]
|
| 50 |
+
tmp8 = tmp6[:, None]
|
| 51 |
+
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
|
| 52 |
+
r0_index = r0_offset + r0_base
|
| 53 |
+
r0_mask = r0_index < r0_numel
|
| 54 |
+
roffset = r0_offset
|
| 55 |
+
rindex = r0_index
|
| 56 |
+
r0_1 = r0_index
|
| 57 |
+
tmp9 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 58 |
+
tmp12 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 59 |
+
tmp23 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 60 |
+
tmp10 = 1.0
|
| 61 |
+
tmp11 = tmp9 + tmp10
|
| 62 |
+
tmp13 = tmp12.to(tl.float32)
|
| 63 |
+
tmp14 = tmp13 - tmp3
|
| 64 |
+
tmp15 = 4096.0
|
| 65 |
+
tmp16 = (tmp7 / tmp15)
|
| 66 |
+
tmp17 = 1e-06
|
| 67 |
+
tmp18 = tmp16 + tmp17
|
| 68 |
+
tmp19 = libdevice.rsqrt(tmp18)
|
| 69 |
+
tmp20 = tmp14 * tmp19
|
| 70 |
+
tmp21 = tmp20.to(tl.float32)
|
| 71 |
+
tmp22 = tmp11 * tmp21
|
| 72 |
+
tmp24 = tmp22 + tmp23
|
| 73 |
+
tl.store(out_ptr2 + (r0_1 + 4096*x0), tmp24, r0_mask & xmask)
|
torchinductor/av/d186a24d3c8af5514b42dea48fc981efd3f5afb7bba6c30406e42c75862888b1.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 1, "R0_BLOCK": 4096, "num_warps": 16, "num_stages": 1, "configs_hash": "ba27f374f6982634f1ab959ad1e63f726920cfc2c7c821f8e68ec55c3d4d94fc", "found_by_coordesc": false, "time_taken_ms": 33, "triton_cache_hash": "CYKNGA4OMPRI7EV7H5FM47DKU7VFZ4Q5NYQGPNW6ZIVYBLBWPVMA"}
|
torchinductor/ay/cayicsdjyjxzpcmkvjbneubnqkuhs3y37qiwy5qlel3z2loa4qav.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AOT ID: ['1_inference']
|
| 2 |
+
from ctypes import c_void_p, c_long, c_int
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from math import inf, nan
|
| 9 |
+
from cmath import nanj
|
| 10 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 11 |
+
from torch._inductor.utils import maybe_profile
|
| 12 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 13 |
+
from torch import device, empty_strided
|
| 14 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 15 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 16 |
+
|
| 17 |
+
aten = torch.ops.aten
|
| 18 |
+
inductor_ops = torch.ops.inductor
|
| 19 |
+
_quantized = torch.ops._quantized
|
| 20 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 21 |
+
assert_alignment = torch._C._dynamo.guards.assert_alignment
|
| 22 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 23 |
+
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
|
| 24 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 25 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 26 |
+
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
|
| 27 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 28 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 29 |
+
async_compile = AsyncCompile()
|
| 30 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
async_compile.wait(globals())
|
| 34 |
+
del async_compile
|
| 35 |
+
|
| 36 |
+
class Runner:
|
| 37 |
+
def __init__(self, partitions):
|
| 38 |
+
self.partitions = partitions
|
| 39 |
+
|
| 40 |
+
def recursively_apply_fns(self, fns):
|
| 41 |
+
new_callables = []
|
| 42 |
+
for fn, c in zip(fns, self.partitions):
|
| 43 |
+
new_callables.append(fn(c))
|
| 44 |
+
self.partitions = new_callables
|
| 45 |
+
|
| 46 |
+
def call(self, args):
|
| 47 |
+
arg0_1, arg1_1 = args
|
| 48 |
+
args.clear()
|
| 49 |
+
assert_size_stride(arg0_1, (4096, 12288), (1, 4096))
|
| 50 |
+
assert_size_stride(arg1_1, (1, 1), (1, 1))
|
| 51 |
+
return (aten.view.dtype(reinterpret_tensor(arg0_1, (12288, 4096), (4096, 1), 0), torch.uint8), reinterpret_tensor(arg1_1, (1, ), (1, ), 0), )
|
| 52 |
+
|
| 53 |
+
runner = Runner(partitions=[])
|
| 54 |
+
call = runner.call
|
| 55 |
+
recursively_apply_fns = runner.recursively_apply_fns
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 59 |
+
from torch._dynamo.testing import rand_strided
|
| 60 |
+
from torch._inductor.utils import print_performance
|
| 61 |
+
arg0_1 = rand_strided((4096, 12288), (1, 4096), device='cuda:0', dtype=torch.float8_e4m3fn)
|
| 62 |
+
arg1_1 = rand_strided((1, 1), (1, 1), device='cuda:0', dtype=torch.float32)
|
| 63 |
+
fn = lambda: call([arg0_1, arg1_1])
|
| 64 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
if __name__ == "__main__":
|
| 68 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 69 |
+
compiled_module_main('None', benchmark_compiled_module)
|
torchinductor/bv/7969eba2eb589b95d2894ee75ee67ba01cd2bee09cd64d315c70c0950888c19e.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 2, "R0_BLOCK": 128, "num_warps": 2, "num_stages": 1, "configs_hash": "6ffa43f2ca8cb1499f3ff3fbf8c975f2c07eef9b57fcecda113029ab12cbef66", "found_by_coordesc": false, "time_taken_ms": 307, "triton_cache_hash": "AQ3FCZKOYK5LBOX7RLBQGX5T77RKI4M7SEZTYJU34QROQSJNLP5A"}
|
torchinductor/bv/cbvqhjtyg7fvxzwtbtt4vrdkbnb6n32fnrijjpl3vv4cfqd4mznr.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.reduction(
|
| 11 |
+
size_hints={'x': 131072, 'r0_': 128},
|
| 12 |
+
reduction_hint=ReductionHint.DEFAULT,
|
| 13 |
+
filename=__file__,
|
| 14 |
+
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_out_ptr1': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
|
| 15 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__fused_rms_norm__to_copy_add_mul_neg_split_split_with_sizes_stack_unbind_unsqueeze_view_0', 'mutated_arg_names': ['in_out_ptr0', 'in_out_ptr1'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 16, 'num_store': 2, 'num_reduction': 2, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 115606016}}
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_red_fused__fused_rms_norm__to_copy_add_mul_neg_split_split_with_sizes_stack_unbind_unsqueeze_view_0(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
| 19 |
+
xnumel = 73728
|
| 20 |
+
r0_numel = 128
|
| 21 |
+
rnumel = r0_numel
|
| 22 |
+
RBLOCK: tl.constexpr = R0_BLOCK
|
| 23 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 24 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
| 25 |
+
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
|
| 26 |
+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
|
| 27 |
+
rbase = r0_base
|
| 28 |
+
x0 = (xindex % 32)
|
| 29 |
+
x1 = xindex // 32
|
| 30 |
+
_tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
|
| 31 |
+
x5 = xindex
|
| 32 |
+
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
|
| 33 |
+
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
|
| 34 |
+
r0_index = r0_offset + r0_base
|
| 35 |
+
r0_mask = r0_index < r0_numel
|
| 36 |
+
roffset = r0_offset
|
| 37 |
+
rindex = r0_index
|
| 38 |
+
r0_2 = r0_index
|
| 39 |
+
tmp0 = tl.load(in_ptr0 + (4096 + r0_2 + 128*x0 + 36864*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 40 |
+
tmp6 = tl.load(in_ptr0 + (r0_2 + 128*x0 + 36864*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 41 |
+
tmp1 = tmp0.to(tl.float32)
|
| 42 |
+
tmp2 = tmp1 * tmp1
|
| 43 |
+
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
|
| 44 |
+
tmp5 = _tmp4 + tmp3
|
| 45 |
+
_tmp4 = tl.where(r0_mask, tmp5, _tmp4)
|
| 46 |
+
tmp7 = tmp6.to(tl.float32)
|
| 47 |
+
tmp8 = tmp7 * tmp7
|
| 48 |
+
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
|
| 49 |
+
tmp11 = _tmp10 + tmp9
|
| 50 |
+
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
|
| 51 |
+
tmp4 = tl.sum(_tmp4, 1)[:, None]
|
| 52 |
+
tmp10 = tl.sum(_tmp10, 1)[:, None]
|
| 53 |
+
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
|
| 54 |
+
r0_index = r0_offset + r0_base
|
| 55 |
+
r0_mask = r0_index < r0_numel
|
| 56 |
+
roffset = r0_offset
|
| 57 |
+
rindex = r0_index
|
| 58 |
+
r0_3 = (r0_index % 2)
|
| 59 |
+
r0_4 = r0_index // 2
|
| 60 |
+
r0_2 = r0_index
|
| 61 |
+
tmp50 = tl.load(in_ptr0 + (r0_2 + 128*x0 + 36864*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 62 |
+
tmp58 = tl.load(in_ptr1 + (r0_2), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 63 |
+
tmp63 = tl.load(in_ptr2 + (r0_2 + 128*x1), r0_mask, eviction_policy='evict_last', other=0.0)
|
| 64 |
+
tmp66 = tl.load(in_ptr3 + (r0_2 + 128*x1), r0_mask, eviction_policy='evict_last', other=0.0)
|
| 65 |
+
tmp96 = tl.load(in_ptr0 + (4096 + r0_2 + 128*x0 + 36864*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
|
| 66 |
+
tmp102 = tl.load(in_ptr4 + (r0_2), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 67 |
+
tmp12 = r0_3
|
| 68 |
+
tmp13 = tl.full([1, 1], 0, tl.int64)
|
| 69 |
+
tmp14 = tmp12 >= tmp13
|
| 70 |
+
tmp15 = tl.full([1, 1], 1, tl.int64)
|
| 71 |
+
tmp16 = tmp12 < tmp15
|
| 72 |
+
tmp17 = tl.load(in_ptr0 + (1 + 2*r0_4 + 128*x0 + 36864*x1), r0_mask & tmp16, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 73 |
+
tmp18 = tmp17.to(tl.float32)
|
| 74 |
+
tmp19 = 128.0
|
| 75 |
+
tmp20 = (tmp10 / tmp19)
|
| 76 |
+
tmp21 = 1e-06
|
| 77 |
+
tmp22 = tmp20 + tmp21
|
| 78 |
+
tmp23 = libdevice.rsqrt(tmp22)
|
| 79 |
+
tmp24 = tmp18 * tmp23
|
| 80 |
+
tmp25 = tl.load(in_ptr1 + (tl.broadcast_to(1 + 2*r0_4, [XBLOCK, R0_BLOCK])), r0_mask & tmp16, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 81 |
+
tmp26 = tmp25.to(tl.float32)
|
| 82 |
+
tmp27 = tmp24 * tmp26
|
| 83 |
+
tmp28 = tmp27.to(tl.float32)
|
| 84 |
+
tmp29 = -tmp28
|
| 85 |
+
tmp30 = tl.full(tmp29.shape, 0.0, tmp29.dtype)
|
| 86 |
+
tmp31 = tl.where(tmp16, tmp29, tmp30)
|
| 87 |
+
tmp32 = tmp12 >= tmp15
|
| 88 |
+
tmp33 = tl.full([1, 1], 2, tl.int64)
|
| 89 |
+
tmp34 = tmp12 < tmp33
|
| 90 |
+
tmp35 = tl.load(in_ptr0 + (2*r0_4 + 128*x0 + 36864*x1), r0_mask & tmp32, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 91 |
+
tmp36 = tmp35.to(tl.float32)
|
| 92 |
+
tmp37 = 128.0
|
| 93 |
+
tmp38 = (tmp10 / tmp37)
|
| 94 |
+
tmp39 = 1e-06
|
| 95 |
+
tmp40 = tmp38 + tmp39
|
| 96 |
+
tmp41 = libdevice.rsqrt(tmp40)
|
| 97 |
+
tmp42 = tmp36 * tmp41
|
| 98 |
+
tmp43 = tl.load(in_ptr1 + (tl.broadcast_to(2*r0_4, [XBLOCK, R0_BLOCK])), r0_mask & tmp32, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 99 |
+
tmp44 = tmp43.to(tl.float32)
|
| 100 |
+
tmp45 = tmp42 * tmp44
|
| 101 |
+
tmp46 = tmp45.to(tl.float32)
|
| 102 |
+
tmp47 = tl.full(tmp46.shape, 0.0, tmp46.dtype)
|
| 103 |
+
tmp48 = tl.where(tmp32, tmp46, tmp47)
|
| 104 |
+
tmp49 = tl.where(tmp16, tmp31, tmp48)
|
| 105 |
+
tmp51 = tmp50.to(tl.float32)
|
| 106 |
+
tmp52 = 128.0
|
| 107 |
+
tmp53 = (tmp10 / tmp52)
|
| 108 |
+
tmp54 = 1e-06
|
| 109 |
+
tmp55 = tmp53 + tmp54
|
| 110 |
+
tmp56 = libdevice.rsqrt(tmp55)
|
| 111 |
+
tmp57 = tmp51 * tmp56
|
| 112 |
+
tmp59 = tmp58.to(tl.float32)
|
| 113 |
+
tmp60 = tmp57 * tmp59
|
| 114 |
+
tmp61 = tmp60.to(tl.float32)
|
| 115 |
+
tmp62 = tmp61.to(tl.float32)
|
| 116 |
+
tmp64 = tmp62 * tmp63
|
| 117 |
+
tmp65 = tmp49.to(tl.float32)
|
| 118 |
+
tmp67 = tmp65 * tmp66
|
| 119 |
+
tmp68 = tmp64 + tmp67
|
| 120 |
+
tmp69 = tmp68.to(tl.float32)
|
| 121 |
+
tmp70 = tl.load(in_ptr0 + (4097 + 2*r0_4 + 128*x0 + 36864*x1), r0_mask & tmp16, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 122 |
+
tmp71 = tmp70.to(tl.float32)
|
| 123 |
+
tmp72 = (tmp4 / tmp19)
|
| 124 |
+
tmp73 = tmp72 + tmp21
|
| 125 |
+
tmp74 = libdevice.rsqrt(tmp73)
|
| 126 |
+
tmp75 = tmp71 * tmp74
|
| 127 |
+
tmp76 = tl.load(in_ptr4 + (tl.broadcast_to(1 + 2*r0_4, [XBLOCK, R0_BLOCK])), r0_mask & tmp16, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 128 |
+
tmp77 = tmp76.to(tl.float32)
|
| 129 |
+
tmp78 = tmp75 * tmp77
|
| 130 |
+
tmp79 = tmp78.to(tl.float32)
|
| 131 |
+
tmp80 = -tmp79
|
| 132 |
+
tmp81 = tl.full(tmp80.shape, 0.0, tmp80.dtype)
|
| 133 |
+
tmp82 = tl.where(tmp16, tmp80, tmp81)
|
| 134 |
+
tmp83 = tl.load(in_ptr0 + (4096 + 2*r0_4 + 128*x0 + 36864*x1), r0_mask & tmp32, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 135 |
+
tmp84 = tmp83.to(tl.float32)
|
| 136 |
+
tmp85 = (tmp4 / tmp37)
|
| 137 |
+
tmp86 = tmp85 + tmp39
|
| 138 |
+
tmp87 = libdevice.rsqrt(tmp86)
|
| 139 |
+
tmp88 = tmp84 * tmp87
|
| 140 |
+
tmp89 = tl.load(in_ptr4 + (tl.broadcast_to(2*r0_4, [XBLOCK, R0_BLOCK])), r0_mask & tmp32, eviction_policy='evict_last', other=0.0).to(tl.float32)
|
| 141 |
+
tmp90 = tmp89.to(tl.float32)
|
| 142 |
+
tmp91 = tmp88 * tmp90
|
| 143 |
+
tmp92 = tmp91.to(tl.float32)
|
| 144 |
+
tmp93 = tl.full(tmp92.shape, 0.0, tmp92.dtype)
|
| 145 |
+
tmp94 = tl.where(tmp32, tmp92, tmp93)
|
| 146 |
+
tmp95 = tl.where(tmp16, tmp82, tmp94)
|
| 147 |
+
tmp97 = tmp96.to(tl.float32)
|
| 148 |
+
tmp98 = (tmp4 / tmp52)
|
| 149 |
+
tmp99 = tmp98 + tmp54
|
| 150 |
+
tmp100 = libdevice.rsqrt(tmp99)
|
| 151 |
+
tmp101 = tmp97 * tmp100
|
| 152 |
+
tmp103 = tmp102.to(tl.float32)
|
| 153 |
+
tmp104 = tmp101 * tmp103
|
| 154 |
+
tmp105 = tmp104.to(tl.float32)
|
| 155 |
+
tmp106 = tmp105.to(tl.float32)
|
| 156 |
+
tmp107 = tmp106 * tmp63
|
| 157 |
+
tmp108 = tmp95.to(tl.float32)
|
| 158 |
+
tmp109 = tmp108 * tmp66
|
| 159 |
+
tmp110 = tmp107 + tmp109
|
| 160 |
+
tmp111 = tmp110.to(tl.float32)
|
| 161 |
+
tl.store(in_out_ptr0 + (r0_2 + 128*x5), tmp69, r0_mask)
|
| 162 |
+
tl.store(in_out_ptr1 + (r0_2 + 128*x5), tmp111, r0_mask)
|
torchinductor/cr/ccr2gijy4jp6vvdbewmzgaogxbf5as7ytxtou4zo2yelawomrjjg.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AOT ID: ['21_inference']
|
| 2 |
+
from ctypes import c_void_p, c_long, c_int
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from math import inf, nan
|
| 9 |
+
from cmath import nanj
|
| 10 |
+
from torch._inductor.hooks import run_intermediate_hooks
|
| 11 |
+
from torch._inductor.utils import maybe_profile
|
| 12 |
+
from torch._inductor.codegen.memory_planning import _align as align
|
| 13 |
+
from torch import device, empty_strided
|
| 14 |
+
from torch._inductor.async_compile import AsyncCompile
|
| 15 |
+
from torch._inductor.select_algorithm import extern_kernels
|
| 16 |
+
import triton
|
| 17 |
+
import triton.language as tl
|
| 18 |
+
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
|
| 19 |
+
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
|
| 20 |
+
|
| 21 |
+
aten = torch.ops.aten
|
| 22 |
+
inductor_ops = torch.ops.inductor
|
| 23 |
+
_quantized = torch.ops._quantized
|
| 24 |
+
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
| 25 |
+
assert_alignment = torch._C._dynamo.guards.assert_alignment
|
| 26 |
+
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
|
| 27 |
+
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
|
| 28 |
+
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
|
| 29 |
+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
|
| 30 |
+
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
|
| 31 |
+
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
|
| 32 |
+
alloc_from_pool = torch.ops.inductor._alloc_from_pool
|
| 33 |
+
async_compile = AsyncCompile()
|
| 34 |
+
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# kernel path: /app/tensorrt_llm/visual_gen/compiled_cache/flux2_klein_9b_NVIDIA_GeForce_RTX_4090_sm89_torch2.10.0a0_b4e4ee81d3.nv25.12_cuda13_1/torchinductor/sy/csyae3ok2xnzuxhjkxzhdcpcz6jckcu3vv7eqb3pewhrvqmiergf.py
|
| 38 |
+
# Topologically Sorted Source Nodes: [chunk, silu, x], Original ATen: [aten.split, aten.silu, aten.mul]
|
| 39 |
+
# Source node to ATen node mapping:
|
| 40 |
+
# chunk => split
|
| 41 |
+
# silu => convert_element_type, convert_element_type_1, mul_6, sigmoid
|
| 42 |
+
# x => mul_10
|
| 43 |
+
# Graph fragment:
|
| 44 |
+
# %arg1_1 : Tensor "bf16[1, s67, 24576][24576*s67, 24576, 1]cuda:0" = PlaceHolder[target=arg1_1]
|
| 45 |
+
# %split : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%arg1_1, 12288, -1), kwargs = {})
|
| 46 |
+
# %convert_element_type : Tensor "f32[1, s67, 12288][12288*s67, 12288, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem, torch.float32), kwargs = {})
|
| 47 |
+
# %sigmoid : Tensor "f32[1, s67, 12288][12288*s67, 12288, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%convert_element_type,), kwargs = {})
|
| 48 |
+
# %mul_6 : Tensor "f32[1, s67, 12288][12288*s67, 12288, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %sigmoid), kwargs = {})
|
| 49 |
+
# %convert_element_type_1 : Tensor "bf16[1, s67, 12288][12288*s67, 12288, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_6, torch.bfloat16), kwargs = {})
|
| 50 |
+
# %mul_10 : Tensor "bf16[1, s67, 12288][12288*s67, 12288, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1, %getitem_1), kwargs = {})
|
| 51 |
+
# return %mul_10
|
| 52 |
+
triton_poi_fused_mul_silu_split_0 = async_compile.triton('triton_poi_fused_mul_silu_split_0', '''
|
| 53 |
+
import triton
|
| 54 |
+
import triton.language as tl
|
| 55 |
+
|
| 56 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 57 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 58 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 59 |
+
triton_helpers.set_driver_to_gpu()
|
| 60 |
+
|
| 61 |
+
@triton_heuristics.pointwise(
|
| 62 |
+
size_hints={'x': 4194304},
|
| 63 |
+
filename=__file__,
|
| 64 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
|
| 65 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_silu_split_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False},
|
| 66 |
+
min_elem_per_thread=0
|
| 67 |
+
)
|
| 68 |
+
@triton.jit
|
| 69 |
+
def triton_poi_fused_mul_silu_split_0(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
|
| 70 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 71 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 72 |
+
xmask = tl.full([XBLOCK], True, tl.int1)
|
| 73 |
+
x0 = (xindex % 12288)
|
| 74 |
+
x1 = xindex // 12288
|
| 75 |
+
x2 = xindex
|
| 76 |
+
tmp0 = tl.load(in_ptr0 + (x0 + 24576*x1), None).to(tl.float32)
|
| 77 |
+
tmp5 = tl.load(in_ptr0 + (12288 + x0 + 24576*x1), None).to(tl.float32)
|
| 78 |
+
tmp1 = tmp0.to(tl.float32)
|
| 79 |
+
tmp2 = tl.sigmoid(tmp1)
|
| 80 |
+
tmp3 = tmp1 * tmp2
|
| 81 |
+
tmp4 = tmp3.to(tl.float32)
|
| 82 |
+
tmp6 = tmp4 * tmp5
|
| 83 |
+
tl.store(out_ptr0 + (x2), tmp6, None)
|
| 84 |
+
''', device_str='cuda')
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
async_compile.wait(globals())
|
| 88 |
+
del async_compile
|
| 89 |
+
|
| 90 |
+
class Runner:
|
| 91 |
+
def __init__(self, partitions):
|
| 92 |
+
self.partitions = partitions
|
| 93 |
+
|
| 94 |
+
def recursively_apply_fns(self, fns):
|
| 95 |
+
new_callables = []
|
| 96 |
+
for fn, c in zip(fns, self.partitions):
|
| 97 |
+
new_callables.append(fn(c))
|
| 98 |
+
self.partitions = new_callables
|
| 99 |
+
|
| 100 |
+
def call(self, args):
|
| 101 |
+
arg0_1, arg1_1 = args
|
| 102 |
+
args.clear()
|
| 103 |
+
s67 = arg0_1
|
| 104 |
+
assert_size_stride(arg1_1, (1, s67, 24576), (24576*s67, 24576, 1))
|
| 105 |
+
with torch.cuda._DeviceGuard(0):
|
| 106 |
+
torch.cuda.set_device(0)
|
| 107 |
+
buf0 = empty_strided_cuda((1, s67, 12288), (12288*s67, 12288, 1), torch.bfloat16)
|
| 108 |
+
# Topologically Sorted Source Nodes: [chunk, silu, x], Original ATen: [aten.split, aten.silu, aten.mul]
|
| 109 |
+
triton_poi_fused_mul_silu_split_0_xnumel = 12288*s67
|
| 110 |
+
stream0 = get_raw_stream(0)
|
| 111 |
+
triton_poi_fused_mul_silu_split_0.run(arg1_1, buf0, triton_poi_fused_mul_silu_split_0_xnumel, stream=stream0)
|
| 112 |
+
del arg1_1
|
| 113 |
+
return (buf0, )
|
| 114 |
+
|
| 115 |
+
runner = Runner(partitions=[])
|
| 116 |
+
call = runner.call
|
| 117 |
+
recursively_apply_fns = runner.recursively_apply_fns
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def benchmark_compiled_module(times=10, repeat=10):
|
| 121 |
+
from torch._dynamo.testing import rand_strided
|
| 122 |
+
from torch._inductor.utils import print_performance
|
| 123 |
+
arg0_1 = 256
|
| 124 |
+
arg1_1 = rand_strided((1, 256, 24576), (6291456, 24576, 1), device='cuda:0', dtype=torch.bfloat16)
|
| 125 |
+
fn = lambda: call([arg0_1, arg1_1])
|
| 126 |
+
return print_performance(fn, times=times, repeat=repeat)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
if __name__ == "__main__":
|
| 130 |
+
from torch._inductor.wrapper_benchmark import compiled_module_main
|
| 131 |
+
compiled_module_main('None', benchmark_compiled_module)
|
torchinductor/cz/bb6645c6be31f426023ec47eef09e354ad9fa8b2d59e6e45ab49b803eb34d44e.best_config
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 41, "triton_cache_hash": "SWIO2NFSYH3NKX6EWLJXN7WN2QH2K7ETN3JE2BQRCZXLIIDUWOOA"}
|
torchinductor/cz/cczg7tpituprwgqpuajzy2nylfk43mdozd5vwo77muq3kospnf7b.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 6 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 7 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 8 |
+
triton_helpers.set_driver_to_gpu()
|
| 9 |
+
|
| 10 |
+
@triton_heuristics.pointwise(
|
| 11 |
+
size_hints={'x': 8388608},
|
| 12 |
+
filename=__file__,
|
| 13 |
+
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
|
| 14 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 50331648}},
|
| 15 |
+
min_elem_per_thread=0
|
| 16 |
+
)
|
| 17 |
+
@triton.jit
|
| 18 |
+
def triton_poi_fused_clone_0(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
|
| 19 |
+
xnumel = 8388608
|
| 20 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 21 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 22 |
+
xmask = tl.full([XBLOCK], True, tl.int1)
|
| 23 |
+
x0 = xindex
|
| 24 |
+
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
|
| 25 |
+
tl.store(out_ptr0 + (x0), tmp0, None)
|