In [33]:
import dspy
from dspy.teleprompt import MIPROv2
from typing import List, Dict
import json
import numpy as np
import os
import random
from tqdm import tqdm

In [31]:
data_path = "/root/notebooks/MT_TQ/TQ/DataPrep_Prompting_Experiments/labeled_data/parsed/"
json_files = [os.path.join(root, file) for root, _, files in os.walk(data_path) for file in files if file.endswith('.json') and 'PLDL' in file]

training_samples = []
for json_file in tqdm(json_files):
    with open(json_file, 'r') as file:
        data = json.load(file)
    sampled_items = random.sample(data["data"], 20)
    training_samples.extend(sampled_items)

datapoints = []

for sample in training_samples:
    datapoint = {"input":{}}
    datapoint["input"]["src_text"] = sample["main_src_text"]
    datapoint["input"]["tgt_text"] = sample["tgt_text"]
    datapoint["input"]["src_prev"] = sample["tt_src_prev"]
    datapoint["input"]["src_next"] = sample["tt_src_next"]
    datapoint["input"]["tgt_prev"] = sample["tt_tgt_prev"]
    datapoint["input"]["tgt_next"] = sample["tt_tgt_next"]
    datapoint["input"]["src_lang"] = sample["src_lang"]
    datapoint["input"]["tgt_lang"] = sample["tgt_lang"]
    datapoint["evaluation"] = sample["labelers"][0]["annotation"]
    datapoints.append(datapoint)

datapoint

100%|██████████| 4/4 [00:00<00:00, 77.75it/s]


{'input': {'src_text': 'Ma io che ne so, comandà? Io stavo a casa di mia madre, lo sapete.\n\nLo so.',
  'tgt_text': "What do I know, Commander? I was at my mom's house, you know it.\n\nI knows.",
  'src_prev': "Questa è una linea. Qua faccio quello che voglio, è terra mia, la legge è mia. Dall'altro lato c'è un mondo fatto di spazzatura. Questa linea non l'ho mai oltrepassata. Impara chi è tua madre una volta per tutte. Tieni, questo era per te. Mà… Mà! Secondo me non è stata lei. Come al solito ti sei fatto prendere per il culo. Comandà, credo che non è stata lei. Carmine, sei uno stronzo. Robè, portalo via. Andiamocene.",
  'src_next': 'E allora che altro vi devo dire? Tu non devi dire niente. Devi tenere la bocca chiusa. E non dire a nessuno quello che ti ho detto. Ma a nessuno però. Ho capito. Però devi tenere le orecchie aperte e ascoltare tutto quello che si dice qua dentro. Perché prima o poi, chi fa queste cose parla. Si deve atteggiare, si deve fare grosso. Che si è divertito

In [35]:
class TranslationQualityChecker(dspy.Signature):
    """Evaluate the quality of translation."""
    
    context = dspy.InputField(desc="Source and target text with context")
    evaluation = dspy.OutputField(desc="Detailed evaluation of the translation quality")

class TranslationQualityModule(dspy.Module):
    def __init__(self):
        super().__init__()
        self.checker = dspy.Predict(TranslationQualityChecker)
    
    def forward(self, src_text, tgt_text, src_prev, tgt_prev, src_next, tgt_next, src_lang, tgt_lang):
        context = {
            "source_text": src_text,
            "target_text": tgt_text,
            "source_previous": src_prev,
            "target_previous": tgt_prev,
            "source_next": src_next,
            "target_next": tgt_next,
            "source_language": src_lang,
            "target_language": tgt_lang
        }
        
        prediction = self.checker(context=context)
        return prediction.evaluation

# Create a custom backend using your Netflix model
class NetflixBackend(dspy.BackendBase):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    def complete(self, prompt, **kwargs):
        messages = [{"role": "user", "content": prompt}]
        response = self.model.generate(messages)
        return response

    def completions(self, prompts, **kwargs):
        return [self.complete(prompt, **kwargs) for prompt in prompts]

# Prepare training data
def prepare_training_data(data_points):
    compiled_data = []
    for dp in data_points:
        input_data = dp['input']
        train_example = dspy.Example(
            context={
                "source_text": input_data['src_text'],
                "target_text": input_data['tgt_text'],
                "source_previous": input_data['src_prev'],
                "target_previous": input_data['tgt_prev'],
                "source_next": input_data['src_next'],
                "target_next": input_data['tgt_next'],
                "source_language": input_data['src_lang'],
                "target_language": input_data['tgt_lang']
            },
            evaluation=dp['evaluation']
        )
        compiled_data.append(train_example)
    return compiled_data

def optimize_prompt(model, training_data, validation_data):
    # Initialize DSPy with your custom backend
    backend = NetflixBackend(model)
    dspy.settings.configure(lm=backend)
    
    # Create the optimizer
    optimizer = MIPROv2(
        metric="exact_match",  # or another appropriate metric
        max_rounds=5,
        max_prompts=3,
        temp=0.7
    )
    
    # Compile the module
    translation_module = TranslationQualityModule()
    
    # Optimize the prompt
    optimized_module = optimizer.optimize(
        module=translation_module,
        trainset=training_data,
        valset=validation_data,
        metric=dspy.evaluate.answer_exact_match
    )
    
    return optimized_module

AttributeError: module 'dspy' has no attribute 'Predictor'

In [None]:
class TranslationQualityAssessor(dspy.Module):
    def __init__(self):
        super().__init__()
        self.assess = dspy.ChainOfThought(TranslationQualitySignature)

    def forward(self, src_lang, tgt_lang, src_text, translation, src_prev="", tgt_prev="", src_next="", tgt_next=""):
        context = f"""Previous Context:
                Source: {src_prev}
                Translation: {tgt_prev}
                
                Next Context:
                Source: {src_next}
                Translation: {tgt_next}"""

        result = self.assess(
            context=context,
            source=f"Source ({src_lang}): {src_text}",
            translation=f"Translation ({tgt_lang}): {translation}"
        )
        
        return result.evaluation

class TranslationMetrics:
    @staticmethod
    def exact_match_score(pred, gold):
        try:
            pred_json = json.loads(pred)
            gold_json = gold
            
            accuracy_match = (str(pred_json.get('Accuracy Score')) == str(gold_json.get('Accuracy Score')))
            readability_match = (str(pred_json.get('Readability Score')) == str(gold_json.get('Readability Score')))
            
            return (accuracy_match and readability_match)
        except:
            return False
    
    @staticmethod
    def partial_match_score(pred, gold):
        try:
            pred_json = json.loads(pred)
            gold_json = gold
            
            # Score comparison
            accuracy_diff = abs(float(pred_json.get('Accuracy Score', 0)) - float(gold_json.get('Accuracy Score', 0)))
            readability_diff = abs(float(pred_json.get('Readability Score', 0)) - float(gold_json.get('Readability Score', 0)))
            
            # Issues comparison
            pred_accuracy_issues = set(str(issue) for issue in pred_json.get('Accuracy Issues', []))
            gold_accuracy_issues = set(str(issue) for issue in gold_json.get('Accuracy Issues', []))
            pred_readability_issues = set(str(issue) for issue in pred_json.get('Readability Issues', []))
            gold_readability_issues = set(str(issue) for issue in gold_json.get('Readability Issues', []))
            
            # Calculate Jaccard similarity for issues
            accuracy_issues_sim = len(pred_accuracy_issues & gold_accuracy_issues) / max(1, len(pred_accuracy_issues | gold_accuracy_issues))
            readability_issues_sim = len(pred_readability_issues & gold_readability_issues) / max(1, len(pred_readability_issues | gold_readability_issues))
            
            # Combine scores (0.6 weight to scores, 0.4 to issues similarity)
            score_component = 1 - ((accuracy_diff + readability_diff) / 8)
            issues_component = (accuracy_issues_sim + readability_issues_sim) / 2
            
            final_score = 0.6 * score_component + 0.4 * issues_component
            return max(0, final_score)
        except:
            return 0

def prepare_dataset(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    
    prepared_data = []
    
    for item in data:
        example = dspy.Example(
            context=f"""Previous Context:
                    Source: {item['src_prev']}
                    Translation: {item['tgt_prev']}
                    
                    Next Context:
                    Source: {item['src_next']}
                    Translation: {item['tgt_next']}""",
            source=f"Source ({item['src_lang']}): {item['src_text']}",
            translation=f"Translation ({item['tgt_lang']}): {item['main_text']}",
            evaluation=json.dumps(item['evaluation'], ensure_ascii=False)
        ).with_inputs("context", "source", "translation")
        
        prepared_data.append(example)
    
    # Split data: 70% train, 15% dev, 15% test
    train_size = int(0.7 * len(prepared_data))
    dev_size = int(0.15 * len(prepared_data))
    
    train_data = prepared_data[:train_size]
    dev_data = prepared_data[train_size:train_size + dev_size]
    test_data = prepared_data[train_size + dev_size:]
    
    return train_data, dev_data, test_data

def optimize_translation_quality_assessment():
    # Initialize DSPy
    lm = TranslationQualityLM()
    dspy.settings.configure(lm=lm)
    
    # Load and prepare dataset
    train_data, dev_data, test_data = prepare_dataset('translation_quality_dataset.json')
    
    # Create evaluator
    evaluator = Evaluate(
        metrics={
            'exact_match': TranslationMetrics.exact_match_score,
            'partial_match': TranslationMetrics.partial_match_score
        }
    )
    
    # Initialize module
    assessor = TranslationQualityAssessor()
    
    # Initialize MIPROv2 optimizer
    optimizer = dspy.MIPROv2(
        metric=lambda x: x['partial_match'],
        max_rounds=5,              # Number of optimization rounds
        max_traces=10,            # Number of traces per round
        max_depth=3,              # Maximum depth of reasoning chains
        num_candidate_prompts=5,  # Number of candidate prompts to generate
        num_rounds_per_prompt=3,  # Number of rounds per candidate prompt
        temperature=0.7,
        verbose=True
    )
    
    # Compile the module with optimization
    compiled_assessor = optimizer.compile(
        assessor,
        trainset=train_data,
        devset=dev_data,
        eval_kwargs={
            'metric': 'partial_match',
            'num_threads': 4,
            'batch_size': 8
        }
    )
    
    # Evaluate on test set
    results = []
    for example in test_data:
        pred = compiled_assessor(
            context=example.context,
            source=example.source,
            translation=example.translation
        )
        
        result = evaluator.evaluate(
            predictions=[pred],
            ground_truth=[example.evaluation]
        )
        results.append(result)
    
    # Calculate and print final metrics
    avg_exact_match = np.mean([r['exact_match'] for r in results])
    avg_partial_match = np.mean([r['partial_match'] for r in results])
    
    print(f"Average Exact Match Score: {avg_exact_match:.3f}")
    print(f"Average Partial Match Score: {avg_partial_match:.3f}")
    
    return compiled_assessor

if __name__ == "__main__":
    optimized_assessor = optimize_translation_quality_assessment()