|
|
|
|
|
""" |
|
|
Script to upload AdaFortiTran repository to Hugging Face. |
|
|
This script prepares the repository for Hugging Face upload with minimal changes. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import shutil |
|
|
import subprocess |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from typing import List, Optional |
|
|
|
|
|
def check_huggingface_hub_installed() -> bool: |
|
|
"""Check if huggingface_hub is installed.""" |
|
|
try: |
|
|
import huggingface_hub |
|
|
return True |
|
|
except ImportError: |
|
|
return False |
|
|
|
|
|
|
|
|
def create_huggingface_files(repo_path: Path): |
|
|
"""Create necessary files for Hugging Face upload.""" |
|
|
|
|
|
|
|
|
gitattributes_content = """*.mat filter=lfs diff=lfs merge=lfs -text |
|
|
*.pth filter=lfs diff=lfs merge=lfs -text |
|
|
*.ckpt filter=lfs diff=lfs merge=lfs -text |
|
|
*.pt filter=lfs diff=lfs merge=lfs -text |
|
|
*.bin filter=lfs diff=lfs merge=lfs -text |
|
|
""" |
|
|
|
|
|
gitattributes_path = repo_path / ".gitattributes" |
|
|
if not gitattributes_path.exists(): |
|
|
with open(gitattributes_path, 'w') as f: |
|
|
f.write(gitattributes_content) |
|
|
print(f" Created {gitattributes_path}") |
|
|
|
|
|
|
|
|
huggingfaceignore_content = """# Ignore large data files during upload |
|
|
data/train/ |
|
|
data/val/ |
|
|
data/test/ |
|
|
|
|
|
# Ignore model checkpoints and logs |
|
|
*.ckpt |
|
|
*.pth |
|
|
*.pt |
|
|
logs/ |
|
|
runs/ |
|
|
checkpoints/ |
|
|
|
|
|
# Ignore temporary files |
|
|
__pycache__/ |
|
|
*.pyc |
|
|
*.pyo |
|
|
*.pyd |
|
|
.Python |
|
|
*.so |
|
|
.DS_Store |
|
|
Thumbs.db |
|
|
|
|
|
# Ignore IDE files |
|
|
.vscode/ |
|
|
.idea/ |
|
|
*.swp |
|
|
*.swo |
|
|
|
|
|
# Ignore environment files |
|
|
.env |
|
|
.venv/ |
|
|
venv/ |
|
|
env/ |
|
|
""" |
|
|
|
|
|
huggingfaceignore_path = repo_path / ".huggingfaceignore" |
|
|
if not huggingfaceignore_path.exists(): |
|
|
with open(huggingfaceignore_path, 'w') as f: |
|
|
f.write(huggingfaceignore_content) |
|
|
print(f" Created {huggingfaceignore_path}") |
|
|
|
|
|
def create_model_card(repo_path: Path): |
|
|
"""Create a model card for Hugging Face.""" |
|
|
|
|
|
model_card_content = """--- |
|
|
language: |
|
|
- en |
|
|
tags: |
|
|
- pytorch |
|
|
- transformer |
|
|
- channel-estimation |
|
|
- ofdm |
|
|
- wireless |
|
|
- adaptive |
|
|
license: mit |
|
|
datasets: |
|
|
- custom |
|
|
metrics: |
|
|
- mse |
|
|
--- |
|
|
|
|
|
# AdaFortiTran: Adaptive Transformer Model for Robust OFDM Channel Estimation |
|
|
|
|
|
## Model Description |
|
|
|
|
|
AdaFortiTran is a novel adaptive transformer-based model for OFDM channel estimation that dynamically adapts to varying channel conditions (SNR, delay spread, Doppler shift). The model combines the power of transformer architectures with channel-aware adaptation mechanisms to achieve robust performance across diverse wireless environments. |
|
|
|
|
|
## Key Features |
|
|
|
|
|
- **π Adaptive Architecture**: Dynamically adapts to channel conditions using meta-information |
|
|
- **β‘ High Performance**: State-of-the-art results on OFDM channel estimation tasks |
|
|
- **π§ Transformer-Based**: Leverages attention mechanisms for long-range dependencies |
|
|
- **π― Robust**: Maintains performance across varying SNR, delay spread, and Doppler conditions |
|
|
- **π Production Ready**: Comprehensive training pipeline with advanced features |
|
|
|
|
|
## Architecture |
|
|
|
|
|
The project implements three model variants: |
|
|
|
|
|
1. **Linear Estimator**: Simple learned linear transformation baseline |
|
|
2. **FortiTran**: Fixed transformer-based channel estimator |
|
|
3. **AdaFortiTran**: Adaptive transformer with channel condition awareness |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Installation |
|
|
|
|
|
```bash |
|
|
pip install -r requirements.txt |
|
|
``` |
|
|
|
|
|
### Training |
|
|
|
|
|
```bash |
|
|
python src/main.py \ |
|
|
--model_name adafortitran \ |
|
|
--system_config_path config/system_config.yaml \ |
|
|
--model_config_path config/adafortitran.yaml \ |
|
|
--train_set data/train \ |
|
|
--val_set data/val \ |
|
|
--test_set data/test \ |
|
|
--exp_id my_experiment |
|
|
``` |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this model in your research, please cite: |
|
|
|
|
|
```bibtex |
|
|
@misc{guler2025adafortitranadaptivetransformermodel, |
|
|
title={AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation}, |
|
|
author={Berkay Guler and Hamid Jafarkhani}, |
|
|
year={2025}, |
|
|
eprint={2505.09076}, |
|
|
archivePrefix={arXiv}, |
|
|
primaryClass={cs.LG}, |
|
|
url={https://arxiv.org/abs/2505.09076}, |
|
|
} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. |
|
|
""" |
|
|
|
|
|
model_card_path = repo_path / "README.md" |
|
|
if model_card_path.exists(): |
|
|
|
|
|
backup_path = repo_path / "README_original.md" |
|
|
if not backup_path.exists(): |
|
|
shutil.copy2(model_card_path, backup_path) |
|
|
print(f" Backed up original README to {backup_path}") |
|
|
|
|
|
with open(model_card_path, 'w') as f: |
|
|
f.write(model_card_content) |
|
|
print(f" Updated {model_card_path} for Hugging Face") |
|
|
|
|
|
def cleanup_generated_files(repo_path: Path): |
|
|
"""Remove files generated for Hugging Face upload.""" |
|
|
print("\nCleaning up generated files...") |
|
|
|
|
|
files_to_remove = [ |
|
|
".gitattributes", |
|
|
".huggingfaceignore" |
|
|
] |
|
|
|
|
|
for file_name in files_to_remove: |
|
|
file_path = repo_path / file_name |
|
|
if file_path.exists(): |
|
|
file_path.unlink() |
|
|
print(f" Removed {file_path}") |
|
|
|
|
|
|
|
|
backup_path = repo_path / "README_original.md" |
|
|
readme_path = repo_path / "README.md" |
|
|
|
|
|
if backup_path.exists(): |
|
|
shutil.copy2(backup_path, readme_path) |
|
|
backup_path.unlink() |
|
|
print(f" Restored original README.md") |
|
|
|
|
|
|
|
|
try: |
|
|
result = subprocess.run(["git", "remote", "get-url", "origin"], |
|
|
capture_output=True, text=True, check=False) |
|
|
if "huggingface.co" in result.stdout: |
|
|
subprocess.run(["git", "remote", "remove", "origin"], check=False) |
|
|
print(" Removed Hugging Face remote") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
print("β
Cleanup completed") |
|
|
|
|
|
def upload_to_huggingface(repo_path: Path, repo_name: str, private: bool = False): |
|
|
"""Upload the repository to Hugging Face.""" |
|
|
|
|
|
try: |
|
|
from huggingface_hub import HfApi, create_repo |
|
|
|
|
|
api = HfApi() |
|
|
|
|
|
|
|
|
try: |
|
|
user_info = api.whoami() |
|
|
username = user_info['name'] |
|
|
print(f"β
Logged in as: {username}") |
|
|
except Exception: |
|
|
print("β Not logged in to Hugging Face") |
|
|
print("Please run: huggingface-cli login") |
|
|
return False |
|
|
|
|
|
|
|
|
repo_id = f"{username}/{repo_name}" |
|
|
print(f"Creating repository: {repo_id}") |
|
|
|
|
|
try: |
|
|
create_repo( |
|
|
repo_id=repo_id, |
|
|
repo_type="model", |
|
|
private=private, |
|
|
exist_ok=True |
|
|
) |
|
|
print(f"β
Repository created/updated: {repo_id}") |
|
|
except Exception as e: |
|
|
print(f"β Failed to create repository: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
print("Uploading files to Hugging Face...") |
|
|
|
|
|
|
|
|
os.chdir(repo_path) |
|
|
|
|
|
|
|
|
if not (repo_path / ".git").exists(): |
|
|
subprocess.run(["git", "init"], check=True) |
|
|
print(" Initialized git repository") |
|
|
|
|
|
|
|
|
subprocess.run(["git", "add", "."], check=True) |
|
|
print(" Added files to git") |
|
|
|
|
|
|
|
|
result = subprocess.run(["git", "status", "--porcelain"], |
|
|
capture_output=True, text=True, check=True) |
|
|
|
|
|
if result.stdout.strip(): |
|
|
|
|
|
subprocess.run(["git", "commit", "-m", "Initial commit for Hugging Face"], check=True) |
|
|
print(" Committed changes") |
|
|
else: |
|
|
|
|
|
print(" No changes to commit (working tree clean)") |
|
|
|
|
|
|
|
|
try: |
|
|
from huggingface_hub import HfFolder |
|
|
token = HfFolder.get_token() |
|
|
if not token: |
|
|
print("β No Hugging Face token found") |
|
|
print("Please run: huggingface-cli login") |
|
|
return False |
|
|
except Exception: |
|
|
print("β Failed to get Hugging Face token") |
|
|
print("Please run: huggingface-cli login") |
|
|
return False |
|
|
|
|
|
|
|
|
remote_url = f"https://{username}:{token}@huggingface.co/{repo_id}" |
|
|
|
|
|
|
|
|
result = subprocess.run(["git", "remote", "get-url", "origin"], |
|
|
capture_output=True, text=True, check=False) |
|
|
|
|
|
if result.returncode == 0: |
|
|
|
|
|
subprocess.run(["git", "remote", "set-url", "origin", remote_url], check=True) |
|
|
print(f" Updated remote: {repo_id}") |
|
|
else: |
|
|
|
|
|
subprocess.run(["git", "remote", "add", "origin", remote_url], check=True) |
|
|
print(f" Added remote: {repo_id}") |
|
|
|
|
|
|
|
|
result = subprocess.run(["git", "branch", "--show-current"], |
|
|
capture_output=True, text=True, check=True) |
|
|
current_branch = result.stdout.strip() |
|
|
|
|
|
|
|
|
subprocess.run(["git", "push", "-u", "origin", current_branch], check=True) |
|
|
print(f" Pushed to Hugging Face (branch: {current_branch})") |
|
|
|
|
|
print(f"\nπ Successfully uploaded to: https://huggingface.co/{repo_id}") |
|
|
return True |
|
|
|
|
|
except ImportError: |
|
|
print("β huggingface_hub not available") |
|
|
return False |
|
|
except Exception as e: |
|
|
print(f"β Upload failed: {e}") |
|
|
return False |
|
|
|
|
|
def main(): |
|
|
"""Main function to handle the upload process.""" |
|
|
|
|
|
repo_path = Path.cwd() |
|
|
print(f"Preparing AdaFortiTran repository for Hugging Face upload") |
|
|
print(f"Repository path: {repo_path.absolute()}") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
if not (repo_path / "src" / "models" / "adafortitran.py").exists(): |
|
|
print("β Error: Please run this script from the AdaFortiTran root directory") |
|
|
return |
|
|
|
|
|
|
|
|
if not check_huggingface_hub_installed(): |
|
|
print("β huggingface_hub is not installed") |
|
|
print("Please install it manually: pip install huggingface_hub") |
|
|
return |
|
|
|
|
|
|
|
|
repo_name = input("Enter repository name for Hugging Face (default: adafortitran): ").strip() |
|
|
if not repo_name: |
|
|
repo_name = "adafortitran" |
|
|
|
|
|
|
|
|
private_input = input("Make repository private? (y/N): ").strip().lower() |
|
|
private = private_input in ['y', 'yes'] |
|
|
|
|
|
print("\nPreparing repository...") |
|
|
|
|
|
|
|
|
create_huggingface_files(repo_path) |
|
|
create_model_card(repo_path) |
|
|
|
|
|
print("\nUploading to Hugging Face...") |
|
|
|
|
|
|
|
|
if upload_to_huggingface(repo_path, repo_name, private): |
|
|
print("\nβ
Upload completed successfully!") |
|
|
print(f"π View your repository at: https://huggingface.co/{repo_name}") |
|
|
|
|
|
|
|
|
cleanup_input = input("\nRemove generated files and restore original state? (Y/n): ").strip().lower() |
|
|
if cleanup_input not in ['n', 'no']: |
|
|
cleanup_generated_files(repo_path) |
|
|
else: |
|
|
print("Generated files kept for future uploads") |
|
|
else: |
|
|
print("\nβ Upload failed. Please check the error messages above.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |