officeuseaitf2024 commited on
Commit
d59afab
·
verified ·
1 Parent(s): 34b3110

initial Commit

Browse files
Files changed (2) hide show
  1. compare_models_v3_fixed.py +218 -0
  2. model_matching.tflite +3 -0
compare_models_v3_fixed.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+ from transformers import AutoTokenizer, AutoModel
4
+ from numpy.linalg import norm
5
+ import torch
6
+
7
+ # -----------------------------------------------------------
8
+ # CONFIG
9
+ # -----------------------------------------------------------
10
+ targetSentence = "multiply button"
11
+
12
+ candidateSentences = [
13
+ "add button",
14
+ ]
15
+
16
+ # -----------------------------------------------------------
17
+ # LOAD TFLITE MODEL
18
+ # -----------------------------------------------------------
19
+ print("="*80)
20
+ print("LOADING TFLITE MODEL")
21
+ print("="*80)
22
+
23
+ interpreter = tf.lite.Interpreter(model_path="ai-edge-torch/model_matching.tflite")
24
+ interpreter.allocate_tensors()
25
+
26
+ input_details = interpreter.get_input_details()
27
+ output_details = interpreter.get_output_details()
28
+
29
+ print(f"\nTFLite Model has {len(output_details)} outputs:")
30
+ for i, detail in enumerate(output_details):
31
+ print(f" Output {i}: {detail['name']} - Shape: {detail['shape']}")
32
+
33
+ tflite_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
34
+
35
+ # -----------------------------------------------------------
36
+ # LOAD HF MODEL FOR EMBEDDINGS
37
+ # -----------------------------------------------------------
38
+ print("\nLOADING HUGGINGFACE MODEL...")
39
+ hf_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
40
+ hf_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
41
+
42
+
43
+ # -----------------------------------------------------------
44
+ # UTIL: Cosine Similarity
45
+ # -----------------------------------------------------------
46
+ def cosine_similarity(a, b):
47
+ return np.dot(a, b) / (norm(a) * norm(b) + 1e-8)
48
+
49
+
50
+ # -----------------------------------------------------------
51
+ # UTIL: Mean Pooling (with attention mask)
52
+ # -----------------------------------------------------------
53
+ def mean_pooling(last_hidden_state, attention_mask):
54
+ """
55
+ Perform mean pooling on the last_hidden_state using attention mask.
56
+ This is the standard approach for sentence-transformers models.
57
+ """
58
+ # Expand attention mask to match hidden state dimensions
59
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
60
+
61
+ # Sum embeddings weighted by attention mask
62
+ sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
63
+
64
+ # Sum attention mask (to get the actual length)
65
+ sum_mask = input_mask_expanded.sum(1)
66
+ sum_mask = torch.clamp(sum_mask, min=1e-9) # Avoid division by zero
67
+
68
+ # Divide to get mean
69
+ return sum_embeddings / sum_mask
70
+
71
+
72
+ def mean_pooling_numpy(last_hidden_state, attention_mask):
73
+ """
74
+ NumPy version of mean pooling for TFLite outputs.
75
+ last_hidden_state: [batch, seq_len, hidden_dim]
76
+ attention_mask: [batch, seq_len]
77
+ """
78
+ # Expand attention mask to match hidden state dimensions
79
+ input_mask_expanded = np.expand_dims(attention_mask, axis=-1) # [batch, seq_len, 1]
80
+ input_mask_expanded = np.broadcast_to(input_mask_expanded, last_hidden_state.shape) # [batch, seq_len, hidden_dim]
81
+
82
+ # Sum embeddings weighted by attention mask
83
+ sum_embeddings = np.sum(last_hidden_state * input_mask_expanded, axis=1) # [batch, hidden_dim]
84
+
85
+ # Sum attention mask
86
+ sum_mask = np.sum(input_mask_expanded, axis=1) # [batch, hidden_dim]
87
+ sum_mask = np.clip(sum_mask, a_min=1e-9, a_max=None) # Avoid division by zero
88
+
89
+ # Divide to get mean
90
+ return sum_embeddings / sum_mask
91
+
92
+
93
+ # -----------------------------------------------------------
94
+ # TFLITE ENCODING (CORRECTED)
95
+ # -----------------------------------------------------------
96
+ def encode_tflite(sentence):
97
+ """
98
+ Encode a sentence using TFLite model with proper mean pooling.
99
+ """
100
+ tokens = tflite_tokenizer(sentence, return_tensors="np",
101
+ padding="max_length", max_length=512, truncation=True)
102
+
103
+ interpreter.set_tensor(input_details[0]['index'], tokens["input_ids"].astype(np.int64))
104
+ interpreter.set_tensor(input_details[1]['index'], tokens["attention_mask"].astype(np.int64))
105
+ interpreter.invoke()
106
+
107
+ # Get the last_hidden_state (first output)
108
+ last_hidden_state = interpreter.get_tensor(output_details[0]['index'])
109
+
110
+ # Apply mean pooling with attention mask
111
+ pooled_output = mean_pooling_numpy(last_hidden_state, tokens["attention_mask"])
112
+
113
+ return pooled_output.reshape(-1)
114
+
115
+
116
+ # -----------------------------------------------------------
117
+ # HF PYTORCH ENCODING (CORRECTED)
118
+ # -----------------------------------------------------------
119
+ def encode_hf(sentence):
120
+ """
121
+ Encode a sentence using HuggingFace model with proper mean pooling.
122
+ """
123
+ tokens = hf_tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
124
+ with torch.no_grad():
125
+ model_output = hf_model(**tokens)
126
+
127
+ # Apply mean pooling with attention mask
128
+ embeddings = mean_pooling(model_output.last_hidden_state, tokens['attention_mask'])
129
+
130
+ return embeddings[0].numpy()
131
+
132
+
133
+ # -----------------------------------------------------------
134
+ # COMPUTE EMBEDDINGS
135
+ # -----------------------------------------------------------
136
+ print("\n" + "="*80)
137
+ print("ENCODING SENTENCES")
138
+ print("="*80)
139
+
140
+ print(f"\nTarget: '{targetSentence}'")
141
+ target_emb_tflite = encode_tflite(targetSentence)
142
+ target_emb_hf = encode_hf(targetSentence)
143
+
144
+ print(f"\nCandidates: {candidateSentences}")
145
+ candidate_embs_hf = [(sent, encode_hf(sent)) for sent in candidateSentences]
146
+ candidate_embs_tf = [(sent, encode_tflite(sent)) for sent in candidateSentences]
147
+
148
+
149
+ # -----------------------------------------------------------
150
+ # VERIFY CONVERSION CORRECTNESS
151
+ # -----------------------------------------------------------
152
+ print("\n" + "="*80)
153
+ print("VERIFYING TFLITE CONVERSION CORRECTNESS")
154
+ print("="*80)
155
+
156
+ # Compare embeddings for the same sentence from both models
157
+ similarity = cosine_similarity(target_emb_tflite, target_emb_hf)
158
+ print(f"\nTarget sentence embedding similarity (TFLite vs HF): {similarity:.6f}")
159
+
160
+ if similarity > 0.99:
161
+ print("✓ EXCELLENT: TFLite model conversion is highly accurate!")
162
+ elif similarity > 0.95:
163
+ print("✓ GOOD: TFLite model conversion is accurate (minor numerical differences)")
164
+ elif similarity > 0.90:
165
+ print("⚠ WARNING: TFLite model has some differences from original model")
166
+ else:
167
+ print("✗ ERROR: TFLite model outputs are significantly different from original model")
168
+
169
+ # Check candidate embeddings too
170
+ print("\nCandidate embeddings similarity:")
171
+ for i, (sent, _) in enumerate(candidate_embs_tf):
172
+ sim = cosine_similarity(candidate_embs_tf[i][1], candidate_embs_hf[i][1])
173
+ print(f" '{sent}': {sim:.6f}")
174
+
175
+
176
+ # -----------------------------------------------------------
177
+ # SIMILARITY COMPARISON
178
+ # -----------------------------------------------------------
179
+ print("\n" + "="*80)
180
+ print("SIMILARITY SCORES - HUGGINGFACE MODEL")
181
+ print("="*80)
182
+
183
+ for sent, emb in candidate_embs_hf:
184
+ score = cosine_similarity(target_emb_hf, emb)
185
+ print(f"\nTarget: \"{targetSentence}\"")
186
+ print(f"Candidate: \"{sent}\"")
187
+ print(f"Similarity Score: {score:.4f}")
188
+ print("-" * 80)
189
+
190
+
191
+ print("\n" + "="*80)
192
+ print("SIMILARITY SCORES - TFLITE MODEL")
193
+ print("="*80)
194
+
195
+ for sent, emb in candidate_embs_tf:
196
+ score = cosine_similarity(target_emb_tflite, emb)
197
+ print(f"\nTarget: \"{targetSentence}\"")
198
+ print(f"Candidate: \"{sent}\"")
199
+ print(f"Similarity Score: {score:.4f}")
200
+ print("-" * 80)
201
+
202
+
203
+ # -----------------------------------------------------------
204
+ # SUMMARY
205
+ # -----------------------------------------------------------
206
+ print("\n" + "="*80)
207
+ print("SUMMARY")
208
+ print("="*80)
209
+
210
+ print("\n✓ POST-PROCESSING APPLIED:")
211
+ print(" - TFLite: Mean pooling with attention mask on last_hidden_state")
212
+ print(" - HuggingFace: Mean pooling with attention mask on last_hidden_state")
213
+ print("\n✓ Both models now use the SAME pooling strategy")
214
+ print("\n✓ This is the standard approach for sentence-transformers models")
215
+
216
+ print("\n" + "="*80)
217
+ print("Completed.")
218
+ print("="*80)
model_matching.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d641155bee873a43689822ddce665958cd9020b2ad67050b3f00ec08d22a551
3
+ size 91095724