Spaces:
Sleeping
Sleeping
| """ | |
| File processor class for handling dataset file operations | |
| """ | |
| import os | |
| import json | |
| import time | |
| import asyncio | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| from datetime import datetime | |
| import aiohttp | |
| from sqlalchemy.orm import Session | |
| from models import ProcessingState, ProcessingStatusEnum | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class FileProcessor: | |
| """Object-oriented file processor for dataset integration""" | |
| def __init__(self, processed_files_dir: str = "processed_files"): | |
| """ | |
| Initialize file processor | |
| Args: | |
| processed_files_dir: Directory to store processed files | |
| """ | |
| self.processed_files_dir = Path(processed_files_dir) | |
| self.all_raw_dir = self.processed_files_dir / "all_raw" | |
| self.ato_raw_dir = self.processed_files_dir / "ato_raw" | |
| # Create directories | |
| self.processed_files_dir.mkdir(parents=True, exist_ok=True) | |
| self.all_raw_dir.mkdir(parents=True, exist_ok=True) | |
| self.ato_raw_dir.mkdir(parents=True, exist_ok=True) | |
| async def download_file( | |
| self, | |
| repo_id: str, | |
| filename: str, | |
| local_dir: Path, | |
| token: Optional[str] = None, | |
| ) -> Optional[str]: | |
| """ | |
| Download a single file from Hugging Face | |
| Args: | |
| repo_id: Repository ID (e.g., "samfred2/ALL") | |
| filename: File to download | |
| local_dir: Local directory to save file | |
| token: Optional HF token for authentication | |
| Returns: | |
| Path to downloaded file or None if failed | |
| """ | |
| try: | |
| logger.info(f"Downloading {filename} from {repo_id}...") | |
| await asyncio.sleep(1) # Rate limiting | |
| local_path = local_dir / filename | |
| local_path.parent.mkdir(parents=True, exist_ok=True) | |
| url = f"https://huggingface.co/api/datasets/{repo_id}/resolve/main/{filename}" | |
| headers = {"Authorization": f"Bearer {token}"} if token else {} | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get(url, headers=headers) as response: | |
| if response.status != 200: | |
| logger.error( | |
| f"Failed to download {filename}: HTTP {response.status}" | |
| ) | |
| return None | |
| content = await response.read() | |
| local_path.write_bytes(content) | |
| logger.info(f"Downloaded to {local_path}") | |
| return str(local_path) | |
| except Exception as e: | |
| logger.error(f"Error downloading {filename}: {e}") | |
| return None | |
| def load_json_file(self, file_path: str) -> Optional[Dict]: | |
| """ | |
| Load and parse JSON file | |
| Args: | |
| file_path: Path to JSON file | |
| Returns: | |
| Parsed JSON data or None if failed | |
| """ | |
| try: | |
| with open(file_path, "r") as f: | |
| return json.load(f) | |
| except Exception as e: | |
| logger.error(f"Error loading JSON from {file_path}: {e}") | |
| return None | |
| def find_matching_all_file( | |
| self, ato_filename: str, all_filenames: List[str] | |
| ) -> Optional[str]: | |
| """ | |
| Find matching ALL file for ATO file using suffix matching | |
| Args: | |
| ato_filename: ATO filename to match | |
| all_filenames: List of ALL filenames | |
| Returns: | |
| Matching ALL filename or None | |
| """ | |
| for all_name in all_filenames: | |
| if all_name.endswith(ato_filename): | |
| return all_name | |
| return None | |
| async def list_json_files( | |
| self, repo_id: str, token: Optional[str] = None | |
| ) -> List[str]: | |
| """ | |
| List all JSON files in a repository | |
| Args: | |
| repo_id: Repository ID | |
| token: Optional HF token | |
| Returns: | |
| List of JSON filenames | |
| """ | |
| try: | |
| url = f"https://huggingface.co/api/datasets/{repo_id}" | |
| headers = {"Authorization": f"Bearer {token}"} if token else {} | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get(url, headers=headers) as response: | |
| if response.status != 200: | |
| logger.error(f"Failed to list files from {repo_id}") | |
| return [] | |
| data = await response.json() | |
| siblings = data.get("siblings", []) | |
| return [ | |
| f["rfilename"] | |
| for f in siblings | |
| if f["rfilename"].endswith(".json") | |
| ] | |
| except Exception as e: | |
| logger.error(f"Error listing files from {repo_id}: {e}") | |
| return [] | |
| async def process_datasets( | |
| self, | |
| all_repo_id: str, | |
| ato_repo_id: str, | |
| hf_token: Optional[str] = None, | |
| max_files: int = 0, | |
| db: Optional[Session] = None, | |
| ) -> Dict: | |
| """ | |
| Process datasets: download, match, integrate, and save | |
| Args: | |
| all_repo_id: Source ALL repository ID | |
| ato_repo_id: Source ATO repository ID | |
| hf_token: Optional HF token | |
| max_files: Maximum files to process (0 = all) | |
| db: Database session for state tracking | |
| Returns: | |
| Processing result dictionary | |
| """ | |
| try: | |
| # Update state to downloading | |
| if db: | |
| state = db.query(ProcessingState).first() | |
| if not state: | |
| state = ProcessingState(status=ProcessingStatusEnum.DOWNLOADING) | |
| db.add(state) | |
| else: | |
| state.status = ProcessingStatusEnum.DOWNLOADING | |
| state.started_at = datetime.utcnow() | |
| db.commit() | |
| logger.info("Listing repository files...") | |
| all_files = await self.list_json_files(all_repo_id, hf_token) | |
| ato_files = await self.list_json_files(ato_repo_id, hf_token) | |
| logger.info(f"Found {len(all_files)} files in {all_repo_id}") | |
| logger.info(f"Found {len(ato_files)} files in {ato_repo_id}") | |
| # Match files | |
| logger.info("Matching ATO to ALL files...") | |
| match_map: Dict[str, str] = {} | |
| for ato_file in ato_files: | |
| matching_all = self.find_matching_all_file(ato_file, all_files) | |
| if matching_all: | |
| match_map[ato_file] = matching_all | |
| matched_count = len(match_map) | |
| logger.info(f"Found {matched_count} matching pairs") | |
| if db: | |
| state = db.query(ProcessingState).first() | |
| if state: | |
| state.status = ProcessingStatusEnum.MATCHING | |
| state.total_files = len(ato_files) | |
| state.matched_pairs = matched_count | |
| db.commit() | |
| # Process matched files | |
| logger.info("Processing matched files...") | |
| processed_count = 0 | |
| for ato_filename, all_filename in match_map.items(): | |
| if max_files > 0 and processed_count >= max_files: | |
| logger.info(f"Reached limit of {max_files} files") | |
| break | |
| logger.info(f"Processing: {ato_filename} <-> {all_filename}") | |
| # Download ATO file | |
| ato_path = await self.download_file( | |
| ato_repo_id, ato_filename, self.ato_raw_dir, hf_token | |
| ) | |
| if not ato_path: | |
| continue | |
| ato_data = self.load_json_file(ato_path) | |
| if not ato_data: | |
| continue | |
| # Download ALL file | |
| all_path = await self.download_file( | |
| all_repo_id, all_filename, self.all_raw_dir, hf_token | |
| ) | |
| if not all_path: | |
| continue | |
| all_data = self.load_json_file(all_path) | |
| if not all_data: | |
| continue | |
| # Integrate transcription | |
| logger.info("Integrating transcription...") | |
| all_data["transcription_content"] = ato_data | |
| all_data["transcription_content"]["full_course_name"] = all_filename | |
| # Save integrated file | |
| output_path = self.processed_files_dir / all_filename | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(output_path, "w") as f: | |
| json.dump(all_data, f, indent=4) | |
| logger.info(f"Saved integrated file to {output_path}") | |
| processed_count += 1 | |
| if db: | |
| state = db.query(ProcessingState).first() | |
| if state: | |
| state.status = ProcessingStatusEnum.INTEGRATING | |
| state.processed_files = processed_count | |
| db.commit() | |
| logger.info("Processing complete") | |
| if db: | |
| state = db.query(ProcessingState).first() | |
| if state: | |
| state.status = ProcessingStatusEnum.COMPLETED | |
| state.completed_at = datetime.utcnow() | |
| db.commit() | |
| return { | |
| "success": True, | |
| "total_files": len(ato_files), | |
| "matched_pairs": matched_count, | |
| "processed_files": processed_count, | |
| } | |
| except Exception as e: | |
| logger.error(f"Processing error: {e}") | |
| if db: | |
| state = db.query(ProcessingState).first() | |
| if state: | |
| state.status = ProcessingStatusEnum.ERROR | |
| state.error_message = str(e) | |
| state.completed_at = datetime.utcnow() | |
| db.commit() | |
| return { | |
| "success": False, | |
| "total_files": 0, | |
| "matched_pairs": 0, | |
| "processed_files": 0, | |
| "error": str(e), | |
| } | |
| def get_processed_files(self) -> List[str]: | |
| """ | |
| Get list of processed files ready for upload | |
| Returns: | |
| List of relative file paths | |
| """ | |
| files = [] | |
| def walk_dir(directory: Path, prefix: str = ""): | |
| for item in directory.iterdir(): | |
| relative_path = f"{prefix}/{item.name}" if prefix else item.name | |
| if item.is_dir(): | |
| walk_dir(item, relative_path) | |
| elif item.suffix == ".json": | |
| files.append(relative_path) | |
| walk_dir(self.processed_files_dir) | |
| return files | |
| def get_file_content(self, filename: str) -> Optional[Dict]: | |
| """ | |
| Get file content for preview | |
| Args: | |
| filename: Relative filename | |
| Returns: | |
| File content or None | |
| """ | |
| file_path = self.processed_files_dir / filename | |
| # Security: prevent directory traversal | |
| try: | |
| file_path.resolve().relative_to(self.processed_files_dir.resolve()) | |
| except ValueError: | |
| logger.warning(f"Directory traversal attempt: {filename}") | |
| return None | |
| if not file_path.exists(): | |
| return None | |
| return self.load_json_file(str(file_path)) | |
| def get_file_size(self, filename: str) -> Optional[int]: | |
| """ | |
| Get file size in bytes | |
| Args: | |
| filename: Relative filename | |
| Returns: | |
| File size or None | |
| """ | |
| file_path = self.processed_files_dir / filename | |
| try: | |
| file_path.resolve().relative_to(self.processed_files_dir.resolve()) | |
| except ValueError: | |
| return None | |
| if not file_path.exists(): | |
| return None | |
| return file_path.stat().st_size | |