Spaces:
Build error
Build error
| import nltk | |
| from nltk.collocations import BigramAssocMeasures, BigramCollocationFinder | |
| from nltk.corpus import stopwords | |
| import spacy | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from datetime import datetime, timedelta | |
| import numpy as np | |
| import heapq | |
| from concurrent.futures import ThreadPoolExecutor | |
| from annoy import AnnoyIndex | |
| from transformers import pipeline | |
| from rank_bm25 import BM25Okapi | |
| from functools import partial | |
| import ssl | |
| try: | |
| _create_unverified_https_context = ssl._create_unverified_context | |
| except AttributeError: | |
| pass | |
| else: | |
| ssl._create_default_https_context = _create_unverified_https_context | |
| import sys | |
| import subprocess | |
| def download_spacy_model(model_name): | |
| print(f"Downloading spaCy model: {model_name}") | |
| subprocess.check_call([sys.executable, "-m", "spacy", "download", model_name]) | |
| print(f"Model {model_name} downloaded successfully") | |
| # Usage | |
| try: | |
| nlp = spacy.load('en_core_web_sm') | |
| except OSError: | |
| # If the model is not found, download it | |
| download_spacy_model('en_core_web_sm') | |
| # Try loading again | |
| import spacy | |
| nlp = spacy.load('en_core_web_sm') | |
| # Now you can use the model | |
| print("spaCy model loaded successfully") | |
| nltk.download('stopwords') | |
| nltk.download('punkt') | |
| q = 0 | |
| def get_keywords(text, cache): | |
| global q | |
| if q % 1000 == 0: | |
| print(q) | |
| q += 1 | |
| if text in cache: | |
| return cache[text] | |
| doc = nlp(text) | |
| keywords = [] | |
| for token in doc: | |
| if token.pos_ in ['NOUN', 'PROPN', 'VERB']: | |
| keywords.append(token.text.lower()) | |
| stop_words = set(stopwords.words('english')) | |
| keywords = [word for word in keywords if word not in stop_words] | |
| bigram_measures = BigramAssocMeasures() | |
| finder = BigramCollocationFinder.from_words([token.text for token in doc]) | |
| bigrams = finder.nbest(bigram_measures.pmi, 10) | |
| keywords.extend([' '.join(bigram) for bigram in bigrams]) | |
| cache[text] = keywords | |
| return keywords | |
| def calculate_weight(message, sender_messages, cache): | |
| message_time = datetime.strptime(message[1], '%Y-%m-%d %H:%M:%S') | |
| recent_messages = sender_messages[np.abs((np.array([datetime.strptime(m[1], '%Y-%m-%d %H:%M:%S') for m in sender_messages]) - message_time).astype('timedelta64[s]').astype(int) <= 5 * 3600)] | |
| recent_keywords = [get_keywords(m[2], cache) for m in recent_messages] | |
| keyword_counts = [sum([k.count(keyword) for k in recent_keywords]) for keyword in get_keywords(message[2], cache)] | |
| weight = sum(keyword_counts) | |
| return weight | |
| class ChatDatabase: | |
| def __init__(self, filename): | |
| self.filename = filename | |
| self.messages = [] | |
| self.messages_array = None | |
| self.sender_array = None | |
| self.load_messages() | |
| self.index = None | |
| self.tfidf = None | |
| def load_messages(self): | |
| with open(self.filename, 'a') as file: | |
| pass | |
| with open(self.filename, 'r') as f: | |
| for line in f: | |
| parts = line.strip().split('\t') | |
| if len(parts) == 4: | |
| sender, time, text, tag = parts | |
| else: | |
| sender, time, text = parts | |
| tag = None | |
| message = (sender, time, text, tag) | |
| self.messages.append(message) | |
| self.messages_array = np.array(self.messages, dtype=object) | |
| print(self.messages_array,'hihii') | |
| if len(self.messages_array)==0: | |
| self.sender_array = [] | |
| else: | |
| self.sender_array = self.messages_array[:, 0] | |
| print(f'Database loaded. Number of messages: {len(self.messages_array)}') | |
| def add_message(self, sender, time, text, tag=None): | |
| message = np.array((sender, time, text, tag)).flatten() | |
| self.messages.append(message) | |
| self.messages_array = np.append(self.messages_array, message, axis=0) | |
| self.sender_array = np.append(self.sender_array, sender) | |
| with open(self.filename, 'a') as f: | |
| f.write(f'{sender}\t{time}\t{text}\t{tag}\n') | |
| def predict_response_separate(self, query, sender, cache): | |
| if self.messages_array is None: | |
| print("Error: messages_array is None") | |
| return None | |
| sender_messages = self.messages_array[self.sender_array == sender] | |
| if len(sender_messages) == 0: | |
| print(f"No messages found for sender: {sender}") | |
| return None | |
| query_keywords = ' '.join(get_keywords(query, cache)) | |
| query_vector = self.tfidf.transform([query_keywords]).toarray()[0] | |
| relevant_indices = self.index.get_nns_by_vector(query_vector, 1) | |
| relevant_message = sender_messages[relevant_indices[0]] | |
| next_message_index = np.where(self.sender_array != sender)[0][0] | |
| if next_message_index < len(self.messages_array): | |
| predicted_response = self.messages_array[next_message_index] | |
| return tuple(predicted_response) | |
| else: | |
| return None | |
| def get_relevant_messages(self, sender, query, N, cache, query_tag=None, n_threads=30, tag_boost=1.5): | |
| if self.messages_array is None: | |
| print("Error: messages_array is None") | |
| return [] | |
| query_keywords = query.lower().split() | |
| #Filter by sender | |
| sender_messages = self.messages_array[self.sender_array == sender] | |
| print(f"Number of messages from sender {sender}: {len(sender_messages)}") | |
| # Filter messages by sender, tag, and keywords in a single line | |
| sender_messages = self.messages_array[ | |
| (self.sender_array == sender) & | |
| np.array([any(keyword in message.lower() for keyword in query_keywords) for message in self.messages_array[:, 2]]) | |
| ] | |
| if len(sender_messages) == 0: | |
| print(f"No messages found for sender: {sender} with the given keywords") | |
| return [] | |
| else: | |
| print(len(sender_messages)) | |
| def process_batch(batch, query_keywords, current_time, query_tag): | |
| batch_keywords = [get_keywords(message[2], cache) for message in batch] | |
| bm25 = BM25Okapi(batch_keywords) | |
| bm25_scores = bm25.get_scores(query_keywords) | |
| time_scores = 1 / (1 + (current_time - np.array([datetime.strptime(m[1], '%Y-%m-%d %H:%M:%S') for m in batch])).astype('timedelta64[D]').astype(int)) | |
| tag_scores = np.where(np.array([m[3] for m in batch]) == query_tag, tag_boost, 1) | |
| combined_scores = 0.6 * np.array(bm25_scores) + 0.2 * time_scores + 0.2 * tag_scores | |
| return combined_scores, batch | |
| current_time = datetime.now() | |
| batch_size = max(1, len(sender_messages) // n_threads) | |
| batches = [sender_messages[i:i+batch_size] for i in range(0, len(sender_messages), batch_size)] | |
| with ThreadPoolExecutor(max_workers=n_threads) as executor: | |
| process_func = partial(process_batch, query_keywords=query_keywords, current_time=current_time, query_tag=query_tag) | |
| results = list(executor.map(process_func, batches)) | |
| all_scores = np.concatenate([r[0] for r in results]) | |
| all_messages = np.concatenate([r[1] for r in results]) | |
| top_indices = np.argsort(all_scores)[-N:][::-1] | |
| relevant_messages = all_messages[top_indices] | |
| return relevant_messages.tolist() | |
| def generate_response(self, query, sender, cache, query_tag=None): | |
| relevant_messages = self.get_relevant_messages(sender, query, 5, cache, query_tag) | |
| context = ' '.join([message[2] for message in relevant_messages]) | |
| generator = pipeline('text-generation', model='EleutherAI/gpt-neo-2.7B') | |
| response = generator(f'{context} {query}', max_length=100, do_sample=True)[0]['generated_text'] | |
| response = response.split(query)[-1].strip() | |
| return response | |
| # Usage example remains the same | |
| ''' | |
| # Usage example | |
| db = ChatDatabase('memory.txt') | |
| # Example 1: Get relevant messages | |
| query = 'fisical' | |
| sender = 'Arcana' | |
| N = 10 | |
| cache = {} | |
| query_tag = None | |
| relevant_messages = db.get_relevant_messages(sender, query, N, cache, query_tag) | |
| print("Relevant messages:") | |
| for message in relevant_messages: | |
| print(f"Sender: {message[0]}, Time: {message[1]}, Tag: {message[3]}") | |
| print(f"Message: {message[2][:100]}...") | |
| print() | |
| # Example 2: Predict response (using the original method) | |
| query = "what was that?" | |
| sender = 'David' | |
| db.build_index_separate(cache) | |
| predicted_response = db.predict_response_separate(query, sender, cache) | |
| print("\nPredicted response:") | |
| if predicted_response is not None: | |
| print(f"Sender: {predicted_response[0]}, Time: {predicted_response[1]}, Tag: {predicted_response[3]}") | |
| print(f"Message: {predicted_response[2][:100]}...") | |
| else: | |
| print('No predicted response found') | |
| # Example 3: Generate response | |
| query = "Let's plan a trip" | |
| sender = 'Alice' | |
| query_tag = 'travel' | |
| generated_response = db.generate_response(query, sender, cache, query_tag) | |
| print("\nGenerated response:") | |
| print(generated_response) | |
| ''' | |