Spaces:
Sleeping
Sleeping
File size: 6,466 Bytes
9d76e23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
"""Utility functions for model fine-tuning."""
import json
import logging
from typing import Dict, List, Optional
from pathlib import Path
from datetime import datetime
logger = logging.getLogger(__name__)
# Constants
DATA_DIR = Path("src/fine_tuning/data")
TRAINING_DATA_FILE = DATA_DIR / "reranker_training_data.jsonl"
MODEL_METADATA_FILE = DATA_DIR / "model_metadata.json"
USER_FEEDBACK_FILE = DATA_DIR / "user_feedback.jsonl"
MODEL_DIR = DATA_DIR / "models/fine_tuned"
MAX_OLD_MODELS = 3 # Maximum number of old model versions to keep
def save_training_data(training_samples: List[Dict], append: bool = True) -> None:
"""Save training samples to a JSONL file."""
try:
DATA_DIR.mkdir(parents=True, exist_ok=True)
mode = 'a' if append else 'w'
with open(TRAINING_DATA_FILE, mode, encoding='utf-8') as f:
for sample in training_samples:
json.dump(sample, f, ensure_ascii=False)
f.write('\n')
logger.info(f"Saved {len(training_samples)} training samples to {TRAINING_DATA_FILE}")
except Exception as e:
logger.error(f"Error saving training data: {e}")
def load_training_data() -> List[Dict]:
"""Load training samples from the JSONL file."""
samples = []
try:
if TRAINING_DATA_FILE.exists():
with open(TRAINING_DATA_FILE, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
samples.append(json.loads(line))
logger.info(f"Loaded {len(samples)} training samples from {TRAINING_DATA_FILE}")
except Exception as e:
logger.error(f"Error loading training data: {e}")
return samples
def save_model_metadata(metadata: Dict) -> None:
"""Save model metadata to a JSON file."""
try:
DATA_DIR.mkdir(parents=True, exist_ok=True)
# Ensure the metadata contains a timestamp
metadata['last_updated'] = datetime.now().isoformat()
with open(MODEL_METADATA_FILE, 'w', encoding='utf-8') as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
logger.info(f"Saved model metadata to {MODEL_METADATA_FILE}")
except Exception as e:
logger.error(f"Error saving model metadata: {e}")
def load_model_metadata() -> Optional[Dict]:
"""Load model metadata from the JSON file."""
try:
if MODEL_METADATA_FILE.exists():
with open(MODEL_METADATA_FILE, 'r', encoding='utf-8') as f:
metadata = json.load(f)
return metadata
except Exception as e:
logger.error(f"Error loading model metadata: {e}")
return None
def get_model_path(version: str) -> Path:
"""Get the path for a specific model version."""
MODEL_DIR.mkdir(parents=True, exist_ok=True)
return MODEL_DIR / f"reranker_{version}"
def get_latest_model_version() -> str:
"""Get the latest model version from existing files."""
try:
versions = []
if MODEL_DIR.exists():
for path in MODEL_DIR.glob("reranker_v*"):
version = path.name.split('_')[-1] # Extract version from filename
if version.startswith('v') and version[1:].isdigit():
versions.append(version)
return max(versions, default="v0", key=lambda x: int(x[1:]))
except Exception as e:
logger.error(f"Error getting latest model version: {e}")
return "v0"
def cleanup_old_models() -> None:
"""Remove old model versions, keeping only the most recent ones."""
try:
if MODEL_DIR.exists():
versions = []
for path in MODEL_DIR.glob("reranker_v*"):
version = path.name.split('_')[-1]
if version.startswith('v') and version[1:].isdigit():
versions.append((version, path))
# Sort by version number (descending)
versions.sort(key=lambda x: int(x[0][1:]), reverse=True)
# Remove old versions beyond the limit
for version, path in versions[MAX_OLD_MODELS:]:
try:
path.unlink() # Delete the file
logger.info(f"Removed old model version: {version}")
except Exception as e:
logger.error(f"Error removing model version {version}: {e}")
except Exception as e:
logger.error(f"Error during model cleanup: {e}")
def load_user_feedback() -> Dict[str, Dict]:
"""
Load user feedback data from the feedback tracking database.
Returns a dictionary mapping query-candidate pairs to feedback information.
"""
feedback_data = {}
try:
if USER_FEEDBACK_FILE.exists():
with open(USER_FEEDBACK_FILE, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
feedback = json.loads(line)
# Create a unique key for the query-candidate pair
key = f"{feedback['query_text']}_{feedback['candidate_text']}"
# Store feedback with confidence and timestamp
feedback_data[key] = {
'rating': feedback['rating'],
'confidence': feedback.get('confidence', 1.0),
'timestamp': feedback.get('timestamp', datetime.now().isoformat()),
'user_id': feedback.get('user_id', 'anonymous'),
'interaction_type': feedback.get('interaction_type', 'explicit'),
'session_id': feedback.get('session_id', None)
}
logger.info(f"Loaded {len(feedback_data)} user feedback entries")
except Exception as e:
logger.error(f"Error loading user feedback data: {e}")
return feedback_data
def save_user_feedback(feedback: Dict) -> None:
"""
Save a user feedback entry to the feedback tracking database.
"""
try:
DATA_DIR.mkdir(parents=True, exist_ok=True)
with open(USER_FEEDBACK_FILE, 'a', encoding='utf-8') as f:
feedback['timestamp'] = datetime.now().isoformat()
json.dump(feedback, f, ensure_ascii=False)
f.write('\n')
logger.info(f"Saved user feedback for query: {feedback.get('query_text', 'unknown')}")
except Exception as e:
logger.error(f"Error saving user feedback: {e}") |