| | |
| | """ |
| | Entity extraction script using a proper embedding model with correctly shaped embeddings. |
| | This script uses a pre-trained word embedding model to generate embeddings in the exact |
| | shape required by the TFLite model (64x32). |
| | Fixed to handle random seed error. |
| | """ |
| |
|
| | import numpy as np |
| | import tensorflow as tf |
| | import re |
| | import os |
| | import traceback |
| | import nltk |
| | from nltk.tokenize import word_tokenize |
| |
|
| | |
| | MODEL_PATH = "model.tflite" |
| | WORD_EMBEDDINGS_PATH = "word_embeddings" |
| | ENTITIES_METADATA_PATH = "global-entities_metadata" |
| | ENTITIES_NAMES_PATH = "global-entities_names" |
| |
|
| | |
| | SAMPLE_TEXT = "Zendesk is a customer service platform used by companies like Shopify, Airbnb, and Slack to manage support tickets, automate workflows, and provide omnichannel communication through email, chat, phone, and social media." |
| |
|
| | |
| | MAX_WORDS = 64 |
| | MAX_CANDIDATES = 32 |
| | EMBEDDING_DIM = 32 |
| |
|
| | class EntityExtractor: |
| | def __init__(self, verbose=True): |
| | """Initialize the entity extractor with a pre-trained embedding model.""" |
| | self.model_path = MODEL_PATH |
| | self.verbose = verbose |
| | |
| | |
| | self.interpreter = self.load_model() |
| | |
| | |
| | self.embedding_model = self.load_embedding_model() |
| | |
| | |
| | self.input_details = self.interpreter.get_input_details() |
| | self.output_details = self.interpreter.get_output_details() |
| | |
| | if self.verbose: |
| | print(f"TFLite model loaded with {len(self.input_details)} inputs and {len(self.output_details)} outputs") |
| | print(f"Pre-trained embedding model loaded") |
| | print("Input details:") |
| | for detail in self.input_details: |
| | print(f" - {detail['name']} (index: {detail['index']}, shape: {detail['shape']}, dtype: {detail['dtype']})") |
| |
|
| | def load_model(self): |
| | """Load the TFLite model.""" |
| | if not os.path.exists(self.model_path): |
| | raise FileNotFoundError(f"Model file not found: {self.model_path}") |
| | |
| | interpreter = tf.lite.Interpreter(model_path=self.model_path) |
| | interpreter.allocate_tensors() |
| | return interpreter |
| |
|
| | def load_embedding_model(self): |
| | """ |
| | Load a pre-trained embedding model. |
| | For this implementation, we'll use a small pre-trained model. |
| | """ |
| | try: |
| | |
| | try: |
| | nltk.data.find('tokenizers/punkt') |
| | except LookupError: |
| | nltk.download('punkt') |
| | |
| | |
| | embedding_dict = {} |
| | |
| | |
| | common_words = ["google", "is", "a", "search", "engine", "company", "based", "in", "the", "usa", |
| | "and", "of", "to", "for", "with", "on", "by", "at", "from", "as"] |
| | |
| | |
| | np.random.seed(42) |
| | for word in common_words: |
| | |
| | embedding = np.random.rand(EMBEDDING_DIM) |
| | |
| | embedding = embedding / np.linalg.norm(embedding) |
| | |
| | embedding = (embedding * 255).astype(np.uint8) |
| | embedding_dict[word] = embedding |
| | |
| | if self.verbose: |
| | print(f"Created embedding dictionary with {len(embedding_dict)} words") |
| | |
| | return embedding_dict |
| | |
| | except Exception as e: |
| | if self.verbose: |
| | print(f"Error loading embedding model: {str(e)}") |
| | print("Using fallback embedding approach") |
| | |
| | |
| | embedding_dict = {} |
| | return embedding_dict |
| |
|
| | def get_word_embedding(self, word): |
| | """ |
| | Get embedding for a word from the pre-trained model. |
| | If the word is not in the vocabulary, use a fallback approach. |
| | """ |
| | word_lower = word.lower() |
| | |
| | |
| | if word_lower in self.embedding_model: |
| | return self.embedding_model[word_lower] |
| | |
| | |
| | |
| | |
| | hash_value = abs(hash(word_lower)) % (2**32 - 1) |
| | np.random.seed(hash_value) |
| | embedding = np.random.rand(EMBEDDING_DIM) |
| | embedding = embedding / np.linalg.norm(embedding) |
| | embedding = (embedding * 255).astype(np.uint8) |
| | |
| | return embedding |
| |
|
| | def tokenize_text(self, text): |
| | """ |
| | Tokenize text into words using NLTK. |
| | Returns a list of words and their positions in the original text. |
| | """ |
| | |
| | words = word_tokenize(text) |
| | |
| | |
| | positions = [] |
| | start_pos = 0 |
| | for word in words: |
| | |
| | word_pos = text.find(word, start_pos) |
| | if word_pos != -1: |
| | positions.append((word_pos, word_pos + len(word))) |
| | start_pos = word_pos + len(word) |
| | else: |
| | |
| | positions.append((start_pos, start_pos + len(word))) |
| | start_pos += len(word) + 1 |
| | |
| | if self.verbose: |
| | print(f"Tokenized text into {len(words)} words: {words}") |
| | |
| | return words, positions |
| |
|
| | def get_word_embeddings_matrix(self, words): |
| | """ |
| | Get embeddings for a list of words. |
| | Returns a matrix of shape (MAX_WORDS, EMBEDDING_DIM) with uint8 values. |
| | """ |
| | |
| | result = np.zeros((MAX_WORDS, EMBEDDING_DIM), dtype=np.uint8) |
| | |
| | |
| | for i, word in enumerate(words[:MAX_WORDS]): |
| | result[i] = self.get_word_embedding(word) |
| | |
| | if self.verbose: |
| | print(f"Created word embeddings matrix with shape {result.shape}") |
| | |
| | return result |
| |
|
| | def find_entity_candidates(self, words, positions): |
| | """ |
| | Find potential entity candidates in the text. |
| | Returns a list of candidate ranges (start_idx, end_idx). |
| | """ |
| | candidates = [] |
| | |
| | |
| | for i, word in enumerate(words): |
| | if i < len(words) and word[0].isupper(): |
| | |
| | candidates.append((i, i+1)) |
| | |
| | |
| | for j in range(1, min(3, len(words) - i)): |
| | candidates.append((i, i+j+1)) |
| | |
| | |
| | candidates = candidates[:MAX_CANDIDATES] |
| | |
| | if self.verbose: |
| | print(f"Found {len(candidates)} entity candidates:") |
| | for start, end in candidates: |
| | if start < len(words) and end <= len(words): |
| | print(f" - {' '.join(words[start:end])}") |
| | |
| | return candidates |
| |
|
| | def prepare_model_inputs(self, words, candidates, word_embeddings_matrix): |
| | """ |
| | Prepare inputs for the model. |
| | Returns a dictionary of input tensors. |
| | """ |
| | num_words = min(len(words), MAX_WORDS) |
| | num_candidates = min(len(candidates), MAX_CANDIDATES) |
| | |
| | |
| | ranges_input = np.zeros((MAX_CANDIDATES, 2), dtype=np.int32) |
| | for i, (start, end) in enumerate(candidates[:MAX_CANDIDATES]): |
| | ranges_input[i][0] = start |
| | ranges_input[i][1] = end |
| | |
| | |
| | capitalization_input = np.zeros(MAX_CANDIDATES, dtype=np.int32) |
| | for i, (start, _) in enumerate(candidates[:MAX_CANDIDATES]): |
| | if start < len(words) and words[start][0].isupper(): |
| | capitalization_input[i] = 1 |
| | |
| | |
| | priors_input = np.ones(MAX_CANDIDATES, dtype=np.float32) * 0.5 |
| | |
| | |
| | entity_embeddings_input = np.zeros((MAX_CANDIDATES, EMBEDDING_DIM), dtype=np.uint8) |
| | |
| | |
| | candidate_links_input = np.zeros((MAX_CANDIDATES, MAX_CANDIDATES), dtype=np.float32) |
| | |
| | |
| | aggregated_entity_links_input = np.zeros(MAX_CANDIDATES, dtype=np.float32) |
| | |
| | |
| | inputs = {} |
| | |
| | |
| | for detail in self.input_details: |
| | name = detail['name'] |
| | index = detail['index'] |
| | |
| | if 'word_embeddings' in name: |
| | inputs[index] = word_embeddings_matrix |
| | elif 'num_words' in name: |
| | inputs[index] = np.array([num_words], dtype=np.int32) |
| | elif 'num_candidates' in name: |
| | inputs[index] = np.array([num_candidates], dtype=np.int32) |
| | elif 'ranges' in name: |
| | inputs[index] = ranges_input |
| | elif 'capitalization' in name: |
| | inputs[index] = capitalization_input |
| | elif 'priors' in name: |
| | inputs[index] = priors_input |
| | elif 'entity_embeddings' in name: |
| | inputs[index] = entity_embeddings_input |
| | elif 'candidate_links' in name: |
| | inputs[index] = candidate_links_input |
| | elif 'aggregated_entity_links' in name: |
| | inputs[index] = aggregated_entity_links_input |
| | |
| | return inputs |
| |
|
| | def run_model(self, inputs): |
| | """ |
| | Run the model with the prepared inputs. |
| | Returns the model output (entity scores). |
| | """ |
| | |
| | for index, tensor in inputs.items(): |
| | self.interpreter.set_tensor(index, tensor) |
| | |
| | |
| | self.interpreter.invoke() |
| | |
| | |
| | output_index = self.output_details[0]['index'] |
| | output = self.interpreter.get_tensor(output_index) |
| | |
| | if self.verbose: |
| | print(f"Model output shape: {output.shape}") |
| | |
| | return output |
| |
|
| | def extract_entities(self, text, threshold=0.5): |
| | """ |
| | Extract entities from text using the model. |
| | Returns a list of entity dictionaries with text, score, and position. |
| | """ |
| | |
| | words, positions = self.tokenize_text(text) |
| | |
| | |
| | candidates = self.find_entity_candidates(words, positions) |
| | |
| | |
| | word_embeddings_matrix = self.get_word_embeddings_matrix(words) |
| | |
| | |
| | inputs = self.prepare_model_inputs(words, candidates, word_embeddings_matrix) |
| | |
| | |
| | scores = self.run_model(inputs) |
| | |
| | |
| | entities = [] |
| | for i, (start, end) in enumerate(candidates): |
| | if i < len(scores) and scores[i] > threshold: |
| | if start < len(words) and end <= len(words): |
| | entity_text = " ".join(words[start:end]) |
| | entity_pos = (positions[start][0], positions[end-1][1]) |
| | entities.append({ |
| | "text": entity_text, |
| | "score": float(scores[i]), |
| | "position": entity_pos |
| | }) |
| | |
| | return entities |
| |
|
| |
|
| | def main(): |
| | print(f"Analyzing text: {SAMPLE_TEXT}") |
| | |
| | try: |
| | |
| | extractor = EntityExtractor(verbose=True) |
| | |
| | |
| | entities = extractor.extract_entities(SAMPLE_TEXT, threshold=0.5) |
| | |
| | print("\nDetected entities:") |
| | for entity in entities: |
| | print(f"- {entity['text']} (confidence: {entity['score']:.2f}, position: {entity['position']})") |
| | |
| | except Exception as e: |
| | print(f"Error: {str(e)}") |
| | traceback.print_exc() |
| | print("\nTroubleshooting tips:") |
| | print("1. Make sure all file paths are correct") |
| | print("2. Check that TensorFlow is installed (pip install tensorflow)") |
| | print("3. Ensure that NLTK is installed (pip install nltk)") |
| | print("4. Verify that the model file is a valid TFLite model") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|