carpedm20 commited on
Commit
29c669a
·
verified ·
1 Parent(s): 77ac065

Add Flux2 Klein compiled caches

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +14 -0
  2. meta.json +44 -0
  3. torchinductor/2h/a581feca05a976cd76073f2f954a7641097b9c5775b12cf6831b3149d528a8b4.best_config +1 -0
  4. torchinductor/2h/c2hij3hmloumxdmhuezsyhkmnqgnfa5ivre27uosymam3dr7a5xb.py +70 -0
  5. torchinductor/2o/c2oduffhka4c52657rppatcdtgtnibm42qywfo2spmul2dpsj6jj.py +297 -0
  6. torchinductor/3i/c3imyfibcq3zwrc5gvfscvpsaxdzjijymrnt2lfahtrjmrbtlmhe.py +45 -0
  7. torchinductor/3i/cf1587a2fd240ce39177274973308f6fd100d746bf6716a8d96ed4fd12c89d55.best_config +1 -0
  8. torchinductor/3v/4a00da1b5d4ce251d2cb392c24118fc2e6c3818f25b8457665f0d53e12234277.best_config +1 -0
  9. torchinductor/3v/c3vjilvcy7sdqcaxfspffadbsry3sqm4j7qp6u3tusr6p34sbiar.py +28 -0
  10. torchinductor/4y/c4ykjyk6fv6enet6mgkj5bsan42tc6rsdfs7aaskpjgv5rzw7tbr.py +357 -0
  11. torchinductor/6k/abd9e26dfce6bf628201c09f1f90f4340fdaab3cc2dd99f7186afe82fe013d1a.best_config +1 -0
  12. torchinductor/6k/c6kat5g7n3uukkfwgdxfwtmxblcmqzhbifo5g3rbaz3djxii35gi.py +30 -0
  13. torchinductor/6w/4fb0f9adeff50e9452e8fd238a1808052c095c59a0b2f1d9f3f7d7106bd1ede5.best_config +1 -0
  14. torchinductor/6w/c6w6sg4v3bcighwokzq6i43tl5xk5vz7zevuk7jhvlqyryyuhueo.py +33 -0
  15. torchinductor/7f/be95397d0c18f43f4314e0cac66d456d9d3e2b12116963a4bf988016e97f7a5e.best_config +1 -0
  16. torchinductor/7f/c7ff4ib6652ojllutm4c7mkzzpybond3pagu3glspw3sztkfe2za.py +30 -0
  17. torchinductor/a3/94dc88253134d772dc28ed260760d9a0059b054d472700be3c22dd06b228f22f.best_config +1 -0
  18. torchinductor/a3/ca3menlfuldthgmncfpjk452xkros7idrmil6pcoeigraymcg4e6.py +78 -0
  19. torchinductor/aotautograd/a27rkqg32yfaub3aygtms2gl3oet2qxfcnp4zxa3zy5h6c3risxz/aw5eda3h36wpnnltujgkb4mvobznersd4fuvo2p7vy2quujasos +0 -0
  20. torchinductor/aotautograd/a3443o3ywoehrda4trn5q47mauudwcinftvd52hitdnfmakyhqc4/lw6yvpbd45y77sg6fh5v4otbinchkwbf7b56u3rh3wgq3x2wkhq +0 -0
  21. torchinductor/aotautograd/a3554ihbxq57jan4ib74iqo5mnaqevqume4yzewukzkm6ehpsilz/eubahghkef62rmchvnle5v6h3ddip4av5qqjxomdlm7ura45qve +0 -0
  22. torchinductor/aotautograd/a3hojixb5fzn7f7jfco3ddoohdsuggk4qbop3lcg7rjy3e7fkgfz/o7wvolbgborwtoofbovayor23y4ubooymfcvv6jeqm2wbx3n2cs +0 -0
  23. torchinductor/aotautograd/a54twb2qknddjxnxtmkoagy3umo5y3ptsesm2pdhy7nkefklf6wx/emxzj524wmpvifsxw4dsnnkzemqpzfgkenbo5obwmksvlhsr354 +0 -0
  24. torchinductor/aotautograd/a5ksywxhfabbequvxwstheyyj5w3sinuubxcrypqjwbqsyw5la3l/ew4fxjyfoflznyuws2w2ylu4p7owpjuqoshsef75w43w2vvwejd +0 -0
  25. torchinductor/aotautograd/a7ptufzlocphh5n5o5u63gfzkf74tjb3l5is45u5hqjspv32qda6/an4kgppgf4vt5yfvvghrnmho6jc3qnj4l6c75zrsiotr5d4u5gv +0 -0
  26. torchinductor/aotautograd/aal6kceyfi7eazavxzpgcec5hzt32bkwo7p4doeyc56ubzlwuvx4/nkoni3ckgbheucucq64bmrta4lhz7x237lalaqcrejvdc3supg4 +0 -0
  27. torchinductor/aotautograd/aan5kpy6i54rnpeu5vlzbx6i6blimsvhducl7futzdjr4xciy472/a35s4usnkzmh6ybhedo3b6zehfepmwdv2gxscayjeeuucr3zat7 +0 -0
  28. torchinductor/aotautograd/aesonb7djseswkbtu2qzhvg6ikd5rewxnqlt6pwuytadpxxmjcod/lap2sypphhofd6d5rhojruk2vfyvw2olc7gtulmom4i5y7ix2cp +0 -0
  29. torchinductor/aotautograd/age65c4dyk2rxcqufpxd6bsafzao7tacrsvejbf3pjbsngnoashv/upzttal3jaj233iyzyps7mjpq75jt6qi6rzramvgyyewfg76h6s +0 -0
  30. torchinductor/aotautograd/ahkpwjcp2qqyj6wu2ckjqlrit2pbb3ig3ddi75hgbkgngvvipwyq/ha76p7wv3nimmrgvx6kdiqikd6adbw7nlnaiars5ey4anx46mwn +0 -0
  31. torchinductor/aotautograd/aiojzczi5txclvaydkrk5g3qlf33pdkkhxtefkhfphkpc3o6rr4p/w3n37k3qhqfhuewneurnairyblp3h7nrak6oyp2p3um7uwnfcz5 +0 -0
  32. torchinductor/aotautograd/ajdkg3gacw25klanvqotc3mkab3mi23jtjpagxrosdmqv3d4yg7v/ejzrqbsrchqzxfppkzo4ep7edhv7lrjjbcdxkxvodbk4vvk3b62 +0 -0
  33. torchinductor/aotautograd/amb262dx57ptj6gg2ch6skr372w6arsr3i7i4ed5pljhiycuxduw/fntav2w4z5lvr443jxseqalau2vuzp7x7ljd3hanoqubtutjkvp +0 -0
  34. torchinductor/aotautograd/amjjivi2p6firai3idkjgfxyy6z4prevujsjdno2uuchwvd7xqll/enc6ruqcyggs4mnt54tjdd2lvexcvipd5vhhamxwcj77g5fpyof +0 -0
  35. torchinductor/aotautograd/apfaqlwe555qd2zoz575w5mvoxoiasmcomkv76mhz5zvnm5jok66/epmli5r46rzrqf73pqrnb5tratdg3mbbwdf5vyzqr6ejyhnooye +0 -0
  36. torchinductor/aotautograd/asjbg7f735jw54kcldmvv5uost22wzpy3hkxgaihos4rllvagheu/lwqpsnp52rszp2nlwkgi33embno5st2u5bxfm4rpyoy6fql5aor +0 -0
  37. torchinductor/aotautograd/atc2ggqhejcse5aydwh2wjakijsc2dyhqjxwdqrwpra3mgjwe4st/xwy7lzraqocjillvk4s2yc2qhpkx43s2nbkxmeb2wpph3sgyc7n +0 -0
  38. torchinductor/aotautograd/atsevoi6zqdcnehuxassvjosi3j5vrk54uisibylfgspeewp6vyx/4sfzv7d6ch2yoi6nnr5ym3i6yibku3vfveyrr6sx6dqbmavxo32 +0 -0
  39. torchinductor/aotautograd/ax7bbwqbruobasu7vagn2oj2owh5vgosxbjelta324rvf4tkesd4/ipnutob47ydixp2zetluyw4apg7fe5sfkkiianwaawh6yq3uang +0 -0
  40. torchinductor/aotautograd/ay26zyuzpll2prvy7zzoeydo7r47lrr6s6jcmzi2zmytjxzebmnz/nzx7lukg3r25p6sjlwtqmkf6gmgzuq7iwagwki2x4kvhw5ducr5 +0 -0
  41. torchinductor/aotautograd/ay65riayezoo7bqggl72pzrzdi6lvy5mp23ajx4f453ylzpmve3s/p7clvcke3bsgsaumutstrxc7bkq4tq6yoia7nwigana3n3unini +0 -0
  42. torchinductor/aotautograd/azyih32olvhzuay5zpfypzhk2cdlosvaqxdhcnjzlwfs6k3a2ne6/5sz2kjdze7ixdny7hz24p4uma7uup7chdcpiumqznifqn4mpmqb +0 -0
  43. torchinductor/av/cavoaz6e7kbk5wq2n7vz6rxhcrwdu2trazexubdq5qwyv2ajmbkz.py +73 -0
  44. torchinductor/av/d186a24d3c8af5514b42dea48fc981efd3f5afb7bba6c30406e42c75862888b1.best_config +1 -0
  45. torchinductor/ay/cayicsdjyjxzpcmkvjbneubnqkuhs3y37qiwy5qlel3z2loa4qav.py +69 -0
  46. torchinductor/bv/7969eba2eb589b95d2894ee75ee67ba01cd2bee09cd64d315c70c0950888c19e.best_config +1 -0
  47. torchinductor/bv/cbvqhjtyg7fvxzwtbtt4vrdkbnb6n32fnrijjpl3vv4cfqd4mznr.py +162 -0
  48. torchinductor/cr/ccr2gijy4jp6vvdbewmzgaogxbf5as7ytxtou4zo2yelawomrjjg.py +131 -0
  49. torchinductor/cz/bb6645c6be31f426023ec47eef09e354ad9fa8b2d59e6e45ab49b803eb34d44e.best_config +1 -0
  50. torchinductor/cz/cczg7tpituprwgqpuajzy2nylfk43mdozd5vwo77muq3kospnf7b.py +25 -0
.gitattributes CHANGED
@@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ torchinductor/fxgraph/22/f22et4hxfdbezzlil53lu2pcyq6hgd3sgpjtfwxqwvo3nhcfcvx3/iqgyrwyxmlbu22glkzz24rsbjiieww3sllxux2ttk4c6gtoiuba filter=lfs diff=lfs merge=lfs -text
37
+ torchinductor/fxgraph/6v/f6v2dym5xl4l4b2xlv35ic4ajld4mxcvhsvdsiwx2uug77q36cad/nhugjvtrt6cm53zizhnw673hj52m5m3kegaqvkjzkf4qh6d6rhm filter=lfs diff=lfs merge=lfs -text
38
+ torchinductor/fxgraph/ah/fahbtdmoejcqs352pnbnedqns63nbnu6hdbrwzvf6chptnsannjh/fhpbiokcxxh7ksbfgiljcvh7erywotuv4ddvlfcb4fk2ef7dd5c filter=lfs diff=lfs merge=lfs -text
39
+ torchinductor/fxgraph/bn/fbnlruhvmagcngqd5is2xjbucjaq7uf3sgsbdahfi6ovtehbhzyo/62gizmmmqz43cclymnr7ftyo5qt7ux4og6bc2xazrwstkjpsy2e filter=lfs diff=lfs merge=lfs -text
40
+ torchinductor/fxgraph/e2/fe2tjoiexjbavh5sakfaxvga43vsvwn5ev5bzhfjg76jvmtjqtbn/ejg3u4qymaxsvvl2vdequli7pwsrdjf5zdgqjkgbrxdsvgfv3h4 filter=lfs diff=lfs merge=lfs -text
41
+ torchinductor/fxgraph/fr/ffrx7clryowwzulnhruopihutvaxlycymqopsyoha6yecifyw2m2/g3wh462wylribiwz4th3gnlt5rtnrcb5bkad6w3yucxodv3q5ks filter=lfs diff=lfs merge=lfs -text
42
+ torchinductor/fxgraph/k6/fk6cfyjfeiu7xe6ebkapsnixuplqczgfc5534mitqsfkssbzjyak/4xid4w6sg2yg7xaseouf2vwhp2fyff56a2t6z6ownb3yw3g25rk filter=lfs diff=lfs merge=lfs -text
43
+ torchinductor/fxgraph/kj/fkjh2kykxecmnv6oe3zzwtjpek77nmrm35vgv2daxfgkim6xfk4u/gigbnwpmixz5epksvvrh4mtg3nxlpy724eojb3ajzvtepgvx7y4 filter=lfs diff=lfs merge=lfs -text
44
+ torchinductor/fxgraph/n6/fn6x7m44e35jdmh6iqj3eqiyrz7tbhzd3rqartt67myyrnickjmp/um3sgirsxogup4murdiaoy7dxu4ogolqsa343kh56kq24zd53fb filter=lfs diff=lfs merge=lfs -text
45
+ torchinductor/fxgraph/te/fte5y7bccssideiluepvpscj6srf7orxnfgql6to32ni27zf2uv2/vmrpdvw3meiqsf22oras6imorrybogkgp6jjr3ddtreaiuutais filter=lfs diff=lfs merge=lfs -text
46
+ torchinductor/fxgraph/tz/ftzd5ordyehsowurkwjjpkso24gayhyplcd6wz7xdv53fad276l6/36luuy7klcmb7554z63umrezysn6xbat5wxfaadxh4clxhcn2j7 filter=lfs diff=lfs merge=lfs -text
47
+ torchinductor/fxgraph/u4/fu47dchf76mmiajgnawm3xgek4ysnnmqaavupgy4cddyxygid6iq/725pjrxppjygb6kbca6zklwmn7iunv65thw23b5s4im6zi27j3i filter=lfs diff=lfs merge=lfs -text
48
+ torchinductor/fxgraph/uy/fuygegwmldon4qz3wvjs3cld4hnjz6yxh6aa2cmfsal4u3xxws43/l4w2mroymid3qdffvzt4wffavpm5it6rzi6lmbvxzzezfkbavuo filter=lfs diff=lfs merge=lfs -text
49
+ torchinductor/fxgraph/w5/fw5vzdkweh3kv3fm3mnal4wu63gxhw2anwx2pzuved4acfz4fdzm/n5dfreyro3slkufuydw2d54nm7bfiskxjjasoyg2z5yept5c3rf filter=lfs diff=lfs merge=lfs -text
meta.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cache_layout_version": 1,
3
+ "created_at": "2026-01-23T07:13:39Z",
4
+ "model_path": "/root/.cache/huggingface/hub/models--black-forest-labs--FLUX.2-klein-9B/snapshots/cd1bba5810fe2aba6666d9cf7352e25436426039",
5
+ "compile_command": [
6
+ "/usr/bin/python",
7
+ "/app/tensorrt_llm/visual_gen/examples/flux2_klein_9b.py",
8
+ "--model_path",
9
+ "/root/.cache/huggingface/hub/models--black-forest-labs--FLUX.2-klein-9B/snapshots/cd1bba5810fe2aba6666d9cf7352e25436426039",
10
+ "--height",
11
+ "512",
12
+ "--width",
13
+ "1024",
14
+ "--num_inference_steps",
15
+ "4",
16
+ "--num_images",
17
+ "6",
18
+ "--linear_type",
19
+ "te-fp8-per-tensor",
20
+ "--fallback_linear_type",
21
+ "default",
22
+ "--torch_compile_mode",
23
+ "default",
24
+ "--offload_text_encoder"
25
+ ],
26
+ "height": 512,
27
+ "width": 1024,
28
+ "num_inference_steps": 4,
29
+ "num_images": 6,
30
+ "linear_type": "te-fp8-per-tensor",
31
+ "fallback_linear_type": "default",
32
+ "torch_compile_mode": "default",
33
+ "offload_text_encoder": true,
34
+ "offload_vae": false,
35
+ "disable_cuda_graph": false,
36
+ "disable_teacache": false,
37
+ "torch_version": "2.10.0a0+b4e4ee81d3.nv25.12",
38
+ "cuda_version": "13.1",
39
+ "device_name": "NVIDIA GeForce RTX 4090",
40
+ "device_capability": [
41
+ 8,
42
+ 9
43
+ ]
44
+ }
torchinductor/2h/a581feca05a976cd76073f2f954a7641097b9c5775b12cf6831b3149d528a8b4.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 64, "YBLOCK": 64, "num_warps": 8, "num_stages": 1, "configs_hash": "1ce421918d79ed0f7edb09d0ee64f016daf650a007a21866fe52d592be55380c", "found_by_coordesc": false, "time_taken_ms": 143, "triton_cache_hash": "RNNMPWWZPRYLZDDP3QNL7R5SV7EYTG7WXIUJKWKAEGE4BUI424IA"}
torchinductor/2h/c2hij3hmloumxdmhuezsyhkmnqgnfa5ivre27uosymam3dr7a5xb.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'y': 131072, 'x': 128}, tile_hint=TileHint.DEFAULT,
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*fp32', 'in_ptr5': '*bf16', 'out_ptr0': '*bf16', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
14
+ inductor_meta={'grid_type': 'Grid2DWithYZOverflow', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__fused_rms_norm_cat_view_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 6, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 589824, 'x': 75497984}},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused__fused_rms_norm_cat_view_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
19
+ ynumel = 73728
20
+ xnumel = 128
21
+ yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
22
+ yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
23
+ ymask = yindex < ynumel
24
+ xoffset = tl.program_id(0) * XBLOCK
25
+ xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
26
+ xmask = xindex < xnumel
27
+ y1 = yindex // 32
28
+ x2 = xindex
29
+ y0 = (yindex % 32)
30
+ y3 = yindex
31
+ tmp0 = y1
32
+ tmp1 = tl.full([1, 1], 0, tl.int64)
33
+ tmp2 = tmp0 >= tmp1
34
+ tmp3 = tl.full([1, 1], 256, tl.int64)
35
+ tmp4 = tmp0 < tmp3
36
+ tmp5 = tl.load(in_ptr0 + (x2 + 128*y0 + 12288*(y1)), tmp4 & xmask & ymask, eviction_policy='evict_last', other=0.0).to(tl.float32)
37
+ tmp6 = tmp5.to(tl.float32)
38
+ tmp7 = tl.load(in_ptr1 + (tl.broadcast_to(y0 + 32*(y1), [YBLOCK, XBLOCK])), tmp4 & xmask & ymask, eviction_policy='evict_last', other=0.0)
39
+ tmp8 = 128.0
40
+ tmp9 = (tmp7 / tmp8)
41
+ tmp10 = 1e-06
42
+ tmp11 = tmp9 + tmp10
43
+ tmp12 = libdevice.rsqrt(tmp11)
44
+ tmp13 = tmp6 * tmp12
45
+ tmp14 = tl.load(in_ptr2 + (tl.broadcast_to(x2, [YBLOCK, XBLOCK])), tmp4 & xmask & ymask, eviction_policy='evict_last', other=0.0).to(tl.float32)
46
+ tmp15 = tmp14.to(tl.float32)
47
+ tmp16 = tmp13 * tmp15
48
+ tmp17 = tmp16.to(tl.float32)
49
+ tmp18 = tl.full(tmp17.shape, 0.0, tmp17.dtype)
50
+ tmp19 = tl.where(tmp4, tmp17, tmp18)
51
+ tmp20 = tmp0 >= tmp3
52
+ tmp21 = tl.full([1, 1], 2304, tl.int64)
53
+ tmp22 = tmp0 < tmp21
54
+ tmp23 = tl.load(in_ptr3 + (x2 + 128*y0 + 12288*((-256) + y1)), tmp20 & xmask & ymask, eviction_policy='evict_last', other=0.0).to(tl.float32)
55
+ tmp24 = tmp23.to(tl.float32)
56
+ tmp25 = tl.load(in_ptr4 + (tl.broadcast_to(y0 + 32*((-256) + y1), [YBLOCK, XBLOCK])), tmp20 & xmask & ymask, eviction_policy='evict_last', other=0.0)
57
+ tmp26 = 128.0
58
+ tmp27 = (tmp25 / tmp26)
59
+ tmp28 = 1e-06
60
+ tmp29 = tmp27 + tmp28
61
+ tmp30 = libdevice.rsqrt(tmp29)
62
+ tmp31 = tmp24 * tmp30
63
+ tmp32 = tl.load(in_ptr5 + (tl.broadcast_to(x2, [YBLOCK, XBLOCK])), tmp20 & xmask & ymask, eviction_policy='evict_last', other=0.0).to(tl.float32)
64
+ tmp33 = tmp32.to(tl.float32)
65
+ tmp34 = tmp31 * tmp33
66
+ tmp35 = tmp34.to(tl.float32)
67
+ tmp36 = tl.full(tmp35.shape, 0.0, tmp35.dtype)
68
+ tmp37 = tl.where(tmp20, tmp35, tmp36)
69
+ tmp38 = tl.where(tmp4, tmp19, tmp37)
70
+ tl.store(out_ptr0 + (x2 + 128*y3), tmp38, xmask & ymask)
torchinductor/2o/c2oduffhka4c52657rppatcdtgtnibm42qywfo2spmul2dpsj6jj.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['0_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /app/tensorrt_llm/visual_gen/compiled_cache/flux2_klein_9b_NVIDIA_GeForce_RTX_4090_sm89_torch2.10.0a0_b4e4ee81d3.nv25.12_cuda13_1/torchinductor/ww/cwwizzjwmd4ajlubxpvxidjiy3ldv5eflwludgizcahvsp4i75s2.py
38
+ # Topologically Sorted Source Nodes: [norm_hidden_states, add, mul, norm_hidden_states_1], Original ATen: [aten.native_layer_norm, aten.add, aten.mul]
39
+ # Source node to ATen node mapping:
40
+ # add => add_1
41
+ # mul => mul_1
42
+ # norm_hidden_states => add, convert_element_type, convert_element_type_1, mul, rsqrt, sub, var_mean
43
+ # norm_hidden_states_1 => add_2
44
+ # Graph fragment:
45
+ # %arg0_1 : Tensor "bf16[1, 2048, 4096][8388608, 4096, 1]cuda:0" = PlaceHolder[target=arg0_1]
46
+ # %arg1_1 : Tensor "bf16[1, 1, 4096][24576, 24576, 1]cuda:0" = PlaceHolder[target=arg1_1]
47
+ # %getitem_1 : Tensor "f32[1, 2048, 1][2048, 1, 2048]cuda:0" = PlaceHolder[target=getitem_1]
48
+ # %buf1 : Tensor "f32[1, 2048, 1][2048, 1, 2048]cuda:0" = PlaceHolder[target=buf1]
49
+ # %arg2_1 : Tensor "bf16[1, 1, 4096][24576, 24576, 1]cuda:0" = PlaceHolder[target=arg2_1]
50
+ # %convert_element_type : Tensor "f32[1, 2048, 4096][8388608, 4096, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg0_1, torch.float32), kwargs = {})
51
+ # %var_mean : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type, [2]), kwargs = {correction: 0, keepdim: True})
52
+ # %add_1 : Tensor "bf16[1, 1, 4096][4096, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg1_1, 1), kwargs = {})
53
+ # %sub : Tensor "f32[1, 2048, 4096][8388608, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem_1), kwargs = {})
54
+ # %add : Tensor "f32[1, 2048, 1][2048, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getitem, 1e-06), kwargs = {})
55
+ # %rsqrt : Tensor "f32[1, 2048, 1][2048, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add,), kwargs = {})
56
+ # %mul : Tensor "f32[1, 2048, 4096][8388608, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub, %rsqrt), kwargs = {})
57
+ # %convert_element_type_1 : Tensor "bf16[1, 2048, 4096][8388608, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
58
+ # %mul_1 : Tensor "bf16[1, 2048, 4096][8388608, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_1, %convert_element_type_1), kwargs = {})
59
+ # %add_2 : Tensor "bf16[1, 2048, 4096][8388608, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_1, %arg2_1), kwargs = {})
60
+ # return %getitem_1,%buf1,%add_2
61
+ triton_red_fused_add_mul_native_layer_norm_0 = async_compile.triton('triton_red_fused_add_mul_native_layer_norm_0', '''
62
+ import triton
63
+ import triton.language as tl
64
+
65
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
66
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
67
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
68
+ triton_helpers.set_driver_to_gpu()
69
+
70
+ @triton_heuristics.reduction(
71
+ size_hints={'x': 2048, 'r0_': 4096},
72
+ reduction_hint=ReductionHint.INNER,
73
+ filename=__file__,
74
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr2': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
75
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_add_mul_native_layer_norm_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 2, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 50348032}}
76
+ )
77
+ @triton.jit
78
+ def triton_red_fused_add_mul_native_layer_norm_0(in_ptr0, in_ptr1, in_ptr2, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
79
+ xnumel = 2048
80
+ r0_numel = 4096
81
+ rnumel = r0_numel
82
+ RBLOCK: tl.constexpr = R0_BLOCK
83
+ xoffset = tl.program_id(0) * XBLOCK
84
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
85
+ xmask = xindex < xnumel
86
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
87
+ rbase = r0_base
88
+ x0 = xindex
89
+ tmp3_mean = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
90
+ tmp3_m2 = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
91
+ tmp3_weight = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
92
+ for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
93
+ r0_index = r0_offset + r0_base
94
+ r0_mask = r0_index < r0_numel
95
+ roffset = r0_offset
96
+ rindex = r0_index
97
+ r0_1 = r0_index
98
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
99
+ tmp1 = tmp0.to(tl.float32)
100
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
101
+ tmp3_mean_next, tmp3_m2_next, tmp3_weight_next = triton_helpers.welford_reduce(
102
+ tmp2, tmp3_mean, tmp3_m2, tmp3_weight, roffset == 0
103
+ )
104
+ tmp3_mean = tl.where(r0_mask & xmask, tmp3_mean_next, tmp3_mean)
105
+ tmp3_m2 = tl.where(r0_mask & xmask, tmp3_m2_next, tmp3_m2)
106
+ tmp3_weight = tl.where(r0_mask & xmask, tmp3_weight_next, tmp3_weight)
107
+ tmp4, tmp5, tmp6 = triton_helpers.welford(tmp3_mean, tmp3_m2, tmp3_weight, 1)
108
+ tmp3 = tmp4[:, None]
109
+ tmp7 = tmp5[:, None]
110
+ tmp8 = tmp6[:, None]
111
+ for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
112
+ r0_index = r0_offset + r0_base
113
+ r0_mask = r0_index < r0_numel
114
+ roffset = r0_offset
115
+ rindex = r0_index
116
+ r0_1 = r0_index
117
+ tmp9 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
118
+ tmp12 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
119
+ tmp23 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
120
+ tmp10 = 1.0
121
+ tmp11 = tmp9 + tmp10
122
+ tmp13 = tmp12.to(tl.float32)
123
+ tmp14 = tmp13 - tmp3
124
+ tmp15 = 4096.0
125
+ tmp16 = (tmp7 / tmp15)
126
+ tmp17 = 1e-06
127
+ tmp18 = tmp16 + tmp17
128
+ tmp19 = libdevice.rsqrt(tmp18)
129
+ tmp20 = tmp14 * tmp19
130
+ tmp21 = tmp20.to(tl.float32)
131
+ tmp22 = tmp11 * tmp21
132
+ tmp24 = tmp22 + tmp23
133
+ tl.store(out_ptr2 + (r0_1 + 4096*x0), tmp24, r0_mask & xmask)
134
+ ''', device_str='cuda')
135
+
136
+
137
+ # kernel path: /app/tensorrt_llm/visual_gen/compiled_cache/flux2_klein_9b_NVIDIA_GeForce_RTX_4090_sm89_torch2.10.0a0_b4e4ee81d3.nv25.12_cuda13_1/torchinductor/av/cavoaz6e7kbk5wq2n7vz6rxhcrwdu2trazexubdq5qwyv2ajmbkz.py
138
+ # Topologically Sorted Source Nodes: [norm_encoder_hidden_states, add_2, mul_1, norm_encoder_hidden_states_1], Original ATen: [aten.native_layer_norm, aten.add, aten.mul]
139
+ # Source node to ATen node mapping:
140
+ # add_2 => add_4
141
+ # mul_1 => mul_3
142
+ # norm_encoder_hidden_states => add_3, convert_element_type_2, convert_element_type_3, mul_2, rsqrt_1, sub_1, var_mean_1
143
+ # norm_encoder_hidden_states_1 => add_5
144
+ # Graph fragment:
145
+ # %arg3_1 : Tensor "bf16[1, 256, 4096][1048576, 4096, 1]cuda:0" = PlaceHolder[target=arg3_1]
146
+ # %arg4_1 : Tensor "bf16[1, 1, 4096][24576, 24576, 1]cuda:0" = PlaceHolder[target=arg4_1]
147
+ # %getitem_3 : Tensor "f32[1, 256, 1][256, 1, 256]cuda:0" = PlaceHolder[target=getitem_3]
148
+ # %buf4 : Tensor "f32[1, 256, 1][256, 1, 256]cuda:0" = PlaceHolder[target=buf4]
149
+ # %arg5_1 : Tensor "bf16[1, 1, 4096][24576, 24576, 1]cuda:0" = PlaceHolder[target=arg5_1]
150
+ # %convert_element_type_2 : Tensor "f32[1, 256, 4096][1048576, 4096, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg3_1, torch.float32), kwargs = {})
151
+ # %var_mean_1 : [num_users=2] = call_function[target=torch.ops.aten.var_mean.correction](args = (%convert_element_type_2, [2]), kwargs = {correction: 0, keepdim: True})
152
+ # %add_4 : Tensor "bf16[1, 1, 4096][4096, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg4_1, 1), kwargs = {})
153
+ # %sub_1 : Tensor "f32[1, 256, 4096][1048576, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type_2, %getitem_3), kwargs = {})
154
+ # %add_3 : Tensor "f32[1, 256, 1][256, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%getitem_2, 1e-06), kwargs = {})
155
+ # %rsqrt_1 : Tensor "f32[1, 256, 1][256, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_3,), kwargs = {})
156
+ # %mul_2 : Tensor "f32[1, 256, 4096][1048576, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_1, %rsqrt_1), kwargs = {})
157
+ # %convert_element_type_3 : Tensor "bf16[1, 256, 4096][1048576, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_2, torch.bfloat16), kwargs = {})
158
+ # %mul_3 : Tensor "bf16[1, 256, 4096][1048576, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_4, %convert_element_type_3), kwargs = {})
159
+ # %add_5 : Tensor "bf16[1, 256, 4096][1048576, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_3, %arg5_1), kwargs = {})
160
+ # return %getitem_3,%buf4,%add_5
161
+ triton_red_fused_add_mul_native_layer_norm_1 = async_compile.triton('triton_red_fused_add_mul_native_layer_norm_1', '''
162
+ import triton
163
+ import triton.language as tl
164
+
165
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
166
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
167
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
168
+ triton_helpers.set_driver_to_gpu()
169
+
170
+ @triton_heuristics.reduction(
171
+ size_hints={'x': 256, 'r0_': 4096},
172
+ reduction_hint=ReductionHint.INNER,
173
+ filename=__file__,
174
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr2': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
175
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_add_mul_native_layer_norm_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 2, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 6307840}}
176
+ )
177
+ @triton.jit
178
+ def triton_red_fused_add_mul_native_layer_norm_1(in_ptr0, in_ptr1, in_ptr2, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
179
+ xnumel = 256
180
+ r0_numel = 4096
181
+ rnumel = r0_numel
182
+ RBLOCK: tl.constexpr = R0_BLOCK
183
+ xoffset = tl.program_id(0) * XBLOCK
184
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
185
+ xmask = xindex < xnumel
186
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
187
+ rbase = r0_base
188
+ x0 = xindex
189
+ tmp3_mean = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
190
+ tmp3_m2 = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
191
+ tmp3_weight = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
192
+ for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
193
+ r0_index = r0_offset + r0_base
194
+ r0_mask = r0_index < r0_numel
195
+ roffset = r0_offset
196
+ rindex = r0_index
197
+ r0_1 = r0_index
198
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
199
+ tmp1 = tmp0.to(tl.float32)
200
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
201
+ tmp3_mean_next, tmp3_m2_next, tmp3_weight_next = triton_helpers.welford_reduce(
202
+ tmp2, tmp3_mean, tmp3_m2, tmp3_weight, roffset == 0
203
+ )
204
+ tmp3_mean = tl.where(r0_mask & xmask, tmp3_mean_next, tmp3_mean)
205
+ tmp3_m2 = tl.where(r0_mask & xmask, tmp3_m2_next, tmp3_m2)
206
+ tmp3_weight = tl.where(r0_mask & xmask, tmp3_weight_next, tmp3_weight)
207
+ tmp4, tmp5, tmp6 = triton_helpers.welford(tmp3_mean, tmp3_m2, tmp3_weight, 1)
208
+ tmp3 = tmp4[:, None]
209
+ tmp7 = tmp5[:, None]
210
+ tmp8 = tmp6[:, None]
211
+ for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
212
+ r0_index = r0_offset + r0_base
213
+ r0_mask = r0_index < r0_numel
214
+ roffset = r0_offset
215
+ rindex = r0_index
216
+ r0_1 = r0_index
217
+ tmp9 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
218
+ tmp12 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
219
+ tmp23 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
220
+ tmp10 = 1.0
221
+ tmp11 = tmp9 + tmp10
222
+ tmp13 = tmp12.to(tl.float32)
223
+ tmp14 = tmp13 - tmp3
224
+ tmp15 = 4096.0
225
+ tmp16 = (tmp7 / tmp15)
226
+ tmp17 = 1e-06
227
+ tmp18 = tmp16 + tmp17
228
+ tmp19 = libdevice.rsqrt(tmp18)
229
+ tmp20 = tmp14 * tmp19
230
+ tmp21 = tmp20.to(tl.float32)
231
+ tmp22 = tmp11 * tmp21
232
+ tmp24 = tmp22 + tmp23
233
+ tl.store(out_ptr2 + (r0_1 + 4096*x0), tmp24, r0_mask & xmask)
234
+ ''', device_str='cuda')
235
+
236
+
237
+ async_compile.wait(globals())
238
+ del async_compile
239
+
240
+ class Runner:
241
+ def __init__(self, partitions):
242
+ self.partitions = partitions
243
+
244
+ def recursively_apply_fns(self, fns):
245
+ new_callables = []
246
+ for fn, c in zip(fns, self.partitions):
247
+ new_callables.append(fn(c))
248
+ self.partitions = new_callables
249
+
250
+ def call(self, args):
251
+ arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1 = args
252
+ args.clear()
253
+ assert_size_stride(arg0_1, (1, 2048, 4096), (8388608, 4096, 1))
254
+ assert_size_stride(arg1_1, (1, 1, 4096), (24576, 24576, 1))
255
+ assert_size_stride(arg2_1, (1, 1, 4096), (24576, 24576, 1))
256
+ assert_size_stride(arg3_1, (1, 256, 4096), (1048576, 4096, 1))
257
+ assert_size_stride(arg4_1, (1, 1, 4096), (24576, 24576, 1))
258
+ assert_size_stride(arg5_1, (1, 1, 4096), (24576, 24576, 1))
259
+ with torch.cuda._DeviceGuard(0):
260
+ torch.cuda.set_device(0)
261
+ buf6 = empty_strided_cuda((1, 2048, 4096), (8388608, 4096, 1), torch.bfloat16)
262
+ # Topologically Sorted Source Nodes: [norm_hidden_states, add, mul, norm_hidden_states_1], Original ATen: [aten.native_layer_norm, aten.add, aten.mul]
263
+ stream0 = get_raw_stream(0)
264
+ triton_red_fused_add_mul_native_layer_norm_0.run(arg0_1, arg1_1, arg2_1, buf6, 2048, 4096, stream=stream0)
265
+ del arg0_1
266
+ del arg1_1
267
+ del arg2_1
268
+ buf7 = empty_strided_cuda((1, 256, 4096), (1048576, 4096, 1), torch.bfloat16)
269
+ # Topologically Sorted Source Nodes: [norm_encoder_hidden_states, add_2, mul_1, norm_encoder_hidden_states_1], Original ATen: [aten.native_layer_norm, aten.add, aten.mul]
270
+ stream0 = get_raw_stream(0)
271
+ triton_red_fused_add_mul_native_layer_norm_1.run(arg3_1, arg4_1, arg5_1, buf7, 256, 4096, stream=stream0)
272
+ del arg3_1
273
+ del arg4_1
274
+ del arg5_1
275
+ return (buf6, buf7, )
276
+
277
+ runner = Runner(partitions=[])
278
+ call = runner.call
279
+ recursively_apply_fns = runner.recursively_apply_fns
280
+
281
+
282
+ def benchmark_compiled_module(times=10, repeat=10):
283
+ from torch._dynamo.testing import rand_strided
284
+ from torch._inductor.utils import print_performance
285
+ arg0_1 = rand_strided((1, 2048, 4096), (8388608, 4096, 1), device='cuda:0', dtype=torch.bfloat16)
286
+ arg1_1 = rand_strided((1, 1, 4096), (24576, 24576, 1), device='cuda:0', dtype=torch.bfloat16)
287
+ arg2_1 = rand_strided((1, 1, 4096), (24576, 24576, 1), device='cuda:0', dtype=torch.bfloat16)
288
+ arg3_1 = rand_strided((1, 256, 4096), (1048576, 4096, 1), device='cuda:0', dtype=torch.bfloat16)
289
+ arg4_1 = rand_strided((1, 1, 4096), (24576, 24576, 1), device='cuda:0', dtype=torch.bfloat16)
290
+ arg5_1 = rand_strided((1, 1, 4096), (24576, 24576, 1), device='cuda:0', dtype=torch.bfloat16)
291
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1])
292
+ return print_performance(fn, times=times, repeat=repeat)
293
+
294
+
295
+ if __name__ == "__main__":
296
+ from torch._inductor.wrapper_benchmark import compiled_module_main
297
+ compiled_module_main('None', benchmark_compiled_module)
torchinductor/3i/c3imyfibcq3zwrc5gvfscvpsaxdzjijymrnt2lfahtrjmrbtlmhe.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 67108864},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_cat_mul_silu_split_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 377487360}},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_cat_mul_silu_split_view_0(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
19
+ xnumel = 37748736
20
+ xoffset = tl.program_id(0) * XBLOCK
21
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
22
+ xmask = tl.full([XBLOCK], True, tl.int1)
23
+ x0 = (xindex % 16384)
24
+ x1 = xindex // 16384
25
+ x2 = xindex
26
+ tmp0 = x0
27
+ tmp1 = tl.full([1], 0, tl.int64)
28
+ tmp2 = tmp0 >= tmp1
29
+ tmp3 = tl.full([1], 4096, tl.int64)
30
+ tmp4 = tmp0 < tmp3
31
+ tmp5 = tl.load(in_ptr0 + (4096*x1 + (x0)), tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32)
32
+ tmp6 = tmp0 >= tmp3
33
+ tmp7 = tl.full([1], 16384, tl.int64)
34
+ tmp8 = tmp0 < tmp7
35
+ tmp9 = tl.load(in_ptr1 + (36864*x1 + ((-4096) + x0)), tmp6, eviction_policy='evict_last', other=0.0).to(tl.float32)
36
+ tmp10 = tmp9.to(tl.float32)
37
+ tmp11 = tl.sigmoid(tmp10)
38
+ tmp12 = tmp10 * tmp11
39
+ tmp13 = tmp12.to(tl.float32)
40
+ tmp14 = tl.load(in_ptr1 + (12288 + 36864*x1 + ((-4096) + x0)), tmp6, eviction_policy='evict_last', other=0.0).to(tl.float32)
41
+ tmp15 = tmp13 * tmp14
42
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
43
+ tmp17 = tl.where(tmp6, tmp15, tmp16)
44
+ tmp18 = tl.where(tmp4, tmp5, tmp17)
45
+ tl.store(out_ptr0 + (x2), tmp18, None)
torchinductor/3i/cf1587a2fd240ce39177274973308f6fd100d746bf6716a8d96ed4fd12c89d55.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 81, "triton_cache_hash": "PPN4SVQW2UFKVPWUB7HCOIHQMJON3EA6PX7FI3IPMCGAPBBOTNMQ"}
torchinductor/3v/4a00da1b5d4ce251d2cb392c24118fc2e6c3818f25b8457665f0d53e12234277.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1024, "num_warps": 4, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 43, "triton_cache_hash": "SJ2F5NEEPBSFTTPVSLW22OOIZQR5FPT5YWSURMFRPHLWAFZ5VB7A"}
torchinductor/3v/c3vjilvcy7sdqcaxfspffadbsry3sqm4j7qp6u3tusr6p34sbiar.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 16777216},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_cudnn_attention_clone_permute_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 37748736}},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused__scaled_dot_product_cudnn_attention_clone_permute_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
19
+ xnumel = 9437184
20
+ xoffset = tl.program_id(0) * XBLOCK
21
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
22
+ xmask = tl.full([XBLOCK], True, tl.int1)
23
+ x0 = (xindex % 128)
24
+ x1 = ((xindex // 128) % 2304)
25
+ x2 = xindex // 294912
26
+ x3 = xindex
27
+ tmp0 = tl.load(in_ptr0 + (x0 + 128*x2 + ks0*x1), None).to(tl.float32)
28
+ tl.store(out_ptr0 + (x3), tmp0, None)
torchinductor/4y/c4ykjyk6fv6enet6mgkj5bsan42tc6rsdfs7aaskpjgv5rzw7tbr.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['25_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /app/tensorrt_llm/visual_gen/compiled_cache/flux2_klein_9b_NVIDIA_GeForce_RTX_4090_sm89_torch2.10.0a0_b4e4ee81d3.nv25.12_cuda13_1/torchinductor/bv/cbvqhjtyg7fvxzwtbtt4vrdkbnb6n32fnrijjpl3vv4cfqd4mznr.py
38
+ # Topologically Sorted Source Nodes: [split, chunk, query_1, query_2, reshape, unbind, key_1, key_2, reshape_1, unbind_1, float_1, cos, mul, neg, stack, x_rotated, float_2, sin, mul_1, add, out, float_3, cos_2, mul_2, neg_1, stack_1, x_rotated_1, float_4, sin_2, mul_3, add_1, out_1], Original ATen: [aten.split_with_sizes, aten.split, aten.view, aten._fused_rms_norm, aten.unbind, aten._to_copy, aten.unsqueeze, aten.mul, aten.neg, aten.stack, aten.add]
39
+ # Source node to ATen node mapping:
40
+ # add => add_2
41
+ # add_1 => add_3
42
+ # chunk => split
43
+ # cos => unsqueeze, unsqueeze_1
44
+ # cos_2 => unsqueeze_6, unsqueeze_7
45
+ # float_1 => convert_element_type_4
46
+ # float_2 => convert_element_type_5
47
+ # float_3 => convert_element_type_7
48
+ # float_4 => convert_element_type_8
49
+ # key_1 => view_1
50
+ # key_2 => add_1, convert_element_type_2, convert_element_type_3, mean_1, mul_2, mul_3, pow_2, rsqrt_1
51
+ # mul => mul_4
52
+ # mul_1 => mul_5
53
+ # mul_2 => mul_6
54
+ # mul_3 => mul_7
55
+ # neg => neg
56
+ # neg_1 => neg_1
57
+ # out => convert_element_type_6
58
+ # out_1 => convert_element_type_9
59
+ # query_1 => view
60
+ # query_2 => add, convert_element_type, convert_element_type_1, mean, mul, mul_1, pow_1, rsqrt
61
+ # reshape => view_3
62
+ # reshape_1 => view_5
63
+ # sin => unsqueeze_2, unsqueeze_3
64
+ # sin_2 => unsqueeze_8, unsqueeze_9
65
+ # split => split_with_sizes
66
+ # stack => cat, unsqueeze_4, unsqueeze_5
67
+ # stack_1 => cat_1, unsqueeze_10, unsqueeze_11
68
+ # unbind => unbind
69
+ # unbind_1 => unbind_1
70
+ # x_rotated => view_4
71
+ # x_rotated_1 => view_6
72
+ # Graph fragment:
73
+ # %arg0_1 : Tensor "bf16[1, 2304, 36864][84934656, 36864, 1]cuda:0" = PlaceHolder[target=arg0_1]
74
+ # %buf0 : Tensor "f32[1, 2304, 32, 1][73728, 32, 1, 73728]cuda:0" = PlaceHolder[target=buf0]
75
+ # %arg1_1 : Tensor "bf16[128][1]cuda:0" = PlaceHolder[target=arg1_1]
76
+ # %arg3_1 : Tensor "f32[2304, 128][128, 1]cuda:0" = PlaceHolder[target=arg3_1]
77
+ # %cat : Tensor "bf16[1, 2304, 32, 64, 2][9437184, 4096, 128, 2, 1]cuda:0" = PlaceHolder[target=cat]
78
+ # %arg4_1 : Tensor "f32[2304, 128][128, 1]cuda:0" = PlaceHolder[target=arg4_1]
79
+ # %buf1 : Tensor "f32[1, 2304, 32, 1][73728, 32, 1, 73728]cuda:0" = PlaceHolder[target=buf1]
80
+ # %arg2_1 : Tensor "bf16[128][1]cuda:0" = PlaceHolder[target=arg2_1]
81
+ # %cat_1 : Tensor "bf16[1, 2304, 32, 64, 2][9437184, 4096, 128, 2, 1]cuda:0" = PlaceHolder[target=cat_1]
82
+ # %split_with_sizes : [num_users=2] = call_function[target=torch.ops.aten.split_with_sizes.default](args = (%arg0_1, [12288, 24576], -1), kwargs = {})
83
+ # %split : [num_users=3] = call_function[target=torch.ops.aten.split.Tensor](args = (%getitem, 4096, -1), kwargs = {})
84
+ # %view : Tensor "bf16[1, 2304, 32, 128][84934656, 36864, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%getitem_2, [1, 2304, 32, 128]), kwargs = {})
85
+ # %convert_element_type : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view, torch.float32), kwargs = {})
86
+ # %pow_1 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type, 2), kwargs = {})
87
+ # %mean : Tensor "f32[1, 2304, 32, 1][73728, 32, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [3], True), kwargs = {})
88
+ # %add : Tensor "f32[1, 2304, 32, 1][73728, 32, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Scalar](args = (%mean, 1e-06), kwargs = {})
89
+ # %rsqrt : Tensor "f32[1, 2304, 32, 1][73728, 32, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add,), kwargs = {})
90
+ # %mul : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {})
91
+ # %mul_1 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul, %arg1_1), kwargs = {})
92
+ # %convert_element_type_1 : Tensor "bf16[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_1, torch.bfloat16), kwargs = {})
93
+ # %view_3 : Tensor "bf16[1, 2304, 32, 64, 2][9437184, 4096, 128, 2, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_1, [1, 2304, 32, -1, 2]), kwargs = {})
94
+ # %unbind : [num_users=2] = call_function[target=torch.ops.aten.unbind.int](args = (%view_3, -1), kwargs = {})
95
+ # %view_1 : Tensor "bf16[1, 2304, 32, 128][84934656, 36864, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%getitem_3, [1, 2304, 32, 128]), kwargs = {})
96
+ # %convert_element_type_2 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_1, torch.float32), kwargs = {})
97
+ # %pow_2 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_2, 2), kwargs = {})
98
+ # %mean_1 : Tensor "f32[1, 2304, 32, 1][73728, 32, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_2, [3], True), kwargs = {})
99
+ # %add_1 : Tensor "f32[1, 2304, 32, 1][73728, 32, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Scalar](args = (%mean_1, 1e-06), kwargs = {})
100
+ # %rsqrt_1 : Tensor "f32[1, 2304, 32, 1][73728, 32, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_1,), kwargs = {})
101
+ # %mul_2 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_2, %rsqrt_1), kwargs = {})
102
+ # %mul_3 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_2, %arg2_1), kwargs = {})
103
+ # %convert_element_type_3 : Tensor "bf16[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_3, torch.bfloat16), kwargs = {})
104
+ # %view_5 : Tensor "bf16[1, 2304, 32, 64, 2][9437184, 4096, 128, 2, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_3, [1, 2304, 32, -1, 2]), kwargs = {})
105
+ # %unbind_1 : [num_users=2] = call_function[target=torch.ops.aten.unbind.int](args = (%view_5, -1), kwargs = {})
106
+ # %convert_element_type_4 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.float32), kwargs = {})
107
+ # %unsqueeze : Tensor "f32[1, 2304, 128][294912, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%arg3_1, 0), kwargs = {})
108
+ # %unsqueeze_1 : Tensor "f32[1, 2304, 1, 128][294912, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze, 2), kwargs = {})
109
+ # %mul_4 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_4, %unsqueeze_1), kwargs = {})
110
+ # %neg : Tensor "bf16[1, 2304, 32, 64][4718592, 2048, 64, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_6,), kwargs = {})
111
+ # %unsqueeze_4 : Tensor "bf16[1, 2304, 32, 64, 1][4718592, 2048, 64, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%neg, 4), kwargs = {})
112
+ # %unsqueeze_5 : Tensor "bf16[1, 2304, 32, 64, 1][9437184, 4096, 128, 2, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%getitem_5, 4), kwargs = {})
113
+ # %cat : Tensor "bf16[1, 2304, 32, 64, 2][9437184, 4096, 128, 2, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%unsqueeze_4, %unsqueeze_5], -1), kwargs = {})
114
+ # %view_4 : Tensor "bf16[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%cat, [1, 2304, 32, 128]), kwargs = {})
115
+ # %convert_element_type_5 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_4, torch.float32), kwargs = {})
116
+ # %unsqueeze_2 : Tensor "f32[1, 2304, 128][294912, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%arg4_1, 0), kwargs = {})
117
+ # %unsqueeze_3 : Tensor "f32[1, 2304, 1, 128][294912, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, 2), kwargs = {})
118
+ # %mul_5 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_5, %unsqueeze_3), kwargs = {})
119
+ # %add_2 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_4, %mul_5), kwargs = {})
120
+ # %convert_element_type_6 : Tensor "bf16[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_2, torch.bfloat16), kwargs = {})
121
+ # %convert_element_type_7 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_3, torch.float32), kwargs = {})
122
+ # %unsqueeze_6 : Tensor "f32[1, 2304, 128][294912, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%arg3_1, 0), kwargs = {})
123
+ # %unsqueeze_7 : Tensor "f32[1, 2304, 1, 128][294912, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_6, 2), kwargs = {})
124
+ # %mul_6 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_7, %unsqueeze_7), kwargs = {})
125
+ # %neg_1 : Tensor "bf16[1, 2304, 32, 64][4718592, 2048, 64, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_8,), kwargs = {})
126
+ # %unsqueeze_10 : Tensor "bf16[1, 2304, 32, 64, 1][4718592, 2048, 64, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%neg_1, 4), kwargs = {})
127
+ # %unsqueeze_11 : Tensor "bf16[1, 2304, 32, 64, 1][9437184, 4096, 128, 2, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%getitem_7, 4), kwargs = {})
128
+ # %cat_1 : Tensor "bf16[1, 2304, 32, 64, 2][9437184, 4096, 128, 2, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%unsqueeze_10, %unsqueeze_11], -1), kwargs = {})
129
+ # %view_6 : Tensor "bf16[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%cat_1, [1, 2304, 32, 128]), kwargs = {})
130
+ # %convert_element_type_8 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_6, torch.float32), kwargs = {})
131
+ # %unsqueeze_8 : Tensor "f32[1, 2304, 128][294912, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%arg4_1, 0), kwargs = {})
132
+ # %unsqueeze_9 : Tensor "f32[1, 2304, 1, 128][294912, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_8, 2), kwargs = {})
133
+ # %mul_7 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_8, %unsqueeze_9), kwargs = {})
134
+ # %add_3 : Tensor "f32[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_6, %mul_7), kwargs = {})
135
+ # %convert_element_type_9 : Tensor "bf16[1, 2304, 32, 128][9437184, 4096, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_3, torch.bfloat16), kwargs = {})
136
+ # return %buf1,%buf0,%cat,%convert_element_type_6,%cat_1,%convert_element_type_9
137
+ triton_red_fused__fused_rms_norm__to_copy_add_mul_neg_split_split_with_sizes_stack_unbind_unsqueeze_view_0 = async_compile.triton('triton_red_fused__fused_rms_norm__to_copy_add_mul_neg_split_split_with_sizes_stack_unbind_unsqueeze_view_0', '''
138
+ import triton
139
+ import triton.language as tl
140
+
141
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
142
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
143
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
144
+ triton_helpers.set_driver_to_gpu()
145
+
146
+ @triton_heuristics.reduction(
147
+ size_hints={'x': 131072, 'r0_': 128},
148
+ reduction_hint=ReductionHint.DEFAULT,
149
+ filename=__file__,
150
+ triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_out_ptr1': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
151
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__fused_rms_norm__to_copy_add_mul_neg_split_split_with_sizes_stack_unbind_unsqueeze_view_0', 'mutated_arg_names': ['in_out_ptr0', 'in_out_ptr1'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 16, 'num_store': 2, 'num_reduction': 2, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 115606016}}
152
+ )
153
+ @triton.jit
154
+ def triton_red_fused__fused_rms_norm__to_copy_add_mul_neg_split_split_with_sizes_stack_unbind_unsqueeze_view_0(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
155
+ xnumel = 73728
156
+ r0_numel = 128
157
+ rnumel = r0_numel
158
+ RBLOCK: tl.constexpr = R0_BLOCK
159
+ xoffset = tl.program_id(0) * XBLOCK
160
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
161
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
162
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
163
+ rbase = r0_base
164
+ x0 = (xindex % 32)
165
+ x1 = xindex // 32
166
+ _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
167
+ x5 = xindex
168
+ _tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
169
+ for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
170
+ r0_index = r0_offset + r0_base
171
+ r0_mask = r0_index < r0_numel
172
+ roffset = r0_offset
173
+ rindex = r0_index
174
+ r0_2 = r0_index
175
+ tmp0 = tl.load(in_ptr0 + (4096 + r0_2 + 128*x0 + 36864*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
176
+ tmp6 = tl.load(in_ptr0 + (r0_2 + 128*x0 + 36864*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
177
+ tmp1 = tmp0.to(tl.float32)
178
+ tmp2 = tmp1 * tmp1
179
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
180
+ tmp5 = _tmp4 + tmp3
181
+ _tmp4 = tl.where(r0_mask, tmp5, _tmp4)
182
+ tmp7 = tmp6.to(tl.float32)
183
+ tmp8 = tmp7 * tmp7
184
+ tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
185
+ tmp11 = _tmp10 + tmp9
186
+ _tmp10 = tl.where(r0_mask, tmp11, _tmp10)
187
+ tmp4 = tl.sum(_tmp4, 1)[:, None]
188
+ tmp10 = tl.sum(_tmp10, 1)[:, None]
189
+ for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
190
+ r0_index = r0_offset + r0_base
191
+ r0_mask = r0_index < r0_numel
192
+ roffset = r0_offset
193
+ rindex = r0_index
194
+ r0_3 = (r0_index % 2)
195
+ r0_4 = r0_index // 2
196
+ r0_2 = r0_index
197
+ tmp50 = tl.load(in_ptr0 + (r0_2 + 128*x0 + 36864*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
198
+ tmp58 = tl.load(in_ptr1 + (r0_2), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
199
+ tmp63 = tl.load(in_ptr2 + (r0_2 + 128*x1), r0_mask, eviction_policy='evict_last', other=0.0)
200
+ tmp66 = tl.load(in_ptr3 + (r0_2 + 128*x1), r0_mask, eviction_policy='evict_last', other=0.0)
201
+ tmp96 = tl.load(in_ptr0 + (4096 + r0_2 + 128*x0 + 36864*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
202
+ tmp102 = tl.load(in_ptr4 + (r0_2), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
203
+ tmp12 = r0_3
204
+ tmp13 = tl.full([1, 1], 0, tl.int64)
205
+ tmp14 = tmp12 >= tmp13
206
+ tmp15 = tl.full([1, 1], 1, tl.int64)
207
+ tmp16 = tmp12 < tmp15
208
+ tmp17 = tl.load(in_ptr0 + (1 + 2*r0_4 + 128*x0 + 36864*x1), r0_mask & tmp16, eviction_policy='evict_last', other=0.0).to(tl.float32)
209
+ tmp18 = tmp17.to(tl.float32)
210
+ tmp19 = 128.0
211
+ tmp20 = (tmp10 / tmp19)
212
+ tmp21 = 1e-06
213
+ tmp22 = tmp20 + tmp21
214
+ tmp23 = libdevice.rsqrt(tmp22)
215
+ tmp24 = tmp18 * tmp23
216
+ tmp25 = tl.load(in_ptr1 + (tl.broadcast_to(1 + 2*r0_4, [XBLOCK, R0_BLOCK])), r0_mask & tmp16, eviction_policy='evict_last', other=0.0).to(tl.float32)
217
+ tmp26 = tmp25.to(tl.float32)
218
+ tmp27 = tmp24 * tmp26
219
+ tmp28 = tmp27.to(tl.float32)
220
+ tmp29 = -tmp28
221
+ tmp30 = tl.full(tmp29.shape, 0.0, tmp29.dtype)
222
+ tmp31 = tl.where(tmp16, tmp29, tmp30)
223
+ tmp32 = tmp12 >= tmp15
224
+ tmp33 = tl.full([1, 1], 2, tl.int64)
225
+ tmp34 = tmp12 < tmp33
226
+ tmp35 = tl.load(in_ptr0 + (2*r0_4 + 128*x0 + 36864*x1), r0_mask & tmp32, eviction_policy='evict_last', other=0.0).to(tl.float32)
227
+ tmp36 = tmp35.to(tl.float32)
228
+ tmp37 = 128.0
229
+ tmp38 = (tmp10 / tmp37)
230
+ tmp39 = 1e-06
231
+ tmp40 = tmp38 + tmp39
232
+ tmp41 = libdevice.rsqrt(tmp40)
233
+ tmp42 = tmp36 * tmp41
234
+ tmp43 = tl.load(in_ptr1 + (tl.broadcast_to(2*r0_4, [XBLOCK, R0_BLOCK])), r0_mask & tmp32, eviction_policy='evict_last', other=0.0).to(tl.float32)
235
+ tmp44 = tmp43.to(tl.float32)
236
+ tmp45 = tmp42 * tmp44
237
+ tmp46 = tmp45.to(tl.float32)
238
+ tmp47 = tl.full(tmp46.shape, 0.0, tmp46.dtype)
239
+ tmp48 = tl.where(tmp32, tmp46, tmp47)
240
+ tmp49 = tl.where(tmp16, tmp31, tmp48)
241
+ tmp51 = tmp50.to(tl.float32)
242
+ tmp52 = 128.0
243
+ tmp53 = (tmp10 / tmp52)
244
+ tmp54 = 1e-06
245
+ tmp55 = tmp53 + tmp54
246
+ tmp56 = libdevice.rsqrt(tmp55)
247
+ tmp57 = tmp51 * tmp56
248
+ tmp59 = tmp58.to(tl.float32)
249
+ tmp60 = tmp57 * tmp59
250
+ tmp61 = tmp60.to(tl.float32)
251
+ tmp62 = tmp61.to(tl.float32)
252
+ tmp64 = tmp62 * tmp63
253
+ tmp65 = tmp49.to(tl.float32)
254
+ tmp67 = tmp65 * tmp66
255
+ tmp68 = tmp64 + tmp67
256
+ tmp69 = tmp68.to(tl.float32)
257
+ tmp70 = tl.load(in_ptr0 + (4097 + 2*r0_4 + 128*x0 + 36864*x1), r0_mask & tmp16, eviction_policy='evict_last', other=0.0).to(tl.float32)
258
+ tmp71 = tmp70.to(tl.float32)
259
+ tmp72 = (tmp4 / tmp19)
260
+ tmp73 = tmp72 + tmp21
261
+ tmp74 = libdevice.rsqrt(tmp73)
262
+ tmp75 = tmp71 * tmp74
263
+ tmp76 = tl.load(in_ptr4 + (tl.broadcast_to(1 + 2*r0_4, [XBLOCK, R0_BLOCK])), r0_mask & tmp16, eviction_policy='evict_last', other=0.0).to(tl.float32)
264
+ tmp77 = tmp76.to(tl.float32)
265
+ tmp78 = tmp75 * tmp77
266
+ tmp79 = tmp78.to(tl.float32)
267
+ tmp80 = -tmp79
268
+ tmp81 = tl.full(tmp80.shape, 0.0, tmp80.dtype)
269
+ tmp82 = tl.where(tmp16, tmp80, tmp81)
270
+ tmp83 = tl.load(in_ptr0 + (4096 + 2*r0_4 + 128*x0 + 36864*x1), r0_mask & tmp32, eviction_policy='evict_last', other=0.0).to(tl.float32)
271
+ tmp84 = tmp83.to(tl.float32)
272
+ tmp85 = (tmp4 / tmp37)
273
+ tmp86 = tmp85 + tmp39
274
+ tmp87 = libdevice.rsqrt(tmp86)
275
+ tmp88 = tmp84 * tmp87
276
+ tmp89 = tl.load(in_ptr4 + (tl.broadcast_to(2*r0_4, [XBLOCK, R0_BLOCK])), r0_mask & tmp32, eviction_policy='evict_last', other=0.0).to(tl.float32)
277
+ tmp90 = tmp89.to(tl.float32)
278
+ tmp91 = tmp88 * tmp90
279
+ tmp92 = tmp91.to(tl.float32)
280
+ tmp93 = tl.full(tmp92.shape, 0.0, tmp92.dtype)
281
+ tmp94 = tl.where(tmp32, tmp92, tmp93)
282
+ tmp95 = tl.where(tmp16, tmp82, tmp94)
283
+ tmp97 = tmp96.to(tl.float32)
284
+ tmp98 = (tmp4 / tmp52)
285
+ tmp99 = tmp98 + tmp54
286
+ tmp100 = libdevice.rsqrt(tmp99)
287
+ tmp101 = tmp97 * tmp100
288
+ tmp103 = tmp102.to(tl.float32)
289
+ tmp104 = tmp101 * tmp103
290
+ tmp105 = tmp104.to(tl.float32)
291
+ tmp106 = tmp105.to(tl.float32)
292
+ tmp107 = tmp106 * tmp63
293
+ tmp108 = tmp95.to(tl.float32)
294
+ tmp109 = tmp108 * tmp66
295
+ tmp110 = tmp107 + tmp109
296
+ tmp111 = tmp110.to(tl.float32)
297
+ tl.store(in_out_ptr0 + (r0_2 + 128*x5), tmp69, r0_mask)
298
+ tl.store(in_out_ptr1 + (r0_2 + 128*x5), tmp111, r0_mask)
299
+ ''', device_str='cuda')
300
+
301
+
302
+ async_compile.wait(globals())
303
+ del async_compile
304
+
305
+ class Runner:
306
+ def __init__(self, partitions):
307
+ self.partitions = partitions
308
+
309
+ def recursively_apply_fns(self, fns):
310
+ new_callables = []
311
+ for fn, c in zip(fns, self.partitions):
312
+ new_callables.append(fn(c))
313
+ self.partitions = new_callables
314
+
315
+ def call(self, args):
316
+ arg0_1, arg1_1, arg2_1, arg3_1, arg4_1 = args
317
+ args.clear()
318
+ assert_size_stride(arg0_1, (1, 2304, 36864), (84934656, 36864, 1))
319
+ assert_size_stride(arg1_1, (128, ), (1, ))
320
+ assert_size_stride(arg2_1, (128, ), (1, ))
321
+ assert_size_stride(arg3_1, (2304, 128), (128, 1))
322
+ assert_size_stride(arg4_1, (2304, 128), (128, 1))
323
+ with torch.cuda._DeviceGuard(0):
324
+ torch.cuda.set_device(0)
325
+ buf2 = empty_strided_cuda((1, 2304, 32, 64, 2), (9437184, 4096, 128, 2, 1), torch.bfloat16)
326
+ buf3 = reinterpret_tensor(buf2, (1, 2304, 32, 128), (9437184, 4096, 128, 1), 0); del buf2 # reuse
327
+ buf4 = empty_strided_cuda((1, 2304, 32, 64, 2), (9437184, 4096, 128, 2, 1), torch.bfloat16)
328
+ buf5 = reinterpret_tensor(buf4, (1, 2304, 32, 128), (9437184, 4096, 128, 1), 0); del buf4 # reuse
329
+ # Topologically Sorted Source Nodes: [split, chunk, query_1, query_2, reshape, unbind, key_1, key_2, reshape_1, unbind_1, float_1, cos, mul, neg, stack, x_rotated, float_2, sin, mul_1, add, out, float_3, cos_2, mul_2, neg_1, stack_1, x_rotated_1, float_4, sin_2, mul_3, add_1, out_1], Original ATen: [aten.split_with_sizes, aten.split, aten.view, aten._fused_rms_norm, aten.unbind, aten._to_copy, aten.unsqueeze, aten.mul, aten.neg, aten.stack, aten.add]
330
+ stream0 = get_raw_stream(0)
331
+ triton_red_fused__fused_rms_norm__to_copy_add_mul_neg_split_split_with_sizes_stack_unbind_unsqueeze_view_0.run(buf3, buf5, arg0_1, arg1_1, arg3_1, arg4_1, arg2_1, 73728, 128, stream=stream0)
332
+ del arg1_1
333
+ del arg2_1
334
+ del arg3_1
335
+ del arg4_1
336
+ return (buf3, buf5, reinterpret_tensor(arg0_1, (1, 2304, 32, 128), (84934656, 36864, 128, 1), 8192), reinterpret_tensor(arg0_1, (1, 2304, 24576), (84934656, 36864, 1), 12288), )
337
+
338
+ runner = Runner(partitions=[])
339
+ call = runner.call
340
+ recursively_apply_fns = runner.recursively_apply_fns
341
+
342
+
343
+ def benchmark_compiled_module(times=10, repeat=10):
344
+ from torch._dynamo.testing import rand_strided
345
+ from torch._inductor.utils import print_performance
346
+ arg0_1 = rand_strided((1, 2304, 36864), (84934656, 36864, 1), device='cuda:0', dtype=torch.bfloat16)
347
+ arg1_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
348
+ arg2_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
349
+ arg3_1 = rand_strided((2304, 128), (128, 1), device='cuda:0', dtype=torch.float32)
350
+ arg4_1 = rand_strided((2304, 128), (128, 1), device='cuda:0', dtype=torch.float32)
351
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1])
352
+ return print_performance(fn, times=times, repeat=repeat)
353
+
354
+
355
+ if __name__ == "__main__":
356
+ from torch._inductor.wrapper_benchmark import compiled_module_main
357
+ compiled_module_main('None', benchmark_compiled_module)
torchinductor/6k/abd9e26dfce6bf628201c09f1f90f4340fdaab3cc2dd99f7186afe82fe013d1a.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1024, "num_warps": 4, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 35, "triton_cache_hash": "Q5QIKEPJDRH7FHZ6CDBLMD5Y4GTGU6Y7IAWNFLJIJRNGOB7RFV4Q"}
torchinductor/6k/c6kat5g7n3uukkfwgdxfwtmxblcmqzhbifo5g3rbaz3djxii35gi.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 1048576},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 8396800}},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
19
+ xnumel = 1048576
20
+ xoffset = tl.program_id(0) * XBLOCK
21
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
22
+ xmask = tl.full([XBLOCK], True, tl.int1)
23
+ x2 = xindex
24
+ x0 = (xindex % 4096)
25
+ tmp0 = tl.load(in_ptr0 + (x2), None).to(tl.float32)
26
+ tmp1 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
27
+ tmp2 = tl.load(in_ptr2 + (x2), None).to(tl.float32)
28
+ tmp3 = tmp1 * tmp2
29
+ tmp4 = tmp0 + tmp3
30
+ tl.store(out_ptr0 + (x2), tmp4, None)
torchinductor/6w/4fb0f9adeff50e9452e8fd238a1808052c095c59a0b2f1d9f3f7d7106bd1ede5.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1024, "num_warps": 4, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 68, "triton_cache_hash": "6DP457QWOYDYHZ7TARQ4OLPDLGEKSLMNVUE3G2KQSQUSVRN7FBVA"}
torchinductor/6w/c6w6sg4v3bcighwokzq6i43tl5xk5vz7zevuk7jhvlqyryyuhueo.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 33554432},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_silu_split_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 201326592}},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_mul_silu_split_0(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
19
+ xnumel = 25165824
20
+ xoffset = tl.program_id(0) * XBLOCK
21
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
22
+ xmask = tl.full([XBLOCK], True, tl.int1)
23
+ x0 = (xindex % 12288)
24
+ x1 = xindex // 12288
25
+ x2 = xindex
26
+ tmp0 = tl.load(in_ptr0 + (x0 + 24576*x1), None).to(tl.float32)
27
+ tmp5 = tl.load(in_ptr0 + (12288 + x0 + 24576*x1), None).to(tl.float32)
28
+ tmp1 = tmp0.to(tl.float32)
29
+ tmp2 = tl.sigmoid(tmp1)
30
+ tmp3 = tmp1 * tmp2
31
+ tmp4 = tmp3.to(tl.float32)
32
+ tmp6 = tmp4 * tmp5
33
+ tl.store(out_ptr0 + (x2), tmp6, None)
torchinductor/7f/be95397d0c18f43f4314e0cac66d456d9d3e2b12116963a4bf988016e97f7a5e.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1024, "num_warps": 4, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 45, "triton_cache_hash": "EQOEBZDPMDVSX6EJFLBNKY5DUKJXFLSS4SF4QQZQUN6AV3JLHKJQ"}
torchinductor/7f/c7ff4ib6652ojllutm4c7mkzzpybond3pagu3glspw3sztkfe2za.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 8388608},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 67117056}},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_mul_1(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
19
+ xnumel = 8388608
20
+ xoffset = tl.program_id(0) * XBLOCK
21
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
22
+ xmask = tl.full([XBLOCK], True, tl.int1)
23
+ x2 = xindex
24
+ x0 = (xindex % 4096)
25
+ tmp0 = tl.load(in_ptr0 + (x2), None).to(tl.float32)
26
+ tmp1 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
27
+ tmp2 = tl.load(in_ptr2 + (x2), None).to(tl.float32)
28
+ tmp3 = tmp1 * tmp2
29
+ tmp4 = tmp0 + tmp3
30
+ tl.store(out_ptr0 + (x2), tmp4, None)
torchinductor/a3/94dc88253134d772dc28ed260760d9a0059b054d472700be3c22dd06b228f22f.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1, "R0_BLOCK": 2048, "num_warps": 16, "num_stages": 1, "configs_hash": "ba27f374f6982634f1ab959ad1e63f726920cfc2c7c821f8e68ec55c3d4d94fc", "found_by_coordesc": false, "time_taken_ms": 35, "triton_cache_hash": "H6VG26TW2DOV7R3PXVPFDX6HZCVIESL5ZYKZWLUWKZYONCE6NSLQ"}
torchinductor/a3/ca3menlfuldthgmncfpjk452xkros7idrmil6pcoeigraymcg4e6.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 256, 'r0_': 4096},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr0': '*bf16', 'out_ptr3': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_add_mul_native_layer_norm_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 6, 'num_store': 2, 'num_reduction': 2, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 12607488}}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_add_mul_native_layer_norm_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 256
20
+ r0_numel = 4096
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
27
+ rbase = r0_base
28
+ x0 = xindex
29
+ tmp7_mean = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
30
+ tmp7_m2 = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
31
+ tmp7_weight = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
32
+ for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
33
+ r0_index = r0_offset + r0_base
34
+ r0_mask = r0_index < r0_numel
35
+ roffset = r0_offset
36
+ rindex = r0_index
37
+ r0_1 = r0_index
38
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
39
+ tmp1 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
40
+ tmp2 = tl.load(in_ptr2 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
41
+ tmp3 = tmp1 * tmp2
42
+ tmp4 = tmp0 + tmp3
43
+ tmp5 = tmp4.to(tl.float32)
44
+ tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK])
45
+ tmp7_mean_next, tmp7_m2_next, tmp7_weight_next = triton_helpers.welford_reduce(
46
+ tmp6, tmp7_mean, tmp7_m2, tmp7_weight, roffset == 0
47
+ )
48
+ tmp7_mean = tl.where(r0_mask & xmask, tmp7_mean_next, tmp7_mean)
49
+ tmp7_m2 = tl.where(r0_mask & xmask, tmp7_m2_next, tmp7_m2)
50
+ tmp7_weight = tl.where(r0_mask & xmask, tmp7_weight_next, tmp7_weight)
51
+ tl.store(out_ptr0 + (r0_1 + 4096*x0), tmp4, r0_mask & xmask)
52
+ tmp8, tmp9, tmp10 = triton_helpers.welford(tmp7_mean, tmp7_m2, tmp7_weight, 1)
53
+ tmp7 = tmp8[:, None]
54
+ tmp11 = tmp9[:, None]
55
+ tmp12 = tmp10[:, None]
56
+ for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
57
+ r0_index = r0_offset + r0_base
58
+ r0_mask = r0_index < r0_numel
59
+ roffset = r0_offset
60
+ rindex = r0_index
61
+ r0_1 = r0_index
62
+ tmp13 = tl.load(out_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
63
+ tmp23 = tl.load(in_ptr3 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
64
+ tmp27 = tl.load(in_ptr4 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
65
+ tmp14 = tmp13.to(tl.float32)
66
+ tmp15 = tmp14 - tmp7
67
+ tmp16 = 4096.0
68
+ tmp17 = (tmp11 / tmp16)
69
+ tmp18 = 1e-06
70
+ tmp19 = tmp17 + tmp18
71
+ tmp20 = libdevice.rsqrt(tmp19)
72
+ tmp21 = tmp15 * tmp20
73
+ tmp22 = tmp21.to(tl.float32)
74
+ tmp24 = 1.0
75
+ tmp25 = tmp23 + tmp24
76
+ tmp26 = tmp22 * tmp25
77
+ tmp28 = tmp26 + tmp27
78
+ tl.store(out_ptr3 + (r0_1 + 4096*x0), tmp28, r0_mask & xmask)
torchinductor/aotautograd/a27rkqg32yfaub3aygtms2gl3oet2qxfcnp4zxa3zy5h6c3risxz/aw5eda3h36wpnnltujgkb4mvobznersd4fuvo2p7vy2quujasos ADDED
Binary file (52.3 kB). View file
 
torchinductor/aotautograd/a3443o3ywoehrda4trn5q47mauudwcinftvd52hitdnfmakyhqc4/lw6yvpbd45y77sg6fh5v4otbinchkwbf7b56u3rh3wgq3x2wkhq ADDED
Binary file (54.7 kB). View file
 
torchinductor/aotautograd/a3554ihbxq57jan4ib74iqo5mnaqevqume4yzewukzkm6ehpsilz/eubahghkef62rmchvnle5v6h3ddip4av5qqjxomdlm7ura45qve ADDED
Binary file (54.9 kB). View file
 
torchinductor/aotautograd/a3hojixb5fzn7f7jfco3ddoohdsuggk4qbop3lcg7rjy3e7fkgfz/o7wvolbgborwtoofbovayor23y4ubooymfcvv6jeqm2wbx3n2cs ADDED
Binary file (74.3 kB). View file
 
torchinductor/aotautograd/a54twb2qknddjxnxtmkoagy3umo5y3ptsesm2pdhy7nkefklf6wx/emxzj524wmpvifsxw4dsnnkzemqpzfgkenbo5obwmksvlhsr354 ADDED
Binary file (54.7 kB). View file
 
torchinductor/aotautograd/a5ksywxhfabbequvxwstheyyj5w3sinuubxcrypqjwbqsyw5la3l/ew4fxjyfoflznyuws2w2ylu4p7owpjuqoshsef75w43w2vvwejd ADDED
Binary file (55.1 kB). View file
 
torchinductor/aotautograd/a7ptufzlocphh5n5o5u63gfzkf74tjb3l5is45u5hqjspv32qda6/an4kgppgf4vt5yfvvghrnmho6jc3qnj4l6c75zrsiotr5d4u5gv ADDED
Binary file (55.1 kB). View file
 
torchinductor/aotautograd/aal6kceyfi7eazavxzpgcec5hzt32bkwo7p4doeyc56ubzlwuvx4/nkoni3ckgbheucucq64bmrta4lhz7x237lalaqcrejvdc3supg4 ADDED
Binary file (55.1 kB). View file
 
torchinductor/aotautograd/aan5kpy6i54rnpeu5vlzbx6i6blimsvhducl7futzdjr4xciy472/a35s4usnkzmh6ybhedo3b6zehfepmwdv2gxscayjeeuucr3zat7 ADDED
Binary file (54.6 kB). View file
 
torchinductor/aotautograd/aesonb7djseswkbtu2qzhvg6ikd5rewxnqlt6pwuytadpxxmjcod/lap2sypphhofd6d5rhojruk2vfyvw2olc7gtulmom4i5y7ix2cp ADDED
Binary file (62.5 kB). View file
 
torchinductor/aotautograd/age65c4dyk2rxcqufpxd6bsafzao7tacrsvejbf3pjbsngnoashv/upzttal3jaj233iyzyps7mjpq75jt6qi6rzramvgyyewfg76h6s ADDED
Binary file (83.4 kB). View file
 
torchinductor/aotautograd/ahkpwjcp2qqyj6wu2ckjqlrit2pbb3ig3ddi75hgbkgngvvipwyq/ha76p7wv3nimmrgvx6kdiqikd6adbw7nlnaiars5ey4anx46mwn ADDED
Binary file (54.4 kB). View file
 
torchinductor/aotautograd/aiojzczi5txclvaydkrk5g3qlf33pdkkhxtefkhfphkpc3o6rr4p/w3n37k3qhqfhuewneurnairyblp3h7nrak6oyp2p3um7uwnfcz5 ADDED
Binary file (52 kB). View file
 
torchinductor/aotautograd/ajdkg3gacw25klanvqotc3mkab3mi23jtjpagxrosdmqv3d4yg7v/ejzrqbsrchqzxfppkzo4ep7edhv7lrjjbcdxkxvodbk4vvk3b62 ADDED
Binary file (52.7 kB). View file
 
torchinductor/aotautograd/amb262dx57ptj6gg2ch6skr372w6arsr3i7i4ed5pljhiycuxduw/fntav2w4z5lvr443jxseqalau2vuzp7x7ljd3hanoqubtutjkvp ADDED
Binary file (56.9 kB). View file
 
torchinductor/aotautograd/amjjivi2p6firai3idkjgfxyy6z4prevujsjdno2uuchwvd7xqll/enc6ruqcyggs4mnt54tjdd2lvexcvipd5vhhamxwcj77g5fpyof ADDED
Binary file (54.7 kB). View file
 
torchinductor/aotautograd/apfaqlwe555qd2zoz575w5mvoxoiasmcomkv76mhz5zvnm5jok66/epmli5r46rzrqf73pqrnb5tratdg3mbbwdf5vyzqr6ejyhnooye ADDED
Binary file (55.1 kB). View file
 
torchinductor/aotautograd/asjbg7f735jw54kcldmvv5uost22wzpy3hkxgaihos4rllvagheu/lwqpsnp52rszp2nlwkgi33embno5st2u5bxfm4rpyoy6fql5aor ADDED
Binary file (55.1 kB). View file
 
torchinductor/aotautograd/atc2ggqhejcse5aydwh2wjakijsc2dyhqjxwdqrwpra3mgjwe4st/xwy7lzraqocjillvk4s2yc2qhpkx43s2nbkxmeb2wpph3sgyc7n ADDED
Binary file (56.1 kB). View file
 
torchinductor/aotautograd/atsevoi6zqdcnehuxassvjosi3j5vrk54uisibylfgspeewp6vyx/4sfzv7d6ch2yoi6nnr5ym3i6yibku3vfveyrr6sx6dqbmavxo32 ADDED
Binary file (55.1 kB). View file
 
torchinductor/aotautograd/ax7bbwqbruobasu7vagn2oj2owh5vgosxbjelta324rvf4tkesd4/ipnutob47ydixp2zetluyw4apg7fe5sfkkiianwaawh6yq3uang ADDED
Binary file (52.8 kB). View file
 
torchinductor/aotautograd/ay26zyuzpll2prvy7zzoeydo7r47lrr6s6jcmzi2zmytjxzebmnz/nzx7lukg3r25p6sjlwtqmkf6gmgzuq7iwagwki2x4kvhw5ducr5 ADDED
Binary file (59.5 kB). View file
 
torchinductor/aotautograd/ay65riayezoo7bqggl72pzrzdi6lvy5mp23ajx4f453ylzpmve3s/p7clvcke3bsgsaumutstrxc7bkq4tq6yoia7nwigana3n3unini ADDED
Binary file (56.5 kB). View file
 
torchinductor/aotautograd/azyih32olvhzuay5zpfypzhk2cdlosvaqxdhcnjzlwfs6k3a2ne6/5sz2kjdze7ixdny7hz24p4uma7uup7chdcpiumqznifqn4mpmqb ADDED
Binary file (62.7 kB). View file
 
torchinductor/av/cavoaz6e7kbk5wq2n7vz6rxhcrwdu2trazexubdq5qwyv2ajmbkz.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 256, 'r0_': 4096},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr2': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_add_mul_native_layer_norm_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 2, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 6307840}}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_add_mul_native_layer_norm_1(in_ptr0, in_ptr1, in_ptr2, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 256
20
+ r0_numel = 4096
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
27
+ rbase = r0_base
28
+ x0 = xindex
29
+ tmp3_mean = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
30
+ tmp3_m2 = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
31
+ tmp3_weight = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
32
+ for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
33
+ r0_index = r0_offset + r0_base
34
+ r0_mask = r0_index < r0_numel
35
+ roffset = r0_offset
36
+ rindex = r0_index
37
+ r0_1 = r0_index
38
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
39
+ tmp1 = tmp0.to(tl.float32)
40
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
41
+ tmp3_mean_next, tmp3_m2_next, tmp3_weight_next = triton_helpers.welford_reduce(
42
+ tmp2, tmp3_mean, tmp3_m2, tmp3_weight, roffset == 0
43
+ )
44
+ tmp3_mean = tl.where(r0_mask & xmask, tmp3_mean_next, tmp3_mean)
45
+ tmp3_m2 = tl.where(r0_mask & xmask, tmp3_m2_next, tmp3_m2)
46
+ tmp3_weight = tl.where(r0_mask & xmask, tmp3_weight_next, tmp3_weight)
47
+ tmp4, tmp5, tmp6 = triton_helpers.welford(tmp3_mean, tmp3_m2, tmp3_weight, 1)
48
+ tmp3 = tmp4[:, None]
49
+ tmp7 = tmp5[:, None]
50
+ tmp8 = tmp6[:, None]
51
+ for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
52
+ r0_index = r0_offset + r0_base
53
+ r0_mask = r0_index < r0_numel
54
+ roffset = r0_offset
55
+ rindex = r0_index
56
+ r0_1 = r0_index
57
+ tmp9 = tl.load(in_ptr1 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
58
+ tmp12 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
59
+ tmp23 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
60
+ tmp10 = 1.0
61
+ tmp11 = tmp9 + tmp10
62
+ tmp13 = tmp12.to(tl.float32)
63
+ tmp14 = tmp13 - tmp3
64
+ tmp15 = 4096.0
65
+ tmp16 = (tmp7 / tmp15)
66
+ tmp17 = 1e-06
67
+ tmp18 = tmp16 + tmp17
68
+ tmp19 = libdevice.rsqrt(tmp18)
69
+ tmp20 = tmp14 * tmp19
70
+ tmp21 = tmp20.to(tl.float32)
71
+ tmp22 = tmp11 * tmp21
72
+ tmp24 = tmp22 + tmp23
73
+ tl.store(out_ptr2 + (r0_1 + 4096*x0), tmp24, r0_mask & xmask)
torchinductor/av/d186a24d3c8af5514b42dea48fc981efd3f5afb7bba6c30406e42c75862888b1.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1, "R0_BLOCK": 4096, "num_warps": 16, "num_stages": 1, "configs_hash": "ba27f374f6982634f1ab959ad1e63f726920cfc2c7c821f8e68ec55c3d4d94fc", "found_by_coordesc": false, "time_taken_ms": 33, "triton_cache_hash": "CYKNGA4OMPRI7EV7H5FM47DKU7VFZ4Q5NYQGPNW6ZIVYBLBWPVMA"}
torchinductor/ay/cayicsdjyjxzpcmkvjbneubnqkuhs3y37qiwy5qlel3z2loa4qav.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['1_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+
17
+ aten = torch.ops.aten
18
+ inductor_ops = torch.ops.inductor
19
+ _quantized = torch.ops._quantized
20
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
21
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
22
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
23
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
24
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
25
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
26
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
27
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
28
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
29
+ async_compile = AsyncCompile()
30
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
31
+
32
+
33
+ async_compile.wait(globals())
34
+ del async_compile
35
+
36
+ class Runner:
37
+ def __init__(self, partitions):
38
+ self.partitions = partitions
39
+
40
+ def recursively_apply_fns(self, fns):
41
+ new_callables = []
42
+ for fn, c in zip(fns, self.partitions):
43
+ new_callables.append(fn(c))
44
+ self.partitions = new_callables
45
+
46
+ def call(self, args):
47
+ arg0_1, arg1_1 = args
48
+ args.clear()
49
+ assert_size_stride(arg0_1, (4096, 12288), (1, 4096))
50
+ assert_size_stride(arg1_1, (1, 1), (1, 1))
51
+ return (aten.view.dtype(reinterpret_tensor(arg0_1, (12288, 4096), (4096, 1), 0), torch.uint8), reinterpret_tensor(arg1_1, (1, ), (1, ), 0), )
52
+
53
+ runner = Runner(partitions=[])
54
+ call = runner.call
55
+ recursively_apply_fns = runner.recursively_apply_fns
56
+
57
+
58
+ def benchmark_compiled_module(times=10, repeat=10):
59
+ from torch._dynamo.testing import rand_strided
60
+ from torch._inductor.utils import print_performance
61
+ arg0_1 = rand_strided((4096, 12288), (1, 4096), device='cuda:0', dtype=torch.float8_e4m3fn)
62
+ arg1_1 = rand_strided((1, 1), (1, 1), device='cuda:0', dtype=torch.float32)
63
+ fn = lambda: call([arg0_1, arg1_1])
64
+ return print_performance(fn, times=times, repeat=repeat)
65
+
66
+
67
+ if __name__ == "__main__":
68
+ from torch._inductor.wrapper_benchmark import compiled_module_main
69
+ compiled_module_main('None', benchmark_compiled_module)
torchinductor/bv/7969eba2eb589b95d2894ee75ee67ba01cd2bee09cd64d315c70c0950888c19e.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 2, "R0_BLOCK": 128, "num_warps": 2, "num_stages": 1, "configs_hash": "6ffa43f2ca8cb1499f3ff3fbf8c975f2c07eef9b57fcecda113029ab12cbef66", "found_by_coordesc": false, "time_taken_ms": 307, "triton_cache_hash": "AQ3FCZKOYK5LBOX7RLBQGX5T77RKI4M7SEZTYJU34QROQSJNLP5A"}
torchinductor/bv/cbvqhjtyg7fvxzwtbtt4vrdkbnb6n32fnrijjpl3vv4cfqd4mznr.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 131072, 'r0_': 128},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_out_ptr1': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__fused_rms_norm__to_copy_add_mul_neg_split_split_with_sizes_stack_unbind_unsqueeze_view_0', 'mutated_arg_names': ['in_out_ptr0', 'in_out_ptr1'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 16, 'num_store': 2, 'num_reduction': 2, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 115606016}}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__fused_rms_norm__to_copy_add_mul_neg_split_split_with_sizes_stack_unbind_unsqueeze_view_0(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 73728
20
+ r0_numel = 128
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
26
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
27
+ rbase = r0_base
28
+ x0 = (xindex % 32)
29
+ x1 = xindex // 32
30
+ _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
31
+ x5 = xindex
32
+ _tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
33
+ for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
34
+ r0_index = r0_offset + r0_base
35
+ r0_mask = r0_index < r0_numel
36
+ roffset = r0_offset
37
+ rindex = r0_index
38
+ r0_2 = r0_index
39
+ tmp0 = tl.load(in_ptr0 + (4096 + r0_2 + 128*x0 + 36864*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
40
+ tmp6 = tl.load(in_ptr0 + (r0_2 + 128*x0 + 36864*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
41
+ tmp1 = tmp0.to(tl.float32)
42
+ tmp2 = tmp1 * tmp1
43
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
44
+ tmp5 = _tmp4 + tmp3
45
+ _tmp4 = tl.where(r0_mask, tmp5, _tmp4)
46
+ tmp7 = tmp6.to(tl.float32)
47
+ tmp8 = tmp7 * tmp7
48
+ tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
49
+ tmp11 = _tmp10 + tmp9
50
+ _tmp10 = tl.where(r0_mask, tmp11, _tmp10)
51
+ tmp4 = tl.sum(_tmp4, 1)[:, None]
52
+ tmp10 = tl.sum(_tmp10, 1)[:, None]
53
+ for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
54
+ r0_index = r0_offset + r0_base
55
+ r0_mask = r0_index < r0_numel
56
+ roffset = r0_offset
57
+ rindex = r0_index
58
+ r0_3 = (r0_index % 2)
59
+ r0_4 = r0_index // 2
60
+ r0_2 = r0_index
61
+ tmp50 = tl.load(in_ptr0 + (r0_2 + 128*x0 + 36864*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
62
+ tmp58 = tl.load(in_ptr1 + (r0_2), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
63
+ tmp63 = tl.load(in_ptr2 + (r0_2 + 128*x1), r0_mask, eviction_policy='evict_last', other=0.0)
64
+ tmp66 = tl.load(in_ptr3 + (r0_2 + 128*x1), r0_mask, eviction_policy='evict_last', other=0.0)
65
+ tmp96 = tl.load(in_ptr0 + (4096 + r0_2 + 128*x0 + 36864*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
66
+ tmp102 = tl.load(in_ptr4 + (r0_2), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
67
+ tmp12 = r0_3
68
+ tmp13 = tl.full([1, 1], 0, tl.int64)
69
+ tmp14 = tmp12 >= tmp13
70
+ tmp15 = tl.full([1, 1], 1, tl.int64)
71
+ tmp16 = tmp12 < tmp15
72
+ tmp17 = tl.load(in_ptr0 + (1 + 2*r0_4 + 128*x0 + 36864*x1), r0_mask & tmp16, eviction_policy='evict_last', other=0.0).to(tl.float32)
73
+ tmp18 = tmp17.to(tl.float32)
74
+ tmp19 = 128.0
75
+ tmp20 = (tmp10 / tmp19)
76
+ tmp21 = 1e-06
77
+ tmp22 = tmp20 + tmp21
78
+ tmp23 = libdevice.rsqrt(tmp22)
79
+ tmp24 = tmp18 * tmp23
80
+ tmp25 = tl.load(in_ptr1 + (tl.broadcast_to(1 + 2*r0_4, [XBLOCK, R0_BLOCK])), r0_mask & tmp16, eviction_policy='evict_last', other=0.0).to(tl.float32)
81
+ tmp26 = tmp25.to(tl.float32)
82
+ tmp27 = tmp24 * tmp26
83
+ tmp28 = tmp27.to(tl.float32)
84
+ tmp29 = -tmp28
85
+ tmp30 = tl.full(tmp29.shape, 0.0, tmp29.dtype)
86
+ tmp31 = tl.where(tmp16, tmp29, tmp30)
87
+ tmp32 = tmp12 >= tmp15
88
+ tmp33 = tl.full([1, 1], 2, tl.int64)
89
+ tmp34 = tmp12 < tmp33
90
+ tmp35 = tl.load(in_ptr0 + (2*r0_4 + 128*x0 + 36864*x1), r0_mask & tmp32, eviction_policy='evict_last', other=0.0).to(tl.float32)
91
+ tmp36 = tmp35.to(tl.float32)
92
+ tmp37 = 128.0
93
+ tmp38 = (tmp10 / tmp37)
94
+ tmp39 = 1e-06
95
+ tmp40 = tmp38 + tmp39
96
+ tmp41 = libdevice.rsqrt(tmp40)
97
+ tmp42 = tmp36 * tmp41
98
+ tmp43 = tl.load(in_ptr1 + (tl.broadcast_to(2*r0_4, [XBLOCK, R0_BLOCK])), r0_mask & tmp32, eviction_policy='evict_last', other=0.0).to(tl.float32)
99
+ tmp44 = tmp43.to(tl.float32)
100
+ tmp45 = tmp42 * tmp44
101
+ tmp46 = tmp45.to(tl.float32)
102
+ tmp47 = tl.full(tmp46.shape, 0.0, tmp46.dtype)
103
+ tmp48 = tl.where(tmp32, tmp46, tmp47)
104
+ tmp49 = tl.where(tmp16, tmp31, tmp48)
105
+ tmp51 = tmp50.to(tl.float32)
106
+ tmp52 = 128.0
107
+ tmp53 = (tmp10 / tmp52)
108
+ tmp54 = 1e-06
109
+ tmp55 = tmp53 + tmp54
110
+ tmp56 = libdevice.rsqrt(tmp55)
111
+ tmp57 = tmp51 * tmp56
112
+ tmp59 = tmp58.to(tl.float32)
113
+ tmp60 = tmp57 * tmp59
114
+ tmp61 = tmp60.to(tl.float32)
115
+ tmp62 = tmp61.to(tl.float32)
116
+ tmp64 = tmp62 * tmp63
117
+ tmp65 = tmp49.to(tl.float32)
118
+ tmp67 = tmp65 * tmp66
119
+ tmp68 = tmp64 + tmp67
120
+ tmp69 = tmp68.to(tl.float32)
121
+ tmp70 = tl.load(in_ptr0 + (4097 + 2*r0_4 + 128*x0 + 36864*x1), r0_mask & tmp16, eviction_policy='evict_last', other=0.0).to(tl.float32)
122
+ tmp71 = tmp70.to(tl.float32)
123
+ tmp72 = (tmp4 / tmp19)
124
+ tmp73 = tmp72 + tmp21
125
+ tmp74 = libdevice.rsqrt(tmp73)
126
+ tmp75 = tmp71 * tmp74
127
+ tmp76 = tl.load(in_ptr4 + (tl.broadcast_to(1 + 2*r0_4, [XBLOCK, R0_BLOCK])), r0_mask & tmp16, eviction_policy='evict_last', other=0.0).to(tl.float32)
128
+ tmp77 = tmp76.to(tl.float32)
129
+ tmp78 = tmp75 * tmp77
130
+ tmp79 = tmp78.to(tl.float32)
131
+ tmp80 = -tmp79
132
+ tmp81 = tl.full(tmp80.shape, 0.0, tmp80.dtype)
133
+ tmp82 = tl.where(tmp16, tmp80, tmp81)
134
+ tmp83 = tl.load(in_ptr0 + (4096 + 2*r0_4 + 128*x0 + 36864*x1), r0_mask & tmp32, eviction_policy='evict_last', other=0.0).to(tl.float32)
135
+ tmp84 = tmp83.to(tl.float32)
136
+ tmp85 = (tmp4 / tmp37)
137
+ tmp86 = tmp85 + tmp39
138
+ tmp87 = libdevice.rsqrt(tmp86)
139
+ tmp88 = tmp84 * tmp87
140
+ tmp89 = tl.load(in_ptr4 + (tl.broadcast_to(2*r0_4, [XBLOCK, R0_BLOCK])), r0_mask & tmp32, eviction_policy='evict_last', other=0.0).to(tl.float32)
141
+ tmp90 = tmp89.to(tl.float32)
142
+ tmp91 = tmp88 * tmp90
143
+ tmp92 = tmp91.to(tl.float32)
144
+ tmp93 = tl.full(tmp92.shape, 0.0, tmp92.dtype)
145
+ tmp94 = tl.where(tmp32, tmp92, tmp93)
146
+ tmp95 = tl.where(tmp16, tmp82, tmp94)
147
+ tmp97 = tmp96.to(tl.float32)
148
+ tmp98 = (tmp4 / tmp52)
149
+ tmp99 = tmp98 + tmp54
150
+ tmp100 = libdevice.rsqrt(tmp99)
151
+ tmp101 = tmp97 * tmp100
152
+ tmp103 = tmp102.to(tl.float32)
153
+ tmp104 = tmp101 * tmp103
154
+ tmp105 = tmp104.to(tl.float32)
155
+ tmp106 = tmp105.to(tl.float32)
156
+ tmp107 = tmp106 * tmp63
157
+ tmp108 = tmp95.to(tl.float32)
158
+ tmp109 = tmp108 * tmp66
159
+ tmp110 = tmp107 + tmp109
160
+ tmp111 = tmp110.to(tl.float32)
161
+ tl.store(in_out_ptr0 + (r0_2 + 128*x5), tmp69, r0_mask)
162
+ tl.store(in_out_ptr1 + (r0_2 + 128*x5), tmp111, r0_mask)
torchinductor/cr/ccr2gijy4jp6vvdbewmzgaogxbf5as7ytxtou4zo2yelawomrjjg.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['21_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /app/tensorrt_llm/visual_gen/compiled_cache/flux2_klein_9b_NVIDIA_GeForce_RTX_4090_sm89_torch2.10.0a0_b4e4ee81d3.nv25.12_cuda13_1/torchinductor/sy/csyae3ok2xnzuxhjkxzhdcpcz6jckcu3vv7eqb3pewhrvqmiergf.py
38
+ # Topologically Sorted Source Nodes: [chunk, silu, x], Original ATen: [aten.split, aten.silu, aten.mul]
39
+ # Source node to ATen node mapping:
40
+ # chunk => split
41
+ # silu => convert_element_type, convert_element_type_1, mul_6, sigmoid
42
+ # x => mul_10
43
+ # Graph fragment:
44
+ # %arg1_1 : Tensor "bf16[1, s67, 24576][24576*s67, 24576, 1]cuda:0" = PlaceHolder[target=arg1_1]
45
+ # %split : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%arg1_1, 12288, -1), kwargs = {})
46
+ # %convert_element_type : Tensor "f32[1, s67, 12288][12288*s67, 12288, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem, torch.float32), kwargs = {})
47
+ # %sigmoid : Tensor "f32[1, s67, 12288][12288*s67, 12288, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%convert_element_type,), kwargs = {})
48
+ # %mul_6 : Tensor "f32[1, s67, 12288][12288*s67, 12288, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %sigmoid), kwargs = {})
49
+ # %convert_element_type_1 : Tensor "bf16[1, s67, 12288][12288*s67, 12288, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_6, torch.bfloat16), kwargs = {})
50
+ # %mul_10 : Tensor "bf16[1, s67, 12288][12288*s67, 12288, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1, %getitem_1), kwargs = {})
51
+ # return %mul_10
52
+ triton_poi_fused_mul_silu_split_0 = async_compile.triton('triton_poi_fused_mul_silu_split_0', '''
53
+ import triton
54
+ import triton.language as tl
55
+
56
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
57
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
58
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
59
+ triton_helpers.set_driver_to_gpu()
60
+
61
+ @triton_heuristics.pointwise(
62
+ size_hints={'x': 4194304},
63
+ filename=__file__,
64
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
65
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_silu_split_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False},
66
+ min_elem_per_thread=0
67
+ )
68
+ @triton.jit
69
+ def triton_poi_fused_mul_silu_split_0(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
70
+ xoffset = tl.program_id(0) * XBLOCK
71
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
72
+ xmask = tl.full([XBLOCK], True, tl.int1)
73
+ x0 = (xindex % 12288)
74
+ x1 = xindex // 12288
75
+ x2 = xindex
76
+ tmp0 = tl.load(in_ptr0 + (x0 + 24576*x1), None).to(tl.float32)
77
+ tmp5 = tl.load(in_ptr0 + (12288 + x0 + 24576*x1), None).to(tl.float32)
78
+ tmp1 = tmp0.to(tl.float32)
79
+ tmp2 = tl.sigmoid(tmp1)
80
+ tmp3 = tmp1 * tmp2
81
+ tmp4 = tmp3.to(tl.float32)
82
+ tmp6 = tmp4 * tmp5
83
+ tl.store(out_ptr0 + (x2), tmp6, None)
84
+ ''', device_str='cuda')
85
+
86
+
87
+ async_compile.wait(globals())
88
+ del async_compile
89
+
90
+ class Runner:
91
+ def __init__(self, partitions):
92
+ self.partitions = partitions
93
+
94
+ def recursively_apply_fns(self, fns):
95
+ new_callables = []
96
+ for fn, c in zip(fns, self.partitions):
97
+ new_callables.append(fn(c))
98
+ self.partitions = new_callables
99
+
100
+ def call(self, args):
101
+ arg0_1, arg1_1 = args
102
+ args.clear()
103
+ s67 = arg0_1
104
+ assert_size_stride(arg1_1, (1, s67, 24576), (24576*s67, 24576, 1))
105
+ with torch.cuda._DeviceGuard(0):
106
+ torch.cuda.set_device(0)
107
+ buf0 = empty_strided_cuda((1, s67, 12288), (12288*s67, 12288, 1), torch.bfloat16)
108
+ # Topologically Sorted Source Nodes: [chunk, silu, x], Original ATen: [aten.split, aten.silu, aten.mul]
109
+ triton_poi_fused_mul_silu_split_0_xnumel = 12288*s67
110
+ stream0 = get_raw_stream(0)
111
+ triton_poi_fused_mul_silu_split_0.run(arg1_1, buf0, triton_poi_fused_mul_silu_split_0_xnumel, stream=stream0)
112
+ del arg1_1
113
+ return (buf0, )
114
+
115
+ runner = Runner(partitions=[])
116
+ call = runner.call
117
+ recursively_apply_fns = runner.recursively_apply_fns
118
+
119
+
120
+ def benchmark_compiled_module(times=10, repeat=10):
121
+ from torch._dynamo.testing import rand_strided
122
+ from torch._inductor.utils import print_performance
123
+ arg0_1 = 256
124
+ arg1_1 = rand_strided((1, 256, 24576), (6291456, 24576, 1), device='cuda:0', dtype=torch.bfloat16)
125
+ fn = lambda: call([arg0_1, arg1_1])
126
+ return print_performance(fn, times=times, repeat=repeat)
127
+
128
+
129
+ if __name__ == "__main__":
130
+ from torch._inductor.wrapper_benchmark import compiled_module_main
131
+ compiled_module_main('None', benchmark_compiled_module)
torchinductor/cz/bb6645c6be31f426023ec47eef09e354ad9fa8b2d59e6e45ab49b803eb34d44e.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 41, "triton_cache_hash": "SWIO2NFSYH3NKX6EWLJXN7WN2QH2K7ETN3JE2BQRCZXLIIDUWOOA"}
torchinductor/cz/cczg7tpituprwgqpuajzy2nylfk43mdozd5vwo77muq3kospnf7b.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 8388608},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=128, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '139C22A3A3C364569C9941DE9469DCB674B7A631E094782CBD415193800462F6', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 50331648}},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_clone_0(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
19
+ xnumel = 8388608
20
+ xoffset = tl.program_id(0) * XBLOCK
21
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
22
+ xmask = tl.full([XBLOCK], True, tl.int1)
23
+ x0 = xindex
24
+ tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
25
+ tl.store(out_ptr0 + (x0), tmp0, None)