File size: 23,689 Bytes
149cddd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
import os
import json
import numpy as np
import pandas as pd
import logging
from collections import Counter
from sentence_transformers import SentenceTransformer
import warnings
from datetime import datetime
from sklearn.preprocessing import normalize
import requests
import json
import argparse
from openai import OpenAI

from scripts.scripts.sign2text_mapping import sign2text

warnings.filterwarnings("ignore", category=FutureWarning)


# Set up logging configuration
logging.basicConfig(
    filename='AulSign.log',  # Log to a file
    level=logging.DEBUG,         # Log everything, including debug info
    format='%(asctime)s - %(levelname)s - %(message)s',  # Log format
    filemode='w'                 # Overwrite the log file each run
)



client = OpenAI(
  organization=os.getenv("OPENAI_ORGANIZATION"),
  project=os.getenv("OPENAI_PROJECT"),
  api_key=os.getenv("OPENAI_API_KEY")
)

print('Inference started...')

def query_ollama(messages, model="mistral:7b-instruct-fp16"):
    url = "http://localhost:11434/api/chat"

    options = {"seed": 42,"temperature": 0.1}


    payload = {
        "model": model,
        "messages": messages,
        "options": options,
        "stream": False
    }

    response = requests.post(url, json=payload)

    if response.status_code == 200:
        return response.json()["message"]["content"]
    else:
        return f"Error: {response.status_code}, {response.text}"

def check_repetition(text, threshold=0.2):
    if not text:
        return False
    
    words = [word.strip for word in text.split('#')]

    unique_words = len(set(words))
    total_words = len(words)

    if "<unk>" in words:
        logging.debug(f"Check repetition: '<unk>' was generated in the answer")
        return True

    
    is_repetitive = unique_words < total_words * threshold
    logging.debug(f"Check repetition: {is_repetitive} (Unique: {unique_words}, Total: {total_words})")
    return is_repetitive


# Function to merge predictions with gold data and compute metrics
def prepare_dataset(prediction: pd.DataFrame, validation: pd.DataFrame, modality:str):
    if modality=='text2sign':
        validation = validation.rename(columns={'fsw':'gold_fsw_seq','symbol': 'gold_symbol_seq', 'word': 'gold_cd'}) 
        metrics = prediction.merge(validation[['gold_symbol_seq','gold_cd', 'sentence','gold_fsw_seq']], on=['sentence'])
    elif modality=='sign2text':
        validation = validation.rename(columns={'word': 'gold_cd'}) 
        metrics = prediction.merge(validation[['sentence','gold_cd']], on=['gold_cd'])
    return metrics

