Spaces:
Sleeping
Sleeping
| """Utility functions for model fine-tuning.""" | |
| import json | |
| import logging | |
| from typing import Dict, List, Optional | |
| from pathlib import Path | |
| from datetime import datetime | |
| logger = logging.getLogger(__name__) | |
| # Constants | |
| DATA_DIR = Path("src/fine_tuning/data") | |
| TRAINING_DATA_FILE = DATA_DIR / "reranker_training_data.jsonl" | |
| MODEL_METADATA_FILE = DATA_DIR / "model_metadata.json" | |
| USER_FEEDBACK_FILE = DATA_DIR / "user_feedback.jsonl" | |
| MODEL_DIR = DATA_DIR / "models/fine_tuned" | |
| MAX_OLD_MODELS = 3 # Maximum number of old model versions to keep | |
| def save_training_data(training_samples: List[Dict], append: bool = True) -> None: | |
| """Save training samples to a JSONL file.""" | |
| try: | |
| DATA_DIR.mkdir(parents=True, exist_ok=True) | |
| mode = 'a' if append else 'w' | |
| with open(TRAINING_DATA_FILE, mode, encoding='utf-8') as f: | |
| for sample in training_samples: | |
| json.dump(sample, f, ensure_ascii=False) | |
| f.write('\n') | |
| logger.info(f"Saved {len(training_samples)} training samples to {TRAINING_DATA_FILE}") | |
| except Exception as e: | |
| logger.error(f"Error saving training data: {e}") | |
| def load_training_data() -> List[Dict]: | |
| """Load training samples from the JSONL file.""" | |
| samples = [] | |
| try: | |
| if TRAINING_DATA_FILE.exists(): | |
| with open(TRAINING_DATA_FILE, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| if line.strip(): | |
| samples.append(json.loads(line)) | |
| logger.info(f"Loaded {len(samples)} training samples from {TRAINING_DATA_FILE}") | |
| except Exception as e: | |
| logger.error(f"Error loading training data: {e}") | |
| return samples | |
| def save_model_metadata(metadata: Dict) -> None: | |
| """Save model metadata to a JSON file.""" | |
| try: | |
| DATA_DIR.mkdir(parents=True, exist_ok=True) | |
| # Ensure the metadata contains a timestamp | |
| metadata['last_updated'] = datetime.now().isoformat() | |
| with open(MODEL_METADATA_FILE, 'w', encoding='utf-8') as f: | |
| json.dump(metadata, f, ensure_ascii=False, indent=2) | |
| logger.info(f"Saved model metadata to {MODEL_METADATA_FILE}") | |
| except Exception as e: | |
| logger.error(f"Error saving model metadata: {e}") | |
| def load_model_metadata() -> Optional[Dict]: | |
| """Load model metadata from the JSON file.""" | |
| try: | |
| if MODEL_METADATA_FILE.exists(): | |
| with open(MODEL_METADATA_FILE, 'r', encoding='utf-8') as f: | |
| metadata = json.load(f) | |
| return metadata | |
| except Exception as e: | |
| logger.error(f"Error loading model metadata: {e}") | |
| return None | |
| def get_model_path(version: str) -> Path: | |
| """Get the path for a specific model version.""" | |
| MODEL_DIR.mkdir(parents=True, exist_ok=True) | |
| return MODEL_DIR / f"reranker_{version}" | |
| def get_latest_model_version() -> str: | |
| """Get the latest model version from existing files.""" | |
| try: | |
| versions = [] | |
| if MODEL_DIR.exists(): | |
| for path in MODEL_DIR.glob("reranker_v*"): | |
| version = path.name.split('_')[-1] # Extract version from filename | |
| if version.startswith('v') and version[1:].isdigit(): | |
| versions.append(version) | |
| return max(versions, default="v0", key=lambda x: int(x[1:])) | |
| except Exception as e: | |
| logger.error(f"Error getting latest model version: {e}") | |
| return "v0" | |
| def cleanup_old_models() -> None: | |
| """Remove old model versions, keeping only the most recent ones.""" | |
| try: | |
| if MODEL_DIR.exists(): | |
| versions = [] | |
| for path in MODEL_DIR.glob("reranker_v*"): | |
| version = path.name.split('_')[-1] | |
| if version.startswith('v') and version[1:].isdigit(): | |
| versions.append((version, path)) | |
| # Sort by version number (descending) | |
| versions.sort(key=lambda x: int(x[0][1:]), reverse=True) | |
| # Remove old versions beyond the limit | |
| for version, path in versions[MAX_OLD_MODELS:]: | |
| try: | |
| path.unlink() # Delete the file | |
| logger.info(f"Removed old model version: {version}") | |
| except Exception as e: | |
| logger.error(f"Error removing model version {version}: {e}") | |
| except Exception as e: | |
| logger.error(f"Error during model cleanup: {e}") | |
| def load_user_feedback() -> Dict[str, Dict]: | |
| """ | |
| Load user feedback data from the feedback tracking database. | |
| Returns a dictionary mapping query-candidate pairs to feedback information. | |
| """ | |
| feedback_data = {} | |
| try: | |
| if USER_FEEDBACK_FILE.exists(): | |
| with open(USER_FEEDBACK_FILE, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| if line.strip(): | |
| feedback = json.loads(line) | |
| # Create a unique key for the query-candidate pair | |
| key = f"{feedback['query_text']}_{feedback['candidate_text']}" | |
| # Store feedback with confidence and timestamp | |
| feedback_data[key] = { | |
| 'rating': feedback['rating'], | |
| 'confidence': feedback.get('confidence', 1.0), | |
| 'timestamp': feedback.get('timestamp', datetime.now().isoformat()), | |
| 'user_id': feedback.get('user_id', 'anonymous'), | |
| 'interaction_type': feedback.get('interaction_type', 'explicit'), | |
| 'session_id': feedback.get('session_id', None) | |
| } | |
| logger.info(f"Loaded {len(feedback_data)} user feedback entries") | |
| except Exception as e: | |
| logger.error(f"Error loading user feedback data: {e}") | |
| return feedback_data | |
| def save_user_feedback(feedback: Dict) -> None: | |
| """ | |
| Save a user feedback entry to the feedback tracking database. | |
| """ | |
| try: | |
| DATA_DIR.mkdir(parents=True, exist_ok=True) | |
| with open(USER_FEEDBACK_FILE, 'a', encoding='utf-8') as f: | |
| feedback['timestamp'] = datetime.now().isoformat() | |
| json.dump(feedback, f, ensure_ascii=False) | |
| f.write('\n') | |
| logger.info(f"Saved user feedback for query: {feedback.get('query_text', 'unknown')}") | |
| except Exception as e: | |
| logger.error(f"Error saving user feedback: {e}") |