MiniLmTflite / compare_models_v3_fixed.py
officeuseaitf2024's picture
initial Commit
d59afab verified
import tensorflow as tf
import numpy as np
from transformers import AutoTokenizer, AutoModel
from numpy.linalg import norm
import torch
# -----------------------------------------------------------
# CONFIG
# -----------------------------------------------------------
targetSentence = "multiply button"
candidateSentences = [
"add button",
]
# -----------------------------------------------------------
# LOAD TFLITE MODEL
# -----------------------------------------------------------
print("="*80)
print("LOADING TFLITE MODEL")
print("="*80)
interpreter = tf.lite.Interpreter(model_path="ai-edge-torch/model_matching.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(f"\nTFLite Model has {len(output_details)} outputs:")
for i, detail in enumerate(output_details):
print(f" Output {i}: {detail['name']} - Shape: {detail['shape']}")
tflite_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
# -----------------------------------------------------------
# LOAD HF MODEL FOR EMBEDDINGS
# -----------------------------------------------------------
print("\nLOADING HUGGINGFACE MODEL...")
hf_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
hf_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
# -----------------------------------------------------------
# UTIL: Cosine Similarity
# -----------------------------------------------------------
def cosine_similarity(a, b):
return np.dot(a, b) / (norm(a) * norm(b) + 1e-8)
# -----------------------------------------------------------
# UTIL: Mean Pooling (with attention mask)
# -----------------------------------------------------------
def mean_pooling(last_hidden_state, attention_mask):
"""
Perform mean pooling on the last_hidden_state using attention mask.
This is the standard approach for sentence-transformers models.
"""
# Expand attention mask to match hidden state dimensions
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
# Sum embeddings weighted by attention mask
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
# Sum attention mask (to get the actual length)
sum_mask = input_mask_expanded.sum(1)
sum_mask = torch.clamp(sum_mask, min=1e-9) # Avoid division by zero
# Divide to get mean
return sum_embeddings / sum_mask
def mean_pooling_numpy(last_hidden_state, attention_mask):
"""
NumPy version of mean pooling for TFLite outputs.
last_hidden_state: [batch, seq_len, hidden_dim]
attention_mask: [batch, seq_len]
"""
# Expand attention mask to match hidden state dimensions
input_mask_expanded = np.expand_dims(attention_mask, axis=-1) # [batch, seq_len, 1]
input_mask_expanded = np.broadcast_to(input_mask_expanded, last_hidden_state.shape) # [batch, seq_len, hidden_dim]
# Sum embeddings weighted by attention mask
sum_embeddings = np.sum(last_hidden_state * input_mask_expanded, axis=1) # [batch, hidden_dim]
# Sum attention mask
sum_mask = np.sum(input_mask_expanded, axis=1) # [batch, hidden_dim]
sum_mask = np.clip(sum_mask, a_min=1e-9, a_max=None) # Avoid division by zero
# Divide to get mean
return sum_embeddings / sum_mask
# -----------------------------------------------------------
# TFLITE ENCODING (CORRECTED)
# -----------------------------------------------------------
def encode_tflite(sentence):
"""
Encode a sentence using TFLite model with proper mean pooling.
"""
tokens = tflite_tokenizer(sentence, return_tensors="np",
padding="max_length", max_length=512, truncation=True)
interpreter.set_tensor(input_details[0]['index'], tokens["input_ids"].astype(np.int64))
interpreter.set_tensor(input_details[1]['index'], tokens["attention_mask"].astype(np.int64))
interpreter.invoke()
# Get the last_hidden_state (first output)
last_hidden_state = interpreter.get_tensor(output_details[0]['index'])
# Apply mean pooling with attention mask
pooled_output = mean_pooling_numpy(last_hidden_state, tokens["attention_mask"])
return pooled_output.reshape(-1)
# -----------------------------------------------------------
# HF PYTORCH ENCODING (CORRECTED)
# -----------------------------------------------------------
def encode_hf(sentence):
"""
Encode a sentence using HuggingFace model with proper mean pooling.
"""
tokens = hf_tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
model_output = hf_model(**tokens)
# Apply mean pooling with attention mask
embeddings = mean_pooling(model_output.last_hidden_state, tokens['attention_mask'])
return embeddings[0].numpy()
# -----------------------------------------------------------
# COMPUTE EMBEDDINGS
# -----------------------------------------------------------
print("\n" + "="*80)
print("ENCODING SENTENCES")
print("="*80)
print(f"\nTarget: '{targetSentence}'")
target_emb_tflite = encode_tflite(targetSentence)
target_emb_hf = encode_hf(targetSentence)
print(f"\nCandidates: {candidateSentences}")
candidate_embs_hf = [(sent, encode_hf(sent)) for sent in candidateSentences]
candidate_embs_tf = [(sent, encode_tflite(sent)) for sent in candidateSentences]
# -----------------------------------------------------------
# VERIFY CONVERSION CORRECTNESS
# -----------------------------------------------------------
print("\n" + "="*80)
print("VERIFYING TFLITE CONVERSION CORRECTNESS")
print("="*80)
# Compare embeddings for the same sentence from both models
similarity = cosine_similarity(target_emb_tflite, target_emb_hf)
print(f"\nTarget sentence embedding similarity (TFLite vs HF): {similarity:.6f}")
if similarity > 0.99:
print("✓ EXCELLENT: TFLite model conversion is highly accurate!")
elif similarity > 0.95:
print("✓ GOOD: TFLite model conversion is accurate (minor numerical differences)")
elif similarity > 0.90:
print("⚠ WARNING: TFLite model has some differences from original model")
else:
print("✗ ERROR: TFLite model outputs are significantly different from original model")
# Check candidate embeddings too
print("\nCandidate embeddings similarity:")
for i, (sent, _) in enumerate(candidate_embs_tf):
sim = cosine_similarity(candidate_embs_tf[i][1], candidate_embs_hf[i][1])
print(f" '{sent}': {sim:.6f}")
# -----------------------------------------------------------
# SIMILARITY COMPARISON
# -----------------------------------------------------------
print("\n" + "="*80)
print("SIMILARITY SCORES - HUGGINGFACE MODEL")
print("="*80)
for sent, emb in candidate_embs_hf:
score = cosine_similarity(target_emb_hf, emb)
print(f"\nTarget: \"{targetSentence}\"")
print(f"Candidate: \"{sent}\"")
print(f"Similarity Score: {score:.4f}")
print("-" * 80)
print("\n" + "="*80)
print("SIMILARITY SCORES - TFLITE MODEL")
print("="*80)
for sent, emb in candidate_embs_tf:
score = cosine_similarity(target_emb_tflite, emb)
print(f"\nTarget: \"{targetSentence}\"")
print(f"Candidate: \"{sent}\"")
print(f"Similarity Score: {score:.4f}")
print("-" * 80)
# -----------------------------------------------------------
# SUMMARY
# -----------------------------------------------------------
print("\n" + "="*80)
print("SUMMARY")
print("="*80)
print("\n✓ POST-PROCESSING APPLIED:")
print(" - TFLite: Mean pooling with attention mask on last_hidden_state")
print(" - HuggingFace: Mean pooling with attention mask on last_hidden_state")
print("\n✓ Both models now use the SAME pooling strategy")
print("\n✓ This is the standard approach for sentence-transformers models")
print("\n" + "="*80)
print("Completed.")
print("="*80)