|
|
|
|
|
|
|
|
""" |
|
|
Netlistify Training auf Hugging Face Spaces mit ZeroGPU. |
|
|
|
|
|
Diese Datei ist die Hauptdatei für den Hugging Face Space. |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import spaces |
|
|
import torch |
|
|
import os |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
import shutil |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_hf_token(): |
|
|
"""Lädt HF Token aus verschiedenen Quellen.""" |
|
|
|
|
|
token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") |
|
|
if token: |
|
|
return token |
|
|
|
|
|
|
|
|
try: |
|
|
from huggingface_hub import HfFolder |
|
|
token = HfFolder.get_token() |
|
|
if token: |
|
|
return token |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
try: |
|
|
token_file = Path.home() / ".huggingface" / "token" |
|
|
if token_file.exists(): |
|
|
with open(token_file, 'r') as f: |
|
|
token = f.read().strip() |
|
|
if token: |
|
|
return token |
|
|
except: |
|
|
pass |
|
|
|
|
|
return None |
|
|
|
|
|
HF_TOKEN = get_hf_token() |
|
|
|
|
|
@spaces.GPU(duration=3600) |
|
|
def train_netlistify( |
|
|
dataset_repo_id: str, |
|
|
epochs: int = 10, |
|
|
batch_size: int = 64, |
|
|
learning_rate: float = 1e-4, |
|
|
dataset_size: int = -1, |
|
|
progress=gr.Progress() |
|
|
): |
|
|
""" |
|
|
Trainiert Netlistify DETR-Modell für Verbindungserkennung mit ZeroGPU. |
|
|
|
|
|
Args: |
|
|
dataset_repo_id: Hugging Face Dataset Repository-ID |
|
|
epochs: Anzahl Training-Epochs |
|
|
batch_size: Batch-Größe |
|
|
learning_rate: Learning Rate |
|
|
dataset_size: Anzahl Bilder (-1 = alle) |
|
|
progress: Gradio Progress-Tracker |
|
|
""" |
|
|
try: |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
return "❌ Fehler: Keine GPU verfügbar. Prüfe ZeroGPU-Konfiguration." |
|
|
|
|
|
device = torch.device('cuda') |
|
|
gpu_name = torch.cuda.get_device_name(0) |
|
|
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 |
|
|
|
|
|
progress(0.05, desc=f"✅ GPU erkannt: {gpu_name} ({gpu_memory:.1f} GB)") |
|
|
|
|
|
|
|
|
try: |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
progress(0.1, desc="📥 Lade Dataset von Hugging Face...") |
|
|
|
|
|
hf_token = get_hf_token() |
|
|
|
|
|
if not hf_token: |
|
|
|
|
|
env_vars = { |
|
|
"HF_TOKEN": "❌ nicht gesetzt" if not os.getenv("HF_TOKEN") else "✅ gesetzt", |
|
|
"HUGGING_FACE_HUB_TOKEN": "❌ nicht gesetzt" if not os.getenv("HUGGING_FACE_HUB_TOKEN") else "✅ gesetzt", |
|
|
} |
|
|
|
|
|
debug_info = "\n".join([f"- {key}: {value}" for key, value in env_vars.items()]) |
|
|
|
|
|
return f"""❌ Fehler: HF_TOKEN nicht gefunden. |
|
|
|
|
|
**Prüfe folgendes:** |
|
|
|
|
|
1. **Space Settings → Secrets:** |
|
|
- Name muss exakt sein: `HF_TOKEN` (großgeschrieben, kein Leerzeichen) |
|
|
- Value: Dein Hugging Face Token (beginnt mit `hf_...`) |
|
|
- Klicke auf "Save" nach dem Hinzufügen |
|
|
|
|
|
2. **Space neu starten:** |
|
|
- Nach dem Hinzufügen des Secrets: Settings → Restart Space |
|
|
- Warte bis Status "Running" ist |
|
|
|
|
|
3. **Alternative Secret-Namen:** |
|
|
- Falls `HF_TOKEN` nicht funktioniert, versuche: `HUGGING_FACE_HUB_TOKEN` |
|
|
|
|
|
**Debug-Info (verfügbare Environment Variables):** |
|
|
{debug_info} |
|
|
|
|
|
**Hinweis:** Secrets sind erst nach einem Space-Neustart verfügbar!""" |
|
|
|
|
|
progress(0.12, desc="Authentifiziere mit Token...") |
|
|
dataset_path = snapshot_download( |
|
|
repo_id=dataset_repo_id, |
|
|
repo_type="dataset", |
|
|
local_dir="/tmp/netlistify_dataset", |
|
|
token=hf_token |
|
|
) |
|
|
progress(0.15, desc=f"✅ Dataset geladen: {dataset_path}") |
|
|
except Exception as e: |
|
|
error_msg = str(e) |
|
|
if "401" in error_msg or "gated" in error_msg.lower() or "restricted" in error_msg.lower(): |
|
|
return f"""❌ Dataset-Zugriff verweigert (401 / Gated Repository) |
|
|
|
|
|
Das Dataset ist zugriffsbeschränkt. Bitte folge diesen Schritten: |
|
|
|
|
|
1. Gehe zu: https://huggingface.co/datasets/hanky2397/schematic_images |
|
|
2. Klicke auf: "Agree and access repository" oder "Accept terms" |
|
|
3. Warte bis Zugriff gewährt wird (einige Sekunden) |
|
|
4. Prüfe Token in Space Settings → Secrets → HF_TOKEN |
|
|
5. Starte Space neu (Settings → Restart Space) |
|
|
6. Versuche Training erneut |
|
|
|
|
|
Fehlerdetails: {error_msg} |
|
|
|
|
|
Hinweis: Du musst eingeloggt sein und die Terms akzeptieren!""" |
|
|
else: |
|
|
return f"❌ Fehler beim Laden des Datasets: {error_msg}\n\nStelle sicher, dass:\n- Das Dataset auf Hugging Face hochgeladen ist\n- Die Repository-ID korrekt ist\n- Du Zugriff auf das Dataset hast" |
|
|
|
|
|
|
|
|
progress(0.2, desc="📦 Bereite Dataset vor...") |
|
|
|
|
|
|
|
|
train_dir = Path("/tmp/netlistify_train") |
|
|
train_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
dataset_base = Path(dataset_path) |
|
|
|
|
|
|
|
|
import zipfile |
|
|
zip_files = { |
|
|
"images.zip": "images", |
|
|
"components.zip": "components", |
|
|
"pkl.zip": "pkl" |
|
|
} |
|
|
|
|
|
extracted = False |
|
|
for zip_name, extract_dir in zip_files.items(): |
|
|
zip_path = dataset_base / zip_name |
|
|
if zip_path.exists(): |
|
|
progress(0.21, desc=f"📦 Entpacke {zip_name}...") |
|
|
extract_to = dataset_base / extract_dir |
|
|
extract_to.mkdir(exist_ok=True) |
|
|
try: |
|
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
|
|
|
|
|
file_list = zip_ref.namelist() |
|
|
has_nested = any('/' in f and f.split('/')[0] == extract_dir for f in file_list[:10]) |
|
|
|
|
|
if has_nested: |
|
|
|
|
|
zip_ref.extractall(dataset_base) |
|
|
else: |
|
|
|
|
|
zip_ref.extractall(extract_to) |
|
|
extracted = True |
|
|
progress(0.22, desc=f"✅ {zip_name} entpackt") |
|
|
except Exception as e: |
|
|
progress(0.22, desc=f"⚠️ Fehler beim Entpacken von {zip_name}: {e}") |
|
|
|
|
|
|
|
|
images_dir = None |
|
|
labels_dir = None |
|
|
pkl_dir = None |
|
|
|
|
|
|
|
|
debug_info = [] |
|
|
debug_info.append(f"Dataset-Pfad: {dataset_base}") |
|
|
debug_info.append(f"Verfügbare Einträge:") |
|
|
try: |
|
|
for item in sorted(dataset_base.iterdir()): |
|
|
if item.is_dir(): |
|
|
debug_info.append(f" 📁 {item.name}/") |
|
|
|
|
|
try: |
|
|
files = list(item.iterdir())[:3] |
|
|
for f in files: |
|
|
debug_info.append(f" - {f.name}") |
|
|
if len(list(item.iterdir())) > 3: |
|
|
debug_info.append(f" ... ({len(list(item.iterdir())) - 3} weitere)") |
|
|
except: |
|
|
pass |
|
|
elif item.is_file(): |
|
|
debug_info.append(f" 📄 {item.name} ({item.stat().st_size / 1024 / 1024:.1f} MB)") |
|
|
except Exception as e: |
|
|
debug_info.append(f" Fehler beim Auflisten: {e}") |
|
|
|
|
|
|
|
|
|
|
|
if (dataset_base / "images" / "images").exists(): |
|
|
images_dir = dataset_base / "images" / "images" |
|
|
|
|
|
if (dataset_base / "components" / "components").exists(): |
|
|
labels_dir = dataset_base / "components" / "components" |
|
|
elif (dataset_base / "components").exists(): |
|
|
labels_dir = dataset_base / "components" |
|
|
pkl_dir = dataset_base / "pkl" |
|
|
|
|
|
elif (dataset_base / "images").exists(): |
|
|
images_dir = dataset_base / "images" |
|
|
|
|
|
if (dataset_base / "labels").exists(): |
|
|
labels_dir = dataset_base / "labels" |
|
|
elif (dataset_base / "components").exists(): |
|
|
labels_dir = dataset_base / "components" |
|
|
pkl_dir = dataset_base / "pkl" |
|
|
|
|
|
else: |
|
|
|
|
|
jpg_files = list(dataset_base.glob("*.jpg")) |
|
|
if jpg_files: |
|
|
|
|
|
images_dir = dataset_base |
|
|
labels_dir = dataset_base |
|
|
pkl_dir = dataset_base |
|
|
|
|
|
|
|
|
if not images_dir or not images_dir.exists(): |
|
|
debug_output = "\n".join(debug_info) |
|
|
return f"""❌ Dataset-Struktur nicht erkannt. |
|
|
|
|
|
**Erwartet:** images/, labels/ (oder components/), pkl/ |
|
|
|
|
|
**Gefunden:** |
|
|
{debug_output} |
|
|
|
|
|
**Mögliche Lösungen:** |
|
|
1. Dataset muss entpackt sein oder ZIP-Dateien (images.zip, components.zip, pkl.zip) enthalten |
|
|
2. Struktur sollte sein: |
|
|
- images/ (oder images/images/) |
|
|
- labels/ oder components/ (oder components/components/) |
|
|
- pkl/ |
|
|
3. Prüfe ob ZIP-Dateien automatisch entpackt wurden""" |
|
|
|
|
|
|
|
|
if not labels_dir or not labels_dir.exists(): |
|
|
labels_dir = None |
|
|
|
|
|
|
|
|
train_images = train_dir / "images" |
|
|
train_labels = train_dir / "labels" |
|
|
train_pkl = train_dir / "pkl" |
|
|
|
|
|
train_images.mkdir(exist_ok=True) |
|
|
train_labels.mkdir(exist_ok=True) |
|
|
train_pkl.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
progress(0.25, desc="📋 Kopiere Dataset-Dateien...") |
|
|
|
|
|
img_files = list(images_dir.glob("*.jpg")) |
|
|
if dataset_size > 0: |
|
|
img_files = img_files[:dataset_size] |
|
|
|
|
|
|
|
|
def filter_invalid_labels(label_path: Path, max_class: int = 11) -> bool: |
|
|
"""Filtert ungültige Klassen aus Label-Datei und speichert bereinigte Version.""" |
|
|
try: |
|
|
with open(label_path, 'r') as f: |
|
|
lines = f.readlines() |
|
|
|
|
|
filtered_lines = [] |
|
|
invalid_count = 0 |
|
|
for line in lines: |
|
|
line = line.strip() |
|
|
if not line: |
|
|
continue |
|
|
parts = line.split() |
|
|
if len(parts) >= 5: |
|
|
cls = int(parts[0]) |
|
|
if 0 <= cls <= max_class: |
|
|
filtered_lines.append(line + '\n') |
|
|
else: |
|
|
invalid_count += 1 |
|
|
|
|
|
|
|
|
if invalid_count > 0 or len(filtered_lines) != len(lines): |
|
|
with open(label_path, 'w') as f: |
|
|
f.writelines(filtered_lines) |
|
|
return True |
|
|
return False |
|
|
except Exception as e: |
|
|
|
|
|
return False |
|
|
|
|
|
for i, img_file in enumerate(img_files): |
|
|
if i % 100 == 0: |
|
|
progress(0.25 + (i / len(img_files)) * 0.1, desc=f"Kopiere Bilder: {i}/{len(img_files)}") |
|
|
shutil.copy2(img_file, train_images / img_file.name) |
|
|
|
|
|
|
|
|
if labels_dir: |
|
|
label_file = labels_dir / img_file.name.replace(".jpg", ".txt") |
|
|
if label_file.exists(): |
|
|
|
|
|
dest_label = train_labels / label_file.name |
|
|
shutil.copy2(label_file, dest_label) |
|
|
|
|
|
|
|
|
filter_invalid_labels(dest_label, max_class=11) |
|
|
|
|
|
|
|
|
if pkl_dir and pkl_dir.exists(): |
|
|
pkl_file = pkl_dir / img_file.name.replace(".jpg", ".pkl") |
|
|
if pkl_file.exists(): |
|
|
shutil.copy2(pkl_file, train_pkl / pkl_file.name) |
|
|
|
|
|
progress(0.4, desc=f"✅ Dataset vorbereitet: {len(img_files)} Bilder") |
|
|
|
|
|
|
|
|
progress(0.45, desc="🔧 Lade Netlistify-Module...") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
try: |
|
|
import google.protobuf.internal.api_implementation as api_impl |
|
|
if not hasattr(api_impl, '_c_module'): |
|
|
|
|
|
api_impl._c_module = None |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
def install_missing_package(package_name): |
|
|
"""Installiert ein fehlendes Paket.""" |
|
|
try: |
|
|
import subprocess |
|
|
import sys |
|
|
result = subprocess.run( |
|
|
[sys.executable, "-m", "pip", "install", package_name, "--quiet", "--upgrade"], |
|
|
capture_output=True, |
|
|
text=True, |
|
|
timeout=120 |
|
|
) |
|
|
return result.returncode == 0 |
|
|
except: |
|
|
return False |
|
|
|
|
|
|
|
|
try: |
|
|
import google.protobuf |
|
|
except ImportError: |
|
|
progress(0.455, desc="📦 Installiere protobuf...") |
|
|
install_missing_package("protobuf>=3.20.0,<5.0.0") |
|
|
else: |
|
|
|
|
|
try: |
|
|
import google.protobuf |
|
|
protobuf_version = google.protobuf.__version__ |
|
|
|
|
|
major, minor = map(int, protobuf_version.split('.')[:2]) |
|
|
if major < 3 or (major == 3 and minor < 20) or major >= 5: |
|
|
progress(0.455, desc="📦 Upgrade protobuf für Kompatibilität...") |
|
|
install_missing_package("protobuf>=3.20.0,<5.0.0") |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
critical_packages = [ |
|
|
("einops", "einops"), |
|
|
("transformers", "transformers"), |
|
|
("timm", "timm"), |
|
|
("aenum", "aenum"), |
|
|
("ipyplot", "ipyplot"), |
|
|
("ipython", "IPython"), |
|
|
("networkx", "networkx"), |
|
|
("pandas", "pandas"), |
|
|
("p-tqdm", "p_tqdm"), |
|
|
("plotly", "plotly"), |
|
|
("natsort", "natsort"), |
|
|
("numba", "numba"), |
|
|
("rich", "rich"), |
|
|
("scoping", "scoping"), |
|
|
("tabulate", "tabulate"), |
|
|
("torchinfo", "torchinfo"), |
|
|
("torchmetrics", "torchmetrics"), |
|
|
("scikit-learn", "sklearn"), |
|
|
("wandb", "wandb"), |
|
|
("seaborn", "seaborn"), |
|
|
("pypalettes", "pypalettes"), |
|
|
("tensorboard", "tensorboard"), |
|
|
("tensorboardx", "tensorboardX"), |
|
|
("protobuf", "google.protobuf") |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
additional_packages = [ |
|
|
("aiohttp", "aiohttp"), |
|
|
("bokeh", "bokeh"), |
|
|
("diffusers", "diffusers"), |
|
|
("levenshtein", "Levenshtein"), |
|
|
("ortools", "ortools"), |
|
|
("peft", "peft"), |
|
|
("pytesseract", "pytesseract"), |
|
|
("bytecode", "bytecode"), |
|
|
("ftfy", "ftfy"), |
|
|
("imagesize", "imagesize"), |
|
|
("importlib-resources", "importlib_resources"), |
|
|
("typing-utils", "typing_utils"), |
|
|
("universal-pathlib", "universal_pathlib"), |
|
|
("ipython-genutils", "IPython"), |
|
|
("cached-property", "cached_property") |
|
|
] |
|
|
|
|
|
|
|
|
all_packages = critical_packages + additional_packages |
|
|
|
|
|
|
|
|
missing_packages = [] |
|
|
for pip_name, import_name in all_packages: |
|
|
try: |
|
|
__import__(import_name) |
|
|
except ImportError: |
|
|
missing_packages.append(pip_name) |
|
|
|
|
|
if missing_packages: |
|
|
progress(0.46, desc=f"📦 Installiere {len(missing_packages)} fehlende Pakete...") |
|
|
installed_count = 0 |
|
|
failed_packages = [] |
|
|
|
|
|
for i, package in enumerate(missing_packages): |
|
|
progress(0.46 + (i / len(missing_packages)) * 0.02, |
|
|
desc=f"📦 Installiere {package} ({i+1}/{len(missing_packages)})...") |
|
|
if install_missing_package(package): |
|
|
installed_count += 1 |
|
|
else: |
|
|
failed_packages.append(package) |
|
|
|
|
|
if failed_packages: |
|
|
progress(0.48, desc=f"⚠️ {len(failed_packages)} Pakete konnten nicht installiert werden") |
|
|
else: |
|
|
progress(0.48, desc=f"✅ Alle {len(missing_packages)} Pakete installiert") |
|
|
|
|
|
|
|
|
|
|
|
netlistify_dir = Path("/tmp/Netlistify") |
|
|
|
|
|
|
|
|
main_config_file = netlistify_dir / "main_config.py" |
|
|
if not main_config_file.exists(): |
|
|
import subprocess |
|
|
progress(0.46, desc="📥 Klone Netlistify von GitHub...") |
|
|
|
|
|
|
|
|
if netlistify_dir.exists() and not any(netlistify_dir.iterdir()): |
|
|
shutil.rmtree(netlistify_dir) |
|
|
|
|
|
|
|
|
result = subprocess.run([ |
|
|
"git", "clone", |
|
|
"https://github.com/NYCU-AI-EDA/Netlistify.git", |
|
|
str(netlistify_dir) |
|
|
], capture_output=True, text=True, timeout=300) |
|
|
|
|
|
if result.returncode != 0: |
|
|
error_msg = result.stderr or result.stdout or "Unbekannter Fehler" |
|
|
return f"""❌ Fehler beim Klonen von Netlistify: |
|
|
|
|
|
**Git-Output:** |
|
|
{error_msg} |
|
|
|
|
|
**Mögliche Lösungen:** |
|
|
1. Prüfe Internet-Verbindung |
|
|
2. Prüfe ob GitHub erreichbar ist |
|
|
3. Versuche Training erneut (Repository wird beim nächsten Versuch geklont)""" |
|
|
|
|
|
progress(0.47, desc="✅ Netlistify geklont") |
|
|
|
|
|
|
|
|
slice_file = netlistify_dir / "slice.py" |
|
|
if slice_file.exists(): |
|
|
try: |
|
|
with open(slice_file, 'r', encoding='utf-8') as f: |
|
|
slice_content = f.read() |
|
|
|
|
|
|
|
|
if 'if cls not in class_label_real:' not in slice_content: |
|
|
|
|
|
|
|
|
old_pattern = 'if config == DatasetConfig.REAL:\n if class_label_real[cls] == "text":' |
|
|
new_pattern = 'if config == DatasetConfig.REAL:\n if cls not in class_label_real:\n continue # Überspringe ungültige Klassen\n if class_label_real[cls] == "text":' |
|
|
|
|
|
if old_pattern in slice_content: |
|
|
slice_content = slice_content.replace(old_pattern, new_pattern) |
|
|
else: |
|
|
|
|
|
import re |
|
|
|
|
|
slice_content = re.sub( |
|
|
r'class_label_real\[cls\]', |
|
|
r'class_label_real.get(cls, None)', |
|
|
slice_content |
|
|
) |
|
|
|
|
|
slice_content = re.sub( |
|
|
r'(\s+)if config == DatasetConfig\.REAL:', |
|
|
r'\1if config == DatasetConfig.REAL:\n\1 if cls not in class_label_real:\n\1 continue # Überspringe ungültige Klassen', |
|
|
slice_content |
|
|
) |
|
|
|
|
|
with open(slice_file, 'w', encoding='utf-8') as f: |
|
|
f.write(slice_content) |
|
|
progress(0.471, desc="🔧 slice.py gepatcht (ungültige Klassen werden übersprungen)") |
|
|
except Exception as e: |
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
if not main_config_file.exists(): |
|
|
return f"""❌ Netlistify-Repository unvollständig. |
|
|
|
|
|
**Erwartet:** {main_config_file} |
|
|
**Gefunden:** Verzeichnis existiert, aber main_config.py fehlt |
|
|
|
|
|
**Mögliche Lösungen:** |
|
|
1. Prüfe ob Repository korrekt geklont wurde |
|
|
2. Prüfe ob main_config.py im Repository existiert |
|
|
3. Versuche Training erneut""" |
|
|
|
|
|
|
|
|
netlistify_str = str(netlistify_dir) |
|
|
if netlistify_str not in sys.path: |
|
|
sys.path.insert(0, netlistify_str) |
|
|
|
|
|
progress(0.48, desc="📦 Importiere Netlistify-Module...") |
|
|
|
|
|
|
|
|
try: |
|
|
import main_config |
|
|
except ImportError as e: |
|
|
return f"""❌ Fehler beim Import von main_config: |
|
|
|
|
|
**Fehler:** {e} |
|
|
**Python-Pfad:** {sys.path[:3]} |
|
|
**Netlistify-Verzeichnis:** {netlistify_dir} |
|
|
**main_config.py existiert:** {main_config_file.exists()} |
|
|
|
|
|
**Mögliche Lösungen:** |
|
|
1. Prüfe ob Netlistify korrekt geklont wurde |
|
|
2. Prüfe ob alle Abhängigkeiten installiert sind |
|
|
3. Versuche Training erneut""" |
|
|
|
|
|
|
|
|
from main import main as train_main, FormalDatasetWindowedLinePair |
|
|
from Model import Model |
|
|
from slice import load_data |
|
|
|
|
|
|
|
|
main_config.REAL_DATA = True |
|
|
main_config.DATASET_PATH = str(train_dir) |
|
|
main_config.DATASET_SIZE = len(img_files) if dataset_size < 0 else dataset_size |
|
|
main_config.EPOCHS = epochs |
|
|
main_config.BATCH_SIZE = batch_size |
|
|
main_config.LEARNING_RATE = learning_rate |
|
|
main_config.DEVICE_IDS = [0] |
|
|
main_config.EVAL = False |
|
|
main_config.SMALL_IMAGE = True |
|
|
|
|
|
progress(0.5, desc="🚀 Starte Training...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from main import create_model, xtransform, ytransform, criterion, eval_metrics, FormalDatasetWindowedLinePair |
|
|
|
|
|
progress(0.55, desc="🏗️ Erstelle Modell...") |
|
|
network = create_model() |
|
|
|
|
|
progress(0.6, desc="📊 Lade Dataset...") |
|
|
dataset = FormalDatasetWindowedLinePair( |
|
|
main_config.DATASET_SIZE, |
|
|
main_config.DATASET_PATH, |
|
|
main_config.PICK, |
|
|
not main_config.SMALL_IMAGE, |
|
|
direction=main_config.DIRECTION, |
|
|
) |
|
|
|
|
|
progress(0.65, desc="🎯 Initialisiere Training...") |
|
|
model = Model( |
|
|
dataset, |
|
|
None, |
|
|
xtransform=xtransform, |
|
|
ytransform=ytransform, |
|
|
amp=False, |
|
|
batch_size=main_config.BATCH_SIZE, |
|
|
eval=False, |
|
|
shuffle=True, |
|
|
) |
|
|
|
|
|
progress(0.7, desc="🔥 Training läuft...") |
|
|
|
|
|
|
|
|
import torch.optim as optim |
|
|
|
|
|
|
|
|
training_completed = False |
|
|
actual_epochs_completed = 0 |
|
|
training_error = None |
|
|
|
|
|
def training_epoch_end_callback(): |
|
|
"""Callback der nach jeder Epoch aufgerufen wird.""" |
|
|
nonlocal actual_epochs_completed, training_completed |
|
|
|
|
|
current_epoch = getattr(model, 'ep', actual_epochs_completed) |
|
|
actual_epochs_completed = current_epoch |
|
|
progress_value = 0.7 + (current_epoch / epochs) * 0.25 |
|
|
desc = f"🔥 Epoch {current_epoch}/{epochs}" |
|
|
progress(progress_value, desc=desc) |
|
|
|
|
|
|
|
|
if current_epoch >= epochs: |
|
|
training_completed = True |
|
|
|
|
|
|
|
|
try: |
|
|
model.fit( |
|
|
network, |
|
|
criterion, |
|
|
optim.Adam(network.parameters(), lr=main_config.LEARNING_RATE), |
|
|
epochs, |
|
|
max_epochs=float("inf"), |
|
|
pretrained_path=main_config.PRETRAINED_PATH, |
|
|
keep=True, |
|
|
backprop_freq=main_config.BATCH_STEP, |
|
|
device_ids=main_config.DEVICE_IDS, |
|
|
eval_metrics=eval_metrics, |
|
|
keep_epoch=main_config.KEEP_EPOCH, |
|
|
keep_optimizer=main_config.KEEP_OPTIMIZER, |
|
|
config=None, |
|
|
upload=False, |
|
|
flush_cache_after_step=main_config.FLUSH_CACHE_AFTER_STEP, |
|
|
training_epoch_end=training_epoch_end_callback, |
|
|
) |
|
|
training_completed = True |
|
|
except Exception as e: |
|
|
training_error = str(e) |
|
|
import traceback |
|
|
training_error += f"\n\n{traceback.format_exc()}" |
|
|
|
|
|
progress(0.95, desc="💾 Speichere Modell...") |
|
|
|
|
|
|
|
|
model_path = Path("/tmp/models") |
|
|
model_path.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
model_saved = False |
|
|
best_model_path = None |
|
|
|
|
|
|
|
|
runs_dir = netlistify_dir / "runs" / "FormalDatasetWindowedLinePair" |
|
|
if runs_dir.exists() and runs_dir.is_dir(): |
|
|
try: |
|
|
run_dirs = [d for d in runs_dir.iterdir() if d.is_dir()] |
|
|
if run_dirs: |
|
|
latest_run = max(run_dirs, key=lambda x: x.stat().st_mtime) |
|
|
best_model = latest_run / "best_train.pth" |
|
|
if best_model.exists(): |
|
|
best_model_path = model_path / "best_model.pth" |
|
|
shutil.copy2(best_model, best_model_path) |
|
|
model_saved = True |
|
|
|
|
|
|
|
|
latest_model = latest_run / "latest.pth" |
|
|
if latest_model.exists(): |
|
|
shutil.copy2(latest_model, model_path / "latest_model.pth") |
|
|
except Exception as e: |
|
|
pass |
|
|
|
|
|
progress(1.0, desc="✅ Training abgeschlossen!") |
|
|
|
|
|
|
|
|
status_lines = [] |
|
|
|
|
|
if training_error: |
|
|
status_lines.append("❌ **Training mit Fehler beendet:**") |
|
|
status_lines.append(f"```\n{training_error}\n```") |
|
|
elif training_completed: |
|
|
status_lines.append("✅ **Training erfolgreich abgeschlossen!**") |
|
|
else: |
|
|
status_lines.append("⚠️ **Training-Status unklar**") |
|
|
|
|
|
status_lines.append("") |
|
|
status_lines.append("📊 **Training-Details:**") |
|
|
status_lines.append(f"- GPU: {gpu_name} ({gpu_memory:.1f} GB)") |
|
|
status_lines.append(f"- Geplante Epochs: {epochs}") |
|
|
status_lines.append(f"- Abgeschlossene Epochs: {actual_epochs_completed}") |
|
|
status_lines.append(f"- Batch Size: {batch_size}") |
|
|
status_lines.append(f"- Learning Rate: {learning_rate}") |
|
|
status_lines.append(f"- Dataset-Größe: {len(img_files)} Bilder") |
|
|
status_lines.append("") |
|
|
|
|
|
if model_saved: |
|
|
status_lines.append("💾 **Modell gespeichert:**") |
|
|
status_lines.append(f"- Pfad: {model_path}") |
|
|
status_lines.append(f"- Bestes Modell: best_model.pth") |
|
|
if best_model_path and best_model_path.exists(): |
|
|
file_size = best_model_path.stat().st_size / (1024 * 1024) |
|
|
status_lines.append(f"- Dateigröße: {file_size:.2f} MB") |
|
|
else: |
|
|
status_lines.append("⚠️ **Modell nicht gefunden:**") |
|
|
status_lines.append(f"- Erwarteter Pfad: {runs_dir}") |
|
|
status_lines.append("- Prüfe Logs für Details") |
|
|
|
|
|
status_lines.append("") |
|
|
|
|
|
if training_completed and model_saved: |
|
|
status_lines.append("📁 **Nächste Schritte:**") |
|
|
status_lines.append("1. Lade das trainierte Modell herunter") |
|
|
status_lines.append("2. Verwende es für Inference in deiner Anwendung") |
|
|
elif not training_completed: |
|
|
status_lines.append("⚠️ **Hinweis:** Training wurde möglicherweise nicht vollständig abgeschlossen.") |
|
|
status_lines.append("- Prüfe die Logs für weitere Details") |
|
|
status_lines.append("- Versuche Training erneut zu starten") |
|
|
|
|
|
return "\n".join(status_lines) |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = f"❌ Fehler beim Training: {e}\n\n{traceback.format_exc()}" |
|
|
return error_msg |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = f"❌ Fehler: {e}\n\n{traceback.format_exc()}" |
|
|
return error_msg |
|
|
|
|
|
|
|
|
def check_gpu_status(): |
|
|
"""Prüft GPU-Status.""" |
|
|
try: |
|
|
if torch.cuda.is_available(): |
|
|
gpu_name = torch.cuda.get_device_name(0) |
|
|
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 |
|
|
return f"✅ GPU verfügbar: {gpu_name} ({gpu_memory:.1f} GB)" |
|
|
else: |
|
|
return "❌ Keine GPU verfügbar. Prüfe ZeroGPU-Konfiguration." |
|
|
except: |
|
|
return "⚠️ GPU-Status kann nicht geprüft werden (normal wenn keine GPU aktiv)" |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Netlistify Training mit ZeroGPU") as app: |
|
|
gr.Markdown(""" |
|
|
# 🔥 Netlistify Training mit ZeroGPU |
|
|
|
|
|
Trainiert Netlistify DETR-Modell für Verbindungserkennung auf Hugging Face Spaces mit ZeroGPU. |
|
|
|
|
|
**Voraussetzungen:** |
|
|
- Dataset auf Hugging Face hochgeladen (als Dataset Repository) |
|
|
- ZeroGPU Hardware aktiviert |
|
|
- Repository-ID des Datasets |
|
|
""") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
gpu_status = gr.Textbox( |
|
|
label="GPU-Status", |
|
|
value=check_gpu_status(), |
|
|
interactive=False |
|
|
) |
|
|
refresh_btn = gr.Button("🔄 Status aktualisieren") |
|
|
refresh_btn.click(fn=check_gpu_status, outputs=gpu_status) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
dataset_repo = gr.Textbox( |
|
|
label="Dataset Repository-ID", |
|
|
placeholder="username/netlistify-dataset", |
|
|
value="hanky2397/schematic_images", |
|
|
info="Hugging Face Dataset Repository (z.B. hanky2397/schematic_images)" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
epochs = gr.Number( |
|
|
label="Epochs", |
|
|
value=10, |
|
|
minimum=1, |
|
|
maximum=1000, |
|
|
info="Anzahl Training-Epochs" |
|
|
) |
|
|
batch_size = gr.Number( |
|
|
label="Batch Size", |
|
|
value=64, |
|
|
minimum=1, |
|
|
maximum=256, |
|
|
info="Batch-Größe" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
learning_rate = gr.Number( |
|
|
label="Learning Rate", |
|
|
value=1e-4, |
|
|
minimum=1e-6, |
|
|
maximum=1e-1, |
|
|
info="Learning Rate" |
|
|
) |
|
|
dataset_size = gr.Number( |
|
|
label="Dataset-Größe", |
|
|
value=-1, |
|
|
minimum=-1, |
|
|
maximum=100000, |
|
|
info="-1 = alle Bilder, sonst Anzahl" |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
train_btn = gr.Button( |
|
|
"🚀 Training starten", |
|
|
variant="primary", |
|
|
size="lg" |
|
|
) |
|
|
output = gr.Textbox( |
|
|
label="Training-Status", |
|
|
lines=15, |
|
|
max_lines=30 |
|
|
) |
|
|
|
|
|
train_btn.click( |
|
|
fn=train_netlistify, |
|
|
inputs=[dataset_repo, epochs, batch_size, learning_rate, dataset_size], |
|
|
outputs=output |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
## 📝 Hinweise |
|
|
|
|
|
- **ZeroGPU**: GPU wird automatisch zugewiesen wenn Training startet |
|
|
- **Dauer**: Standard-Limit ist 60 Sekunden, wurde auf 1 Stunde (3600 Sekunden) erhöht |
|
|
- **Checkpoints**: Modelle werden automatisch gespeichert |
|
|
- **Dataset**: Muss vorher auf Hugging Face hochgeladen werden |
|
|
|
|
|
## 🔗 Links |
|
|
|
|
|
- [Netlistify GitHub](https://github.com/NYCU-AI-EDA/Netlistify) |
|
|
- [Dataset auf Hugging Face](https://huggingface.co/datasets/hanky2397/schematic_images) |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.launch() |
|
|
|
|
|
|