File size: 2,376 Bytes
c1af2fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#pragma once

#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>

#if !defined(__MAC_14_0) && (!defined(MAC_OS_X_VERSION_14_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_14_0))

typedef NS_ENUM(NSUInteger, MPSGraphFFTScalingMode) {
  MPSGraphFFTScalingModeNone = 0L,
  MPSGraphFFTScalingModeSize = 1L,
  MPSGraphFFTScalingModeUnitary = 2L,
};

@interface FakeMPSGraphFFTDescriptor : NSObject<NSCopying>
@property(readwrite, nonatomic) BOOL inverse;
@property(readwrite, nonatomic) MPSGraphFFTScalingMode scalingMode;
@property(readwrite, nonatomic) BOOL roundToOddHermitean;
+ (nullable instancetype)descriptor;
@end

@compatibility_alias MPSGraphFFTDescriptor FakeMPSGraphFFTDescriptor;

@interface MPSGraph (SonomaOps)
- (MPSGraphTensor* _Nonnull)conjugateWithTensor:(MPSGraphTensor* _Nonnull)tensor name:(NSString* _Nullable)name;

- (MPSGraphTensor* _Nonnull)realPartOfTensor:(MPSGraphTensor* _Nonnull)tensor name:(NSString* _Nullable)name;

- (MPSGraphTensor* _Nonnull)fastFourierTransformWithTensor:(MPSGraphTensor* _Nonnull)tensor
                                                      axes:(NSArray<NSNumber*>* _Nonnull)axes
                                                descriptor:(MPSGraphFFTDescriptor* _Nonnull)descriptor
                                                      name:(NSString* _Nullable)name;

- (MPSGraphTensor* _Nonnull)realToHermiteanFFTWithTensor:(MPSGraphTensor* _Nonnull)tensor
                                                    axes:(NSArray<NSNumber*>* _Nonnull)axes
                                              descriptor:(MPSGraphFFTDescriptor* _Nonnull)descriptor
                                                    name:(NSString* _Nullable)name;

- (MPSGraphTensor* _Nonnull)HermiteanToRealFFTWithTensor:(MPSGraphTensor* _Nonnull)tensor
                                                    axes:(NSArray<NSNumber*>* _Nonnull)axes
                                              descriptor:(MPSGraphFFTDescriptor* _Nonnull)descriptor
                                                    name:(NSString* _Nullable)name;
@end

// define BFloat16 enums for MacOS13
#define MPSDataTypeBFloat16 ((MPSDataType)(MPSDataTypeAlternateEncodingBit | MPSDataTypeFloat16))

// define Metal version
#define MTLLanguageVersion3_1 ((MTLLanguageVersion)((3 << 16) + 1))
#endif