File size: 6,466 Bytes
9d76e23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""Utility functions for model fine-tuning."""

import json
import logging
from typing import Dict, List, Optional
from pathlib import Path
from datetime import datetime

logger = logging.getLogger(__name__)

# Constants
DATA_DIR = Path("src/fine_tuning/data")
TRAINING_DATA_FILE = DATA_DIR / "reranker_training_data.jsonl"
MODEL_METADATA_FILE = DATA_DIR / "model_metadata.json"
USER_FEEDBACK_FILE = DATA_DIR / "user_feedback.jsonl"
MODEL_DIR = DATA_DIR / "models/fine_tuned"
MAX_OLD_MODELS = 3  # Maximum number of old model versions to keep

def save_training_data(training_samples: List[Dict], append: bool = True) -> None:
    """Save training samples to a JSONL file."""
    try:
        DATA_DIR.mkdir(parents=True, exist_ok=True)
        mode = 'a' if append else 'w'
        with open(TRAINING_DATA_FILE, mode, encoding='utf-8') as f:
            for sample in training_samples:
                json.dump(sample, f, ensure_ascii=False)
                f.write('\n')
        logger.info(f"Saved {len(training_samples)} training samples to {TRAINING_DATA_FILE}")
    except Exception as e:
        logger.error(f"Error saving training data: {e}")

def load_training_data() -> List[Dict]:
    """Load training samples from the JSONL file."""
    samples = []
    try:
        if TRAINING_DATA_FILE.exists():
            with open(TRAINING_DATA_FILE, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        samples.append(json.loads(line))
        logger.info(f"Loaded {len(samples)} training samples from {TRAINING_DATA_FILE}")
    except Exception as e:
        logger.error(f"Error loading training data: {e}")
    return samples

def save_model_metadata(metadata: Dict) -> None:
    """Save model metadata to a JSON file."""
    try:
        DATA_DIR.mkdir(parents=True, exist_ok=True)
        # Ensure the metadata contains a timestamp
        metadata['last_updated'] = datetime.now().isoformat()
        with open(MODEL_METADATA_FILE, 'w', encoding='utf-8') as f:
            json.dump(metadata, f, ensure_ascii=False, indent=2)
        logger.info(f"Saved model metadata to {MODEL_METADATA_FILE}")
    except Exception as e:
        logger.error(f"Error saving model metadata: {e}")

def load_model_metadata() -> Optional[Dict]:
    """Load model metadata from the JSON file."""
    try:
        if MODEL_METADATA_FILE.exists():
            with open(MODEL_METADATA_FILE, 'r', encoding='utf-8') as f:
                metadata = json.load(f)
            return metadata
    except Exception as e:
        logger.error(f"Error loading model metadata: {e}")
    return None

def get_model_path(version: str) -> Path:
    """Get the path for a specific model version."""
    MODEL_DIR.mkdir(parents=True, exist_ok=True)
    return MODEL_DIR / f"reranker_{version}"

def get_latest_model_version() -> str:
    """Get the latest model version from existing files."""
    try:
        versions = []
        if MODEL_DIR.exists():
            for path in MODEL_DIR.glob("reranker_v*"):
                version = path.name.split('_')[-1]  # Extract version from filename
                if version.startswith('v') and version[1:].isdigit():
                    versions.append(version)
        return max(versions, default="v0", key=lambda x: int(x[1:]))
    except Exception as e:
        logger.error(f"Error getting latest model version: {e}")
        return "v0"

def cleanup_old_models() -> None:
    """Remove old model versions, keeping only the most recent ones."""
    try:
        if MODEL_DIR.exists():
            versions = []
            for path in MODEL_DIR.glob("reranker_v*"):
                version = path.name.split('_')[-1]
                if version.startswith('v') and version[1:].isdigit():
                    versions.append((version, path))
            
            # Sort by version number (descending)
            versions.sort(key=lambda x: int(x[0][1:]), reverse=True)
            
            # Remove old versions beyond the limit
            for version, path in versions[MAX_OLD_MODELS:]:
                try:
                    path.unlink()  # Delete the file
                    logger.info(f"Removed old model version: {version}")
                except Exception as e:
                    logger.error(f"Error removing model version {version}: {e}")
    except Exception as e:
        logger.error(f"Error during model cleanup: {e}")

def load_user_feedback() -> Dict[str, Dict]:
    """
    Load user feedback data from the feedback tracking database.
    Returns a dictionary mapping query-candidate pairs to feedback information.
    """
    feedback_data = {}
    try:
        if USER_FEEDBACK_FILE.exists():
            with open(USER_FEEDBACK_FILE, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        feedback = json.loads(line)
                        # Create a unique key for the query-candidate pair
                        key = f"{feedback['query_text']}_{feedback['candidate_text']}"
                        # Store feedback with confidence and timestamp
                        feedback_data[key] = {
                            'rating': feedback['rating'],
                            'confidence': feedback.get('confidence', 1.0),
                            'timestamp': feedback.get('timestamp', datetime.now().isoformat()),
                            'user_id': feedback.get('user_id', 'anonymous'),
                            'interaction_type': feedback.get('interaction_type', 'explicit'),
                            'session_id': feedback.get('session_id', None)
                        }
        logger.info(f"Loaded {len(feedback_data)} user feedback entries")
    except Exception as e:
        logger.error(f"Error loading user feedback data: {e}")
    return feedback_data

def save_user_feedback(feedback: Dict) -> None:
    """
    Save a user feedback entry to the feedback tracking database.
    """
    try:
        DATA_DIR.mkdir(parents=True, exist_ok=True)
        with open(USER_FEEDBACK_FILE, 'a', encoding='utf-8') as f:
            feedback['timestamp'] = datetime.now().isoformat()
            json.dump(feedback, f, ensure_ascii=False)
            f.write('\n')
        logger.info(f"Saved user feedback for query: {feedback.get('query_text', 'unknown')}")
    except Exception as e:
        logger.error(f"Error saving user feedback: {e}")