# Define cosine similarity function if it's missing
def cos_sim(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

def find_most_similar_sentence(user_embedding, train_sentences: pd.DataFrame, n=3, unk_threshold=7):
    # Estrai gli embedding, le decomposizioni e le frasi dal DataFrame
    sentence_embeddings = np.vstack(train_sentences["embedding_sentence"].values)  # Matrix of sentence embeddings
    decompositions = train_sentences["decomposition"].values
    sentences = train_sentences["sentence"].values
    
    # Normalizza gli embedding delle frasi e l'embedding utente
    sentence_embeddings = normalize(sentence_embeddings, axis=1)
    user_embedding = normalize(user_embedding.reshape(1, -1), axis=1)
    
    # Calcola le similarità usando un'unica moltiplicazione matrice-vettore
    similarities = np.dot(sentence_embeddings, user_embedding.T).flatten()  # Shape (num_sentences,)
    
    # Imposta la similarità a zero per le frasi con troppi "<unk>"
    unk_counts = np.array([d.count("<unk>") for d in decompositions])
    similarities[unk_counts > unk_threshold] = 0  # Penalizza le frasi con troppi "<unk>"
    
    # Ottieni gli indici delle top-n frasi più simili
    top_n_indices = np.argsort(similarities)[-n:][::-1]
    
    # Ritorna le decomposizioni e le frasi corrispondenti alle top-n similitudini
    return [decompositions[i] for i in top_n_indices], [sentences[i] for i in top_n_indices]


def find_most_similar_canonical_entry(user_embedding, vocabulary: pd.DataFrame, n=30):
    # Extract embeddings and words from the vocabulary
    vocabulary_embeddings = np.vstack(vocabulary["embedding"].values)  # Matrix of embeddings
    vocabulary_words = vocabulary["word"].values
    
    # Normalize vocabulary embeddings and user embedding
    vocabulary_embeddings = normalize(vocabulary_embeddings, axis=1)
    user_embedding = normalize(user_embedding.reshape(1, -1), axis=1)
    
    # Compute cosine similarities for all entries in one matrix multiplication
    similarities = np.dot(vocabulary_embeddings, user_embedding.T).flatten()  # Shape (vocabulary_size,)
    
    # Get a sorted list of indices based on similarity scores
    sorted_indices = np.argsort(similarities)[::-1]  # Sort in descending order
    
    # Initialize lists for canonical entries and similarities
    canonical_list = []
    canonical_similarities = []
    
    for idx in sorted_indices:
        if len(canonical_list) >= n:  # Stop once we have n entries
            break
        
        # Get canonical entry for the current word
        canonical_entry = get_most_freq(vocabulary_words[idx])
        
        # Check for duplicates in canonical entries
        if canonical_entry not in canonical_list:
            canonical_list.append(canonical_entry)
            canonical_similarities.append(similarities[idx])
    
    # Return the top n canonical entries and their similarities
    return canonical_list#, canonical_similarities


def get_most_freq(lista:list):
    lista_cleaned = []
    for segno in lista:
        segno_pulito = segno.lower().strip()
        if segno_pulito not in lista_cleaned:
            lista_cleaned.append(segno_pulito)

    frequency_count = Counter(lista_cleaned)
    #print(frequency_count)
    top_two_words = frequency_count.most_common(2)

    if len(top_two_words) >= 2:
        first_word = top_two_words[0][0]
        second_word = top_two_words[1][0]

        return first_word+'|'+second_word
    else:
        first_word = top_two_words[0][0]
        return first_word

def get_most_freq_fsw(lista_fsw):
    if isinstance(lista_fsw,str):
        return lista_fsw
    else:
        frequency_count = Counter(lista_fsw)
        max_freq_word = frequency_count.most_common(1)[0][0]
        return max_freq_word


def get_fsw_exact(vocabulary: pd.DataFrame, can_desc_answer, model, top_k=10):
    # Extract vocabulary embeddings and words
    vocabulary_embeddings = np.vstack(vocabulary["embedding"].values)  # Create a matrix of all embeddings
    vocabulary_words = vocabulary["word"].values
    vocabulary_fsw = vocabulary["fsw"].values

    # Normalize vocabulary embeddings for cosine similarity
    vocabulary_embeddings = normalize(vocabulary_embeddings, axis=1)

    fsw_seq = []
    can_desc_association_seq = []
    joint_prob = 1

    for can_d in can_desc_answer:
        # Encode the candidate description and normalize
        can_d_emb = model.encode(can_d, normalize_embeddings=True).reshape(1, -1)  # Shape (1, embedding_dim)

        # Compute cosine similarities using matrix multiplication
        similarities = np.dot(vocabulary_embeddings, can_d_emb.T).flatten()  # Shape (vocabulary_size,)

        # Get the indices of the top_k most similar elements
        top_k_indices = np.argsort(similarities)[-top_k:][::-1]  # Indices of top-k elements
        top_k_words = vocabulary_words[top_k_indices]
        top_k_fsws = vocabulary_fsw[top_k_indices]
        top_k_similarities = similarities[top_k_indices]

        # Check for an exact match in the top_k elements
        exact_match_index = next((i for i, word in enumerate(top_k_words) if get_most_freq(word) == can_d.strip()), None)

        if exact_match_index is not None:
            # Exact match found
            most_similar_word = get_most_freq(top_k_words[exact_match_index])
            fsw = top_k_fsws[exact_match_index]
            max_similarity = 1  # Assign maximum similarity for an exact match
        else:
            # If no exact match, use the most similar word semantically
            max_index = 0  # First element in the sorted top_k (highest similarity)
            most_similar_word = get_most_freq(top_k_words[max_index])
            fsw = top_k_fsws[max_index]
            max_similarity = top_k_similarities[max_index]

        # Append the result
        logging.info(fsw)
        fsw_seq.append(get_most_freq_fsw(fsw))  # Append to fsw sequence
        joint_prob *= max_similarity  # Multiply joint probability
        can_desc_association_seq.append(most_similar_word)

        # Logging
        logging.debug(f"Word: {can_d}")
        logging.debug(f"Most similar word in vocabulary: {most_similar_word}")
        logging.debug(f"Similarity: {max_similarity}")
        logging.debug(f"Fsw_seq: {' '.join(fsw_seq)}")
        logging.debug("---")

    # Compute geometric mean of joint probability
    joint_prob = pow(joint_prob, 1 / len(can_desc_association_seq))
    
    return ' '.join(fsw_seq), ' # '.join(can_desc_association_seq), np.round(joint_prob, 3)

# Process input sentence through retrieval-augmented generation (RAG)
def AulSign(input:str, rules_prompt_path:str, train_sentences:pd.DataFrame, vocabulary:pd.DataFrame, model, ollama:bool, modality:str):
    """
AulSign: A function for translating between text and Formal SignWriting (FSW) or vice versa.

This function leverages embeddings, similarity matching, and language models to facilitate
translations based on the specified modality (`text2sign` or `sign2text`).

Args:
    input (str): 
        The sentence or sign sequence to be analyzed and translated.
    rules_prompt_path (str): 
        Path to a file containing predefined prompts and rules to guide the language model.
    train_sentences (pd.DataFrame): 
        A dataset containing sentences and their embeddings for training or similarity matching.
    vocabulary (pd.DataFrame): 
        A table of vocabulary entries with canonical descriptions and embeddings, used for matching.
    model: 
        The embedding model used to convert sentences or sign sequences into vector representations.
    ollama (bool): 
        Specifies whether to use the `query_ollama` method for querying the language model.
    modality (str): 
        The translation mode:
        - `'text2sign'`: Converts text to Formal SignWriting sequences.
        - `'sign2text'`: Converts Formal SignWriting to textual sentences.

Returns:
    For `modality == "text2sign"`:
        tuple:
            - answer (str): 
                The translated text or decomposition provided by the language model.
            - fsw (list): 
                A list of Formal SignWriting sequences associated with the translation.
            - can_desc_association_seq (list): 
                A list of canonical descriptions associated with the FSW sequences.
            - joint_prob (float): 
                The joint probability of the most likely translation path.

    For `modality == "sign2text"`:
        str: 
            The reconstructed textual sentence translated from the input sign sequence.

    If an invalid modality is provided:
        str: 
            Returns 'error' to indicate invalid input.

Raises:
    Exception: 
        Logs and raises errors encountered during API calls or message construction.
    """
   
    sent_embedding = model.encode(input, normalize_embeddings=True)

    if modality =='text2sign':
    
        similar_canonical = find_most_similar_canonical_entry(sent_embedding, vocabulary, n=100)
        #print(similar_canonical)

        
        similar_canonical_str = ' # '.join(similar_canonical)

        # Load the rules prompt from the file
        with open(rules_prompt_path, 'r') as file:
            rules_prompt = file.read().format(similar_canonical=similar_canonical_str)

        # Find the most similar sentences from training set
        decomposition, sentences = find_most_similar_sentence(
            user_embedding=sent_embedding, 
            train_sentences=train_sentences, 
            n=20
        )

        messages = [{"role": "system", "content": rules_prompt}]
        for sentence, decomposition in zip(sentences, decomposition):
            # Ensure each message has 'role' and 'content' keys
            if sentence and decomposition:
                messages.append({"role": "user", "content": sentence})
                messages.append({"role": "assistant", "content": decomposition})#.replace(' | ',' # ')})
            else:
                logging.warning("Missing 'sentence' or 'decomposition' in messages.")

        messages.append({"role": "user", "content": "decompose the following sentence as shown in the previous examples"})
        messages.append({"role": "user", "content": input})
        
        # Validate the constructed messages before converting to prompt text
        valid_messages = []
        for message in messages:
            if 'role' in message and 'content' in message:
                valid_messages.append(message)
                logging.debug(message)
            else:
                logging.error(f"Invalid message format detected: {message}")

        if ollama:
            # Query the LLM using query_ollama instead of llm_pipeline
            answer = query_ollama(messages)#, model="mistral:7b-instruct-fp16")

            logging.info("\n[LOG] MISTRAL Answer:")
            logging.info(answer)

            can_description_answer = answer.split('#')
        else:
            try:
                # Initial API call
                completion = client.chat.completions.create(
                    model="gpt-3.5-turbo",
                    messages=messages,
                    temperature=0
                )
                answer = completion.choices[0].message.content

                if check_repetition(answer):
                # Optional: Repetition check
                    presence_penalty = 0.6
                    completion = client.chat.completions.create(
                        model="gpt-3.5-turbo",
                        messages=messages,
                        presence_penalty=presence_penalty,
                        temperature=0
                    )
                    logging.info(f"presence_penalty: {presence_penalty}")
                    answer = completion.choices[0].message.content
                    logging.info('ANSWER: GPT')
                    logging.info(answer + '\n\n')

                    # Update parsed answer
                    can_description_answer = answer.split('#')
                    
                else:
                    logging.info('ANSWER: GPT')
                    logging.info(answer + '\n\n')

                    # Split for further processing
                    can_description_answer = answer.split('#')


            except Exception as e:
                logging.error(f"Error during GPT API call: {e}")

        # Map canonical descriptions to most similar words in vocabulary
        fsw, can_desc_association_seq, joint_prob = get_fsw_exact(
            vocabulary=vocabulary, 
            can_desc_answer=can_description_answer, 
            model=model
        )

        return answer, fsw, can_desc_association_seq, joint_prob
    
    elif modality =='sign2text':

       # Load the rules prompt from the file
        with open(rules_prompt_path, 'r') as file:
            rules_prompt = file.read()


        # Find the most similar sentences from training set
        decomposition, sentences = find_most_similar_sentence(
            user_embedding=sent_embedding, 
            train_sentences=train_sentences, 
            n=30
        )

        messages = [{"role": "system", "content": rules_prompt}]
        for sentence, decomposition in zip(sentences, decomposition):
            # Ensure each message has 'role' and 'content' keys
            if sentence and decomposition:
                messages.append({"role": "user", "content": decomposition})
                messages.append({"role": "assistant", "content": sentence}) # qui stiamo invertendo il task! dalla decomposition vogliamo che l'assistant ci dia la sentence
            else:
                logging.warning("Missing 'sentence' or 'decomposition' in messages.")

        messages.append({"role": "user", "content": "reconstruct the sentence as shown on the examples above"})
        messages.append({"role": "user", "content": input})
        
        # Validate the constructed messages before converting to prompt text
        valid_messages = []
        for message in messages:
            if 'role' in message and 'content' in message:
                valid_messages.append(message)
                logging.debug(message)
            else:
                logging.error(f"Invalid message format detected: {message}")

        if ollama:
            # Query the LLM using query_ollama instead of llm_pipeline
            answer = query_ollama(messages)#, model="mistral:7b-instruct-fp16")

            logging.info("\n[LOG] MISTRAL Answer:")
            logging.info(answer)

            can_description_answer = answer.split('#')
        else:
            try:
                # Initial API call
                completion = client.chat.completions.create(
                    model="gpt-3.5-turbo",
                    messages=messages,
                    temperature=0
                )
                answer = completion.choices[0].message.content
                logging.info('ANSWER: GPT')
                logging.info(answer + '\n\n')


            except Exception as e:
                logging.error(f"Error during GPT API call: {e}")

        return answer
    else:
        return 'error'
    

def main(modality, setup, input=None):
    np.random.seed(42)
    current_time = datetime.now().strftime("%Y_%m_%d_%H_%M")
    data_path = f"data/preprocess_output_{setup}/file_comparison"
    corpus_embeddings_path = 'tools/corpus_embeddings.json'
    if setup is None:
        sentences_train_embeddings_path = f"tools/sentences_train_embeddings_filtered_01.json"
    else:
        sentences_train_embeddings_path = f"tools/sentences_train_embeddings_{setup}.json"
    rules_prompt_path_text2sign = 'tools/rules_prompt_text2sign.txt'
    rules_prompt_path_sign2text = 'tools/rules_prompt_sign2text.txt'

    # Model to use for sentence embeddings
    model_name = "mixedbread-ai/mxbai-embed-large-v1"
    model = SentenceTransformer(model_name)

    # Load embeddings
    with open(corpus_embeddings_path, 'r') as file:
        corpus_embeddings = pd.DataFrame(json.load(file))

    with open(sentences_train_embeddings_path, 'r') as file:
        sentences_train_embeddings = pd.DataFrame(json.load(file))

    if input:  # Se è fornita una frase personalizzata
        if modality == 'text2sign':
            answer, fsw_seq, can_desc_association_seq, joint_prob = AulSign(
                input=input,
                rules_prompt_path=rules_prompt_path_text2sign,
                train_sentences=sentences_train_embeddings,
                vocabulary=corpus_embeddings,
                model=model,
                ollama=False,
                modality=modality
            )
            #print(f"Input Sentence: {input}")
            print(f"Canonical Descriptions: {can_desc_association_seq}")
            print(f"Translation (FSW): {fsw_seq}")
            #print(f"Canonical Descriptions: {can_desc_association_seq}")
            #print(f"Joint Probability: {joint_prob}")
        
        elif modality == 'sign2text': #qui l'input è una FSW seq, che deve essere mappata in canonicals
            mapped_input = sign2text(input,corpus_embeddings_path)
            logging.info(f"\nReconstructed Sentence via Vocaboulary: {mapped_input}")
            answer= AulSign(
                input=mapped_input,
                rules_prompt_path=rules_prompt_path_sign2text,
                train_sentences=sentences_train_embeddings,
                vocabulary=corpus_embeddings,
                model=model,
                ollama=False,
                modality=modality
            )
            print(f"Input Sign Voucaboualry Mapping: {input}")
            print(f"Translation (Text): {answer}")

    else:  # Flusso standard con testset
        test_path = os.path.join(data_path, f"test.csv")
        test = pd.read_csv(test_path)
        test = test.head(1)

        if modality == 'text2sign':
            list_sentence = []
            list_answer = []
            list_fsw_seq = []
            can_desc_association_list = []
            prob_of_association_list = []

            for index, row in test.iterrows():
                sentence = row['sentence']
                answer, fsw_seq, can_desc_association_seq, joint_prob = AulSign(
                    input=sentence,
                    rules_prompt_path=rules_prompt_path_text2sign,
                    train_sentences=sentences_train_embeddings,
                    vocabulary=corpus_embeddings,
                    model=model,
                    ollama=False,
                    modality=modality
                )

                list_sentence.append(sentence)
                list_answer.append(answer)
                list_fsw_seq.append(fsw_seq)
                can_desc_association_list.append(can_desc_association_seq)
                prob_of_association_list.append(joint_prob)
            
            df_pred = pd.DataFrame({
                'sentence': list_sentence,
                'pseudo_cd': list_answer,
                'pred_cd': can_desc_association_list,
                'joint_prob': prob_of_association_list,
                'pred_fsw_seq': list_fsw_seq
            })
            output_path = os.path.join('result', f"{modality}_{current_time}")
            os.makedirs(output_path, exist_ok=True)
            df_pred = prepare_dataset(df_pred,test,modality)
            df_pred.to_csv(os.path.join(output_path, f'result_{current_time}.csv'), index=False)

        elif modality == 'sign2text':

            list_answer = []
            list_gold_cd = []

            for index, row in test.iterrows():
                dec_sentence = row['word']
                answer = AulSign(
                    input=dec_sentence,
                    rules_prompt_path=rules_prompt_path_sign2text,
                    train_sentences=sentences_train_embeddings,
                    vocabulary=corpus_embeddings,
                    model=model,
                    ollama=False,
                    modality=modality
                )
                list_gold_cd.append(dec_sentence)
                list_answer.append(answer)
            
            df_pred = pd.DataFrame({
                'pseudo_sentence': list_answer,
                'gold_cd': list_gold_cd,
            })
            output_path = os.path.join('result', f"{modality}_{current_time}")
            os.makedirs(output_path, exist_ok=True)
            df_pred = prepare_dataset(df_pred,test,modality)
            df_pred.to_csv(os.path.join(output_path, f'result_{current_time}.csv'), index=False)

if __name__ == "__main__":
    
    #sentence_to_analyze = "This is a new ASL translator"
    #main(modality='text2sign', setup="filtered_01", input=sentence_to_analyze)
    #main(modality='text2sign', setup="filtered_01")

 
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", required=True, help="Mode of operation: text2sign or sign2text")
    parser.add_argument("--input", help="Input text or sign sequence")
    args = parser.parse_args()

    main(args.mode, setup=None, input=args.input)