|
|
import os |
|
|
import json |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import logging |
|
|
from collections import Counter |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import warnings |
|
|
from datetime import datetime |
|
|
from sklearn.preprocessing import normalize |
|
|
import requests |
|
|
import json |
|
|
import argparse |
|
|
from openai import OpenAI |
|
|
|
|
|
from scripts.scripts.sign2text_mapping import sign2text |
|
|
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
filename='AulSign.log', |
|
|
level=logging.DEBUG, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
|
filemode='w' |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
client = OpenAI( |
|
|
organization=os.getenv("OPENAI_ORGANIZATION"), |
|
|
project=os.getenv("OPENAI_PROJECT"), |
|
|
api_key=os.getenv("OPENAI_API_KEY") |
|
|
) |
|
|
|
|
|
print('Inference started...') |
|
|
|
|
|
def query_ollama(messages, model="mistral:7b-instruct-fp16"): |
|
|
url = "http://localhost:11434/api/chat" |
|
|
|
|
|
options = {"seed": 42,"temperature": 0.1} |
|
|
|
|
|
|
|
|
payload = { |
|
|
"model": model, |
|
|
"messages": messages, |
|
|
"options": options, |
|
|
"stream": False |
|
|
} |
|
|
|
|
|
response = requests.post(url, json=payload) |
|
|
|
|
|
if response.status_code == 200: |
|
|
return response.json()["message"]["content"] |
|
|
else: |
|
|
return f"Error: {response.status_code}, {response.text}" |
|
|
|
|
|
def check_repetition(text, threshold=0.2): |
|
|
if not text: |
|
|
return False |
|
|
|
|
|
words = [word.strip for word in text.split('#')] |
|
|
|
|
|
unique_words = len(set(words)) |
|
|
total_words = len(words) |
|
|
|
|
|
if "<unk>" in words: |
|
|
logging.debug(f"Check repetition: '<unk>' was generated in the answer") |
|
|
return True |
|
|
|
|
|
|
|
|
is_repetitive = unique_words < total_words * threshold |
|
|
logging.debug(f"Check repetition: {is_repetitive} (Unique: {unique_words}, Total: {total_words})") |
|
|
return is_repetitive |
|
|
|
|
|
|
|
|
|
|
|
def prepare_dataset(prediction: pd.DataFrame, validation: pd.DataFrame, modality:str): |
|
|
if modality=='text2sign': |
|
|
validation = validation.rename(columns={'fsw':'gold_fsw_seq','symbol': 'gold_symbol_seq', 'word': 'gold_cd'}) |
|
|
metrics = prediction.merge(validation[['gold_symbol_seq','gold_cd', 'sentence','gold_fsw_seq']], on=['sentence']) |
|
|
elif modality=='sign2text': |
|
|
validation = validation.rename(columns={'word': 'gold_cd'}) |
|
|
metrics = prediction.merge(validation[['sentence','gold_cd']], on=['gold_cd']) |
|
|
return metrics |
|
|
|
|
|
|
|
|
def cos_sim(a, b): |
|
|
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) |
|
|
|
|
|
def find_most_similar_sentence(user_embedding, train_sentences: pd.DataFrame, n=3, unk_threshold=7): |
|
|
|
|
|
sentence_embeddings = np.vstack(train_sentences["embedding_sentence"].values) |
|
|
decompositions = train_sentences["decomposition"].values |
|
|
sentences = train_sentences["sentence"].values |
|
|
|
|
|
|
|
|
sentence_embeddings = normalize(sentence_embeddings, axis=1) |
|
|
user_embedding = normalize(user_embedding.reshape(1, -1), axis=1) |
|
|
|
|
|
|
|
|
similarities = np.dot(sentence_embeddings, user_embedding.T).flatten() |
|
|
|
|
|
|
|
|
unk_counts = np.array([d.count("<unk>") for d in decompositions]) |
|
|
similarities[unk_counts > unk_threshold] = 0 |
|
|
|
|
|
|
|
|
top_n_indices = np.argsort(similarities)[-n:][::-1] |
|
|
|
|
|
|
|
|
return [decompositions[i] for i in top_n_indices], [sentences[i] for i in top_n_indices] |
|
|
|
|
|
|
|
|
def find_most_similar_canonical_entry(user_embedding, vocabulary: pd.DataFrame, n=30): |
|
|
|
|
|
vocabulary_embeddings = np.vstack(vocabulary["embedding"].values) |
|
|
vocabulary_words = vocabulary["word"].values |
|
|
|
|
|
|
|
|
vocabulary_embeddings = normalize(vocabulary_embeddings, axis=1) |
|
|
user_embedding = normalize(user_embedding.reshape(1, -1), axis=1) |
|
|
|
|
|
|
|
|
similarities = np.dot(vocabulary_embeddings, user_embedding.T).flatten() |
|
|
|
|
|
|
|
|
sorted_indices = np.argsort(similarities)[::-1] |
|
|
|
|
|
|
|
|
canonical_list = [] |
|
|
canonical_similarities = [] |
|
|
|
|
|
for idx in sorted_indices: |
|
|
if len(canonical_list) >= n: |
|
|
break |
|
|
|
|
|
|
|
|
canonical_entry = get_most_freq(vocabulary_words[idx]) |
|
|
|
|
|
|
|
|
if canonical_entry not in canonical_list: |
|
|
canonical_list.append(canonical_entry) |
|
|
canonical_similarities.append(similarities[idx]) |
|
|
|
|
|
|
|
|
return canonical_list |
|
|
|
|
|
|
|
|
def get_most_freq(lista:list): |
|
|
lista_cleaned = [] |
|
|
for segno in lista: |
|
|
segno_pulito = segno.lower().strip() |
|
|
if segno_pulito not in lista_cleaned: |
|
|
lista_cleaned.append(segno_pulito) |
|
|
|
|
|
frequency_count = Counter(lista_cleaned) |
|
|
|
|
|
top_two_words = frequency_count.most_common(2) |
|
|
|
|
|
if len(top_two_words) >= 2: |
|
|
first_word = top_two_words[0][0] |
|
|
second_word = top_two_words[1][0] |
|
|
|
|
|
return first_word+'|'+second_word |
|
|
else: |
|
|
first_word = top_two_words[0][0] |
|
|
return first_word |
|
|
|
|
|
def get_most_freq_fsw(lista_fsw): |
|
|
if isinstance(lista_fsw,str): |
|
|
return lista_fsw |
|
|
else: |
|
|
frequency_count = Counter(lista_fsw) |
|
|
max_freq_word = frequency_count.most_common(1)[0][0] |
|
|
return max_freq_word |
|
|
|
|
|
|
|
|
def get_fsw_exact(vocabulary: pd.DataFrame, can_desc_answer, model, top_k=10): |
|
|
|
|
|
vocabulary_embeddings = np.vstack(vocabulary["embedding"].values) |
|
|
vocabulary_words = vocabulary["word"].values |
|
|
vocabulary_fsw = vocabulary["fsw"].values |
|
|
|
|
|
|
|
|
vocabulary_embeddings = normalize(vocabulary_embeddings, axis=1) |
|
|
|
|
|
fsw_seq = [] |
|
|
can_desc_association_seq = [] |
|
|
joint_prob = 1 |
|
|
|
|
|
for can_d in can_desc_answer: |
|
|
|
|
|
can_d_emb = model.encode(can_d, normalize_embeddings=True).reshape(1, -1) |
|
|
|
|
|
|
|
|
similarities = np.dot(vocabulary_embeddings, can_d_emb.T).flatten() |
|
|
|
|
|
|
|
|
top_k_indices = np.argsort(similarities)[-top_k:][::-1] |
|
|
top_k_words = vocabulary_words[top_k_indices] |
|
|
top_k_fsws = vocabulary_fsw[top_k_indices] |
|
|
top_k_similarities = similarities[top_k_indices] |
|
|
|
|
|
|
|
|
exact_match_index = next((i for i, word in enumerate(top_k_words) if get_most_freq(word) == can_d.strip()), None) |
|
|
|
|
|
if exact_match_index is not None: |
|
|
|
|
|
most_similar_word = get_most_freq(top_k_words[exact_match_index]) |
|
|
fsw = top_k_fsws[exact_match_index] |
|
|
max_similarity = 1 |
|
|
else: |
|
|
|
|
|
max_index = 0 |
|
|
most_similar_word = get_most_freq(top_k_words[max_index]) |
|
|
fsw = top_k_fsws[max_index] |
|
|
max_similarity = top_k_similarities[max_index] |
|
|
|
|
|
|
|
|
logging.info(fsw) |
|
|
fsw_seq.append(get_most_freq_fsw(fsw)) |
|
|
joint_prob *= max_similarity |
|
|
can_desc_association_seq.append(most_similar_word) |
|
|
|
|
|
|
|
|
logging.debug(f"Word: {can_d}") |
|
|
logging.debug(f"Most similar word in vocabulary: {most_similar_word}") |
|
|
logging.debug(f"Similarity: {max_similarity}") |
|
|
logging.debug(f"Fsw_seq: {' '.join(fsw_seq)}") |
|
|
logging.debug("---") |
|
|
|
|
|
|
|
|
joint_prob = pow(joint_prob, 1 / len(can_desc_association_seq)) |
|
|
|
|
|
return ' '.join(fsw_seq), ' # '.join(can_desc_association_seq), np.round(joint_prob, 3) |
|
|
|
|
|
|
|
|
def AulSign(input:str, rules_prompt_path:str, train_sentences:pd.DataFrame, vocabulary:pd.DataFrame, model, ollama:bool, modality:str): |
|
|
""" |
|
|
AulSign: A function for translating between text and Formal SignWriting (FSW) or vice versa. |
|
|
|
|
|
This function leverages embeddings, similarity matching, and language models to facilitate |
|
|
translations based on the specified modality (`text2sign` or `sign2text`). |
|
|
|
|
|
Args: |
|
|
input (str): |
|
|
The sentence or sign sequence to be analyzed and translated. |
|
|
rules_prompt_path (str): |
|
|
Path to a file containing predefined prompts and rules to guide the language model. |
|
|
train_sentences (pd.DataFrame): |
|
|
A dataset containing sentences and their embeddings for training or similarity matching. |
|
|
vocabulary (pd.DataFrame): |
|
|
A table of vocabulary entries with canonical descriptions and embeddings, used for matching. |
|
|
model: |
|
|
The embedding model used to convert sentences or sign sequences into vector representations. |
|
|
ollama (bool): |
|
|
Specifies whether to use the `query_ollama` method for querying the language model. |
|
|
modality (str): |
|
|
The translation mode: |
|
|
- `'text2sign'`: Converts text to Formal SignWriting sequences. |
|
|
- `'sign2text'`: Converts Formal SignWriting to textual sentences. |
|
|
|
|
|
Returns: |
|
|
For `modality == "text2sign"`: |
|
|
tuple: |
|
|
- answer (str): |
|
|
The translated text or decomposition provided by the language model. |
|
|
- fsw (list): |
|
|
A list of Formal SignWriting sequences associated with the translation. |
|
|
- can_desc_association_seq (list): |
|
|
A list of canonical descriptions associated with the FSW sequences. |
|
|
- joint_prob (float): |
|
|
The joint probability of the most likely translation path. |
|
|
|
|
|
For `modality == "sign2text"`: |
|
|
str: |
|
|
The reconstructed textual sentence translated from the input sign sequence. |
|
|
|
|
|
If an invalid modality is provided: |
|
|
str: |
|
|
Returns 'error' to indicate invalid input. |
|
|
|
|
|
Raises: |
|
|
Exception: |
|
|
Logs and raises errors encountered during API calls or message construction. |
|
|
""" |
|
|
|
|
|
sent_embedding = model.encode(input, normalize_embeddings=True) |
|
|
|
|
|
if modality =='text2sign': |
|
|
|
|
|
similar_canonical = find_most_similar_canonical_entry(sent_embedding, vocabulary, n=100) |
|
|
|
|
|
|
|
|
|
|
|
similar_canonical_str = ' # '.join(similar_canonical) |
|
|
|
|
|
|
|
|
with open(rules_prompt_path, 'r') as file: |
|
|
rules_prompt = file.read().format(similar_canonical=similar_canonical_str) |
|
|
|
|
|
|
|
|
decomposition, sentences = find_most_similar_sentence( |
|
|
user_embedding=sent_embedding, |
|
|
train_sentences=train_sentences, |
|
|
n=20 |
|
|
) |
|
|
|
|
|
messages = [{"role": "system", "content": rules_prompt}] |
|
|
for sentence, decomposition in zip(sentences, decomposition): |
|
|
|
|
|
if sentence and decomposition: |
|
|
messages.append({"role": "user", "content": sentence}) |
|
|
messages.append({"role": "assistant", "content": decomposition}) |
|
|
else: |
|
|
logging.warning("Missing 'sentence' or 'decomposition' in messages.") |
|
|
|
|
|
messages.append({"role": "user", "content": "decompose the following sentence as shown in the previous examples"}) |
|
|
messages.append({"role": "user", "content": input}) |
|
|
|
|
|
|
|
|
valid_messages = [] |
|
|
for message in messages: |
|
|
if 'role' in message and 'content' in message: |
|
|
valid_messages.append(message) |
|
|
logging.debug(message) |
|
|
else: |
|
|
logging.error(f"Invalid message format detected: {message}") |
|
|
|
|
|
if ollama: |
|
|
|
|
|
answer = query_ollama(messages) |
|
|
|
|
|
logging.info("\n[LOG] MISTRAL Answer:") |
|
|
logging.info(answer) |
|
|
|
|
|
can_description_answer = answer.split('#') |
|
|
else: |
|
|
try: |
|
|
|
|
|
completion = client.chat.completions.create( |
|
|
model="gpt-3.5-turbo", |
|
|
messages=messages, |
|
|
temperature=0 |
|
|
) |
|
|
answer = completion.choices[0].message.content |
|
|
|
|
|
if check_repetition(answer): |
|
|
|
|
|
presence_penalty = 0.6 |
|
|
completion = client.chat.completions.create( |
|
|
model="gpt-3.5-turbo", |
|
|
messages=messages, |
|
|
presence_penalty=presence_penalty, |
|
|
temperature=0 |
|
|
) |
|
|
logging.info(f"presence_penalty: {presence_penalty}") |
|
|
answer = completion.choices[0].message.content |
|
|
logging.info('ANSWER: GPT') |
|
|
logging.info(answer + '\n\n') |
|
|
|
|
|
|
|
|
can_description_answer = answer.split('#') |
|
|
|
|
|
else: |
|
|
logging.info('ANSWER: GPT') |
|
|
logging.info(answer + '\n\n') |
|
|
|
|
|
|
|
|
can_description_answer = answer.split('#') |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Error during GPT API call: {e}") |
|
|
|
|
|
|
|
|
fsw, can_desc_association_seq, joint_prob = get_fsw_exact( |
|
|
vocabulary=vocabulary, |
|
|
can_desc_answer=can_description_answer, |
|
|
model=model |
|
|
) |
|
|
|
|
|
return answer, fsw, can_desc_association_seq, joint_prob |
|
|
|
|
|
elif modality =='sign2text': |
|
|
|
|
|
|
|
|
with open(rules_prompt_path, 'r') as file: |
|
|
rules_prompt = file.read() |
|
|
|
|
|
|
|
|
|
|
|
decomposition, sentences = find_most_similar_sentence( |
|
|
user_embedding=sent_embedding, |
|
|
train_sentences=train_sentences, |
|
|
n=30 |
|
|
) |
|
|
|
|
|
messages = [{"role": "system", "content": rules_prompt}] |
|
|
for sentence, decomposition in zip(sentences, decomposition): |
|
|
|
|
|
if sentence and decomposition: |
|
|
messages.append({"role": "user", "content": decomposition}) |
|
|
messages.append({"role": "assistant", "content": sentence}) |
|
|
else: |
|
|
logging.warning("Missing 'sentence' or 'decomposition' in messages.") |
|
|
|
|
|
messages.append({"role": "user", "content": "reconstruct the sentence as shown on the examples above"}) |
|
|
messages.append({"role": "user", "content": input}) |
|
|
|
|
|
|
|
|
valid_messages = [] |
|
|
for message in messages: |
|
|
if 'role' in message and 'content' in message: |
|
|
valid_messages.append(message) |
|
|
logging.debug(message) |
|
|
else: |
|
|
logging.error(f"Invalid message format detected: {message}") |
|
|
|
|
|
if ollama: |
|
|
|
|
|
answer = query_ollama(messages) |
|
|
|
|
|
logging.info("\n[LOG] MISTRAL Answer:") |
|
|
logging.info(answer) |
|
|
|
|
|
can_description_answer = answer.split('#') |
|
|
else: |
|
|
try: |
|
|
|
|
|
completion = client.chat.completions.create( |
|
|
model="gpt-3.5-turbo", |
|
|
messages=messages, |
|
|
temperature=0 |
|
|
) |
|
|
answer = completion.choices[0].message.content |
|
|
logging.info('ANSWER: GPT') |
|
|
logging.info(answer + '\n\n') |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Error during GPT API call: {e}") |
|
|
|
|
|
return answer |
|
|
else: |
|
|
return 'error' |
|
|
|
|
|
|
|
|
def main(modality, setup, input=None): |
|
|
np.random.seed(42) |
|
|
current_time = datetime.now().strftime("%Y_%m_%d_%H_%M") |
|
|
data_path = f"data/preprocess_output_{setup}/file_comparison" |
|
|
corpus_embeddings_path = 'tools/corpus_embeddings.json' |
|
|
if setup is None: |
|
|
sentences_train_embeddings_path = f"tools/sentences_train_embeddings_filtered_01.json" |
|
|
else: |
|
|
sentences_train_embeddings_path = f"tools/sentences_train_embeddings_{setup}.json" |
|
|
rules_prompt_path_text2sign = 'tools/rules_prompt_text2sign.txt' |
|
|
rules_prompt_path_sign2text = 'tools/rules_prompt_sign2text.txt' |
|
|
|
|
|
|
|
|
model_name = "mixedbread-ai/mxbai-embed-large-v1" |
|
|
model = SentenceTransformer(model_name) |
|
|
|
|
|
|
|
|
with open(corpus_embeddings_path, 'r') as file: |
|
|
corpus_embeddings = pd.DataFrame(json.load(file)) |
|
|
|
|
|
with open(sentences_train_embeddings_path, 'r') as file: |
|
|
sentences_train_embeddings = pd.DataFrame(json.load(file)) |
|
|
|
|
|
if input: |
|
|
if modality == 'text2sign': |
|
|
answer, fsw_seq, can_desc_association_seq, joint_prob = AulSign( |
|
|
input=input, |
|
|
rules_prompt_path=rules_prompt_path_text2sign, |
|
|
train_sentences=sentences_train_embeddings, |
|
|
vocabulary=corpus_embeddings, |
|
|
model=model, |
|
|
ollama=False, |
|
|
modality=modality |
|
|
) |
|
|
|
|
|
print(f"Canonical Descriptions: {can_desc_association_seq}") |
|
|
print(f"Translation (FSW): {fsw_seq}") |
|
|
|
|
|
|
|
|
|
|
|
elif modality == 'sign2text': |
|
|
mapped_input = sign2text(input,corpus_embeddings_path) |
|
|
logging.info(f"\nReconstructed Sentence via Vocaboulary: {mapped_input}") |
|
|
answer= AulSign( |
|
|
input=mapped_input, |
|
|
rules_prompt_path=rules_prompt_path_sign2text, |
|
|
train_sentences=sentences_train_embeddings, |
|
|
vocabulary=corpus_embeddings, |
|
|
model=model, |
|
|
ollama=False, |
|
|
modality=modality |
|
|
) |
|
|
print(f"Input Sign Voucaboualry Mapping: {input}") |
|
|
print(f"Translation (Text): {answer}") |
|
|
|
|
|
else: |
|
|
test_path = os.path.join(data_path, f"test.csv") |
|
|
test = pd.read_csv(test_path) |
|
|
test = test.head(1) |
|
|
|
|
|
if modality == 'text2sign': |
|
|
list_sentence = [] |
|
|
list_answer = [] |
|
|
list_fsw_seq = [] |
|
|
can_desc_association_list = [] |
|
|
prob_of_association_list = [] |
|
|
|
|
|
for index, row in test.iterrows(): |
|
|
sentence = row['sentence'] |
|
|
answer, fsw_seq, can_desc_association_seq, joint_prob = AulSign( |
|
|
input=sentence, |
|
|
rules_prompt_path=rules_prompt_path_text2sign, |
|
|
train_sentences=sentences_train_embeddings, |
|
|
vocabulary=corpus_embeddings, |
|
|
model=model, |
|
|
ollama=False, |
|
|
modality=modality |
|
|
) |
|
|
|
|
|
list_sentence.append(sentence) |
|
|
list_answer.append(answer) |
|
|
list_fsw_seq.append(fsw_seq) |
|
|
can_desc_association_list.append(can_desc_association_seq) |
|
|
prob_of_association_list.append(joint_prob) |
|
|
|
|
|
df_pred = pd.DataFrame({ |
|
|
'sentence': list_sentence, |
|
|
'pseudo_cd': list_answer, |
|
|
'pred_cd': can_desc_association_list, |
|
|
'joint_prob': prob_of_association_list, |
|
|
'pred_fsw_seq': list_fsw_seq |
|
|
}) |
|
|
output_path = os.path.join('result', f"{modality}_{current_time}") |
|
|
os.makedirs(output_path, exist_ok=True) |
|
|
df_pred = prepare_dataset(df_pred,test,modality) |
|
|
df_pred.to_csv(os.path.join(output_path, f'result_{current_time}.csv'), index=False) |
|
|
|
|
|
elif modality == 'sign2text': |
|
|
|
|
|
list_answer = [] |
|
|
list_gold_cd = [] |
|
|
|
|
|
for index, row in test.iterrows(): |
|
|
dec_sentence = row['word'] |
|
|
answer = AulSign( |
|
|
input=dec_sentence, |
|
|
rules_prompt_path=rules_prompt_path_sign2text, |
|
|
train_sentences=sentences_train_embeddings, |
|
|
vocabulary=corpus_embeddings, |
|
|
model=model, |
|
|
ollama=False, |
|
|
modality=modality |
|
|
) |
|
|
list_gold_cd.append(dec_sentence) |
|
|
list_answer.append(answer) |
|
|
|
|
|
df_pred = pd.DataFrame({ |
|
|
'pseudo_sentence': list_answer, |
|
|
'gold_cd': list_gold_cd, |
|
|
}) |
|
|
output_path = os.path.join('result', f"{modality}_{current_time}") |
|
|
os.makedirs(output_path, exist_ok=True) |
|
|
df_pred = prepare_dataset(df_pred,test,modality) |
|
|
df_pred.to_csv(os.path.join(output_path, f'result_{current_time}.csv'), index=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--mode", required=True, help="Mode of operation: text2sign or sign2text") |
|
|
parser.add_argument("--input", help="Input text or sign sequence") |
|
|
args = parser.parse_args() |
|
|
|
|
|
main(args.mode, setup=None, input=args.input) |