sundaram22verma's picture
initial commit
9d76e23
"""Utility functions for model fine-tuning."""
import json
import logging
from typing import Dict, List, Optional
from pathlib import Path
from datetime import datetime
logger = logging.getLogger(__name__)
# Constants
DATA_DIR = Path("src/fine_tuning/data")
TRAINING_DATA_FILE = DATA_DIR / "reranker_training_data.jsonl"
MODEL_METADATA_FILE = DATA_DIR / "model_metadata.json"
USER_FEEDBACK_FILE = DATA_DIR / "user_feedback.jsonl"
MODEL_DIR = DATA_DIR / "models/fine_tuned"
MAX_OLD_MODELS = 3 # Maximum number of old model versions to keep
def save_training_data(training_samples: List[Dict], append: bool = True) -> None:
"""Save training samples to a JSONL file."""
try:
DATA_DIR.mkdir(parents=True, exist_ok=True)
mode = 'a' if append else 'w'
with open(TRAINING_DATA_FILE, mode, encoding='utf-8') as f:
for sample in training_samples:
json.dump(sample, f, ensure_ascii=False)
f.write('\n')
logger.info(f"Saved {len(training_samples)} training samples to {TRAINING_DATA_FILE}")
except Exception as e:
logger.error(f"Error saving training data: {e}")
def load_training_data() -> List[Dict]:
"""Load training samples from the JSONL file."""
samples = []
try:
if TRAINING_DATA_FILE.exists():
with open(TRAINING_DATA_FILE, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
samples.append(json.loads(line))
logger.info(f"Loaded {len(samples)} training samples from {TRAINING_DATA_FILE}")
except Exception as e:
logger.error(f"Error loading training data: {e}")
return samples
def save_model_metadata(metadata: Dict) -> None:
"""Save model metadata to a JSON file."""
try:
DATA_DIR.mkdir(parents=True, exist_ok=True)
# Ensure the metadata contains a timestamp
metadata['last_updated'] = datetime.now().isoformat()
with open(MODEL_METADATA_FILE, 'w', encoding='utf-8') as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
logger.info(f"Saved model metadata to {MODEL_METADATA_FILE}")
except Exception as e:
logger.error(f"Error saving model metadata: {e}")
def load_model_metadata() -> Optional[Dict]:
"""Load model metadata from the JSON file."""
try:
if MODEL_METADATA_FILE.exists():
with open(MODEL_METADATA_FILE, 'r', encoding='utf-8') as f:
metadata = json.load(f)
return metadata
except Exception as e:
logger.error(f"Error loading model metadata: {e}")
return None
def get_model_path(version: str) -> Path:
"""Get the path for a specific model version."""
MODEL_DIR.mkdir(parents=True, exist_ok=True)
return MODEL_DIR / f"reranker_{version}"
def get_latest_model_version() -> str:
"""Get the latest model version from existing files."""
try:
versions = []
if MODEL_DIR.exists():
for path in MODEL_DIR.glob("reranker_v*"):
version = path.name.split('_')[-1] # Extract version from filename
if version.startswith('v') and version[1:].isdigit():
versions.append(version)
return max(versions, default="v0", key=lambda x: int(x[1:]))
except Exception as e:
logger.error(f"Error getting latest model version: {e}")
return "v0"
def cleanup_old_models() -> None:
"""Remove old model versions, keeping only the most recent ones."""
try:
if MODEL_DIR.exists():
versions = []
for path in MODEL_DIR.glob("reranker_v*"):
version = path.name.split('_')[-1]
if version.startswith('v') and version[1:].isdigit():
versions.append((version, path))
# Sort by version number (descending)
versions.sort(key=lambda x: int(x[0][1:]), reverse=True)
# Remove old versions beyond the limit
for version, path in versions[MAX_OLD_MODELS:]:
try:
path.unlink() # Delete the file
logger.info(f"Removed old model version: {version}")
except Exception as e:
logger.error(f"Error removing model version {version}: {e}")
except Exception as e:
logger.error(f"Error during model cleanup: {e}")
def load_user_feedback() -> Dict[str, Dict]:
"""
Load user feedback data from the feedback tracking database.
Returns a dictionary mapping query-candidate pairs to feedback information.
"""
feedback_data = {}
try:
if USER_FEEDBACK_FILE.exists():
with open(USER_FEEDBACK_FILE, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
feedback = json.loads(line)
# Create a unique key for the query-candidate pair
key = f"{feedback['query_text']}_{feedback['candidate_text']}"
# Store feedback with confidence and timestamp
feedback_data[key] = {
'rating': feedback['rating'],
'confidence': feedback.get('confidence', 1.0),
'timestamp': feedback.get('timestamp', datetime.now().isoformat()),
'user_id': feedback.get('user_id', 'anonymous'),
'interaction_type': feedback.get('interaction_type', 'explicit'),
'session_id': feedback.get('session_id', None)
}
logger.info(f"Loaded {len(feedback_data)} user feedback entries")
except Exception as e:
logger.error(f"Error loading user feedback data: {e}")
return feedback_data
def save_user_feedback(feedback: Dict) -> None:
"""
Save a user feedback entry to the feedback tracking database.
"""
try:
DATA_DIR.mkdir(parents=True, exist_ok=True)
with open(USER_FEEDBACK_FILE, 'a', encoding='utf-8') as f:
feedback['timestamp'] = datetime.now().isoformat()
json.dump(feedback, f, ensure_ascii=False)
f.write('\n')
logger.info(f"Saved user feedback for query: {feedback.get('query_text', 'unknown')}")
except Exception as e:
logger.error(f"Error saving user feedback: {e}")