Spaces:
Running
Running
| """ | |
| import_projects.py β First-boot script. | |
| Reads projects/*.json and imports them into the fresh Label Studio instance. | |
| Runs once; guarded by /data/ls/.initialized flag in start.sh. | |
| """ | |
| import json, os, re, time, requests | |
| LS = "http://localhost:8080" | |
| PROJ_DIR = "/app/projects" | |
| ADMIN_EMAIL = os.environ.get("LABEL_STUDIO_USERNAME", "admin@saga.is") | |
| ADMIN_PASSWORD = "12345" # must match ensure_admin.py hardcoded value | |
| def get_session(): | |
| """Login via web form and return an authenticated requests.Session.""" | |
| for attempt in range(20): | |
| try: | |
| session = requests.Session() | |
| r = session.get(f"{LS}/user/login/", timeout=10) | |
| if r.status_code != 200: | |
| raise RuntimeError(f"login page {r.status_code}") | |
| csrf = session.cookies.get("csrftoken", "") | |
| m = (re.search(r'name="csrfmiddlewaretoken"[^>]*value="([^"]+)"', r.text) or | |
| re.search(r'value="([^"]+)"[^>]*name="csrfmiddlewaretoken"', r.text)) | |
| if m: | |
| csrf = m.group(1) | |
| r2 = session.post( | |
| f"{LS}/user/login/", | |
| data={"email": ADMIN_EMAIL, "password": ADMIN_PASSWORD, | |
| "csrfmiddlewaretoken": csrf}, | |
| headers={"Referer": f"{LS}/user/login/"}, | |
| timeout=10, | |
| allow_redirects=True, | |
| ) | |
| # Successful login redirects to / (200 after redirect) | |
| if r2.status_code == 200 and "/user/login/" not in r2.url: | |
| # Refresh CSRF cookie after login | |
| session.headers.update({ | |
| "X-CSRFToken": session.cookies.get("csrftoken", csrf), | |
| "Referer": LS, | |
| }) | |
| return session | |
| raise RuntimeError(f"login returned {r2.status_code}, url={r2.url}") | |
| except Exception as e: | |
| print(f" attempt {attempt + 1}/20: {e}") | |
| time.sleep(5) | |
| raise RuntimeError("Could not login to Label Studio after 20 attempts") | |
| def main(): | |
| print("Logging in to Label Studio...") | |
| session = get_session() | |
| print(" Logged in.") | |
| with open(f"{PROJ_DIR}/manifest.json") as f: | |
| manifest = json.load(f) | |
| created = 0 | |
| for entry in sorted(manifest, key=lambda x: x["id"]): | |
| fpath = f"{PROJ_DIR}/{entry['file']}" | |
| with open(fpath, encoding="utf-8") as f: | |
| data = json.load(f) | |
| title = data["title"] | |
| label_config = data["label_config"] | |
| tasks = data["tasks"] | |
| # Create project | |
| r = session.post(f"{LS}/api/projects/", json={ | |
| "title": title, | |
| "label_config": label_config, | |
| }, timeout=15) | |
| if r.status_code not in (200, 201): | |
| print(f" SKIP [{entry['id']:02d}] {title[:40]} β {r.status_code}: {r.text[:80]}") | |
| continue | |
| pid = r.json()["id"] | |
| created += 1 | |
| # Import tasks in batches of 100 | |
| imported = 0 | |
| for i in range(0, len(tasks), 100): | |
| batch = tasks[i:i + 100] | |
| ri = session.post(f"{LS}/api/projects/{pid}/import", | |
| json=batch, timeout=30) | |
| if ri.status_code in (200, 201): | |
| imported += ri.json().get("task_count", len(batch)) | |
| print(f" [{pid:02d}] {title[:55]:<55} {imported:>4} tasks imported") | |
| return created | |
| if __name__ == "__main__": | |
| created = main() | |
| if created == 0: | |
| raise SystemExit("No projects were created β check LS logs for auth errors") | |
| print(f"All done: {created} projects imported.") | |