| | """ |
| | scripts/generate_test_data.py |
| | |
| | Generates realistic test data for Sheami using your modules.db.SheamiDB API. |
| | |
| | Behavior: |
| | - Creates N users (default 100) |
| | - Each user: 3-5 patients (enforced) |
| | - Each patient: 2-6 reports |
| | - Each report: 3-6 tests drawn from TEST_POOL |
| | - For each patient we write trends (per test) using add_or_update_trend |
| | - For each patient we write a final report using add_final_report |
| | |
| | Usage: |
| | pip install faker pymongo python-dotenv |
| | MONGODB_URI="mongodb+srv://<user>:<pass>@cluster0.xxxxx.mongodb.net" \ |
| | MONGODB_DB="sheami" \ |
| | python scripts/generate_test_data.py --num-users 100 |
| | |
| | The script CALLS THESE EXACT methods on your SheamiDB: |
| | - add_user(email, name) |
| | - add_patient(user_id, name, dob, gender) |
| | - add_report(patient_id, file_name, parsed_data) |
| | - add_or_update_trend(patient_id, test_name, trend_data) |
| | - add_final_report(patient_id, summary, recommendations, trend_snapshots) |
| | """ |
| | import argparse |
| | import random |
| | from collections import defaultdict |
| | from datetime import datetime, timedelta |
| | import os |
| |
|
| | from faker import Faker |
| | from dotenv import load_dotenv |
| |
|
| | |
| | load_dotenv() |
| |
|
| | |
| | from modules.db import SheamiDB |
| |
|
| | |
| | faker = Faker() |
| | TEST_POOL = { |
| | "Hemoglobin": (11.0, 17.5, "g/dL", "11.0-17.5"), |
| | "Glucose (Fasting)": (60, 130, "mg/dL", "70-99 fasting"), |
| | "Total Cholesterol": (120, 300, "mg/dL", "<200 desirable"), |
| | "Triglycerides": (40, 300, "mg/dL", "<150 normal"), |
| | "HDL": (30, 90, "mg/dL", ">40 desirable"), |
| | "LDL": (50, 200, "mg/dL", "<100 ideal"), |
| | "Creatinine": (0.5, 1.8, "mg/dL", "0.5-1.2"), |
| | "Urea (BUN)": (7, 30, "mg/dL", "7-20"), |
| | "Sodium": (130, 150, "mmol/L", "135-145"), |
| | "Potassium": (3.2, 5.2, "mmol/L", "3.5-5.0"), |
| | "ALT": (7, 55, "U/L", "<45"), |
| | "AST": (8, 48, "U/L", "<40"), |
| | } |
| |
|
| | def random_date_between(start_year=2019): |
| | start = datetime(start_year, 1, 1) |
| | end = datetime.now() |
| | days = (end - start).days |
| | return start + timedelta(days=random.randint(0, days)) |
| |
|
| | def make_test_values(k): |
| | """Return list of test dicts matching parsed_data.tests schema.""" |
| | chosen = random.sample(list(TEST_POOL.items()), k=k) |
| | tests = [] |
| | for name, (low, high, unit, ref) in chosen: |
| | |
| | if isinstance(low, float) or isinstance(high, float): |
| | value = round(random.uniform(low, high), 2) |
| | else: |
| | value = int(round(random.uniform(low, high))) |
| | tests.append({ |
| | "name": name, |
| | "value": value, |
| | "unit": unit, |
| | "reference_range": ref |
| | }) |
| | return tests |
| |
|
| | def compute_direction(points): |
| | if len(points) < 2: |
| | return "stable" |
| | if points[-1]["value"] > points[-2]["value"]: |
| | return "increasing" |
| | if points[-1]["value"] < points[-2]["value"]: |
| | return "decreasing" |
| | return "stable" |
| |
|
| | |
| | def generate_test_data(db_uri: str, db_name: str, num_users: int = 100, |
| | min_patients=3, max_patients=5, |
| | min_reports=2, max_reports=6, |
| | min_tests=3, max_tests=6, |
| | seed: int = None): |
| | if seed is not None: |
| | random.seed(seed) |
| | Faker.seed(seed) |
| |
|
| | db = SheamiDB(db_uri, db_name=db_name) |
| |
|
| | counters = {"users": 0, "patients": 0, "reports": 0, "trends": 0, "final_reports": 0} |
| |
|
| | for u_idx in range(num_users): |
| | |
| | user_name = faker.name() |
| | user_email = faker.unique.safe_email() |
| | user_id = db.add_user(email=user_email, name=user_name) |
| | counters["users"] += 1 |
| |
|
| | |
| | num_patients = random.randint(min_patients, max_patients) |
| | for _p in range(num_patients): |
| | patient_name = faker.name() |
| | |
| | age = random.randint(18, 85) |
| | dob_dt = datetime.now() - timedelta(days=365 * age + random.randint(0, 365)) |
| | dob_str = dob_dt.strftime("%Y-%m-%d") |
| | gender = random.choice(["male", "female", "other"]) |
| |
|
| | patient_id = db.add_patient(user_id=user_id, name=patient_name, dob=dob_str, gender=gender) |
| | counters["patients"] += 1 |
| |
|
| | |
| | trends_map = defaultdict(list) |
| |
|
| | |
| | num_reports = random.randint(min_reports, max_reports) |
| | for r_i in range(num_reports): |
| | report_date_dt = random_date_between() |
| | report_date = report_date_dt.strftime("%Y-%m-%d") |
| | num_tests = random.randint(min_tests, max_tests) |
| | tests = make_test_values(num_tests) |
| |
|
| | parsed_data = { |
| | "tests": tests, |
| | "report_date": report_date |
| | } |
| | file_name = f"report_{report_date.replace('-', '')}_{random.randint(1000,9999)}.pdf" |
| | report_id = db.add_report(patient_id=patient_id, file_name=file_name, parsed_data=parsed_data) |
| | counters["reports"] += 1 |
| |
|
| | |
| | for t in tests: |
| | trends_map[t["name"]].append({"date": report_date, "value": t["value"]}) |
| |
|
| | |
| | for test_name, points in trends_map.items(): |
| | |
| | pts_sorted = sorted(points, key=lambda x: x["date"]) |
| | db.add_or_update_trend(patient_id=patient_id, test_name=test_name, trend_data=pts_sorted) |
| | counters["trends"] += 1 |
| |
|
| | |
| | trend_snapshots = [] |
| | for test_name, points in trends_map.items(): |
| | pts_sorted = sorted(points, key=lambda x: x["date"]) |
| | latest_value = pts_sorted[-1]["value"] |
| | direction = compute_direction(pts_sorted) |
| | trend_snapshots.append({ |
| | "test_name": test_name, |
| | "latest_value": latest_value, |
| | "direction": direction |
| | }) |
| |
|
| | summary = f"Auto-generated summary for {patient_name} ({len(trend_snapshots)} tests)" |
| | recommendations = [] |
| | |
| | if any(ts["direction"] == "increasing" for ts in trend_snapshots): |
| | recommendations.append("Follow up for rising values") |
| | else: |
| | recommendations.append("Continue routine monitoring") |
| | db.add_final_report(patient_id=patient_id, |
| | summary=summary, |
| | recommendations=recommendations, |
| | trend_snapshots=trend_snapshots) |
| | counters["final_reports"] += 1 |
| |
|
| | |
| | if (u_idx + 1) % 10 == 0 or (u_idx + 1) == num_users: |
| | print(f"Created {u_idx+1}/{num_users} users so far...") |
| |
|
| | |
| | print("Generation complete. Summary:") |
| | for k, v in counters.items(): |
| | print(f" {k}: {v}") |
| |
|
| | |
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Generate test data for Sheami (matches your db.py).") |
| | parser.add_argument("--num-users", type=int, default=100, help="Number of users to create") |
| | parser.add_argument("--db-uri", type=str, default=os.getenv("MONGODB_URI", "mongodb://localhost:27017"), |
| | help="MongoDB connection URI") |
| | parser.add_argument("--db-name", type=str, default=os.getenv("MONGODB_DB", "sheami"), |
| | help="Database name") |
| | parser.add_argument("--seed", type=int, default=None, help="Random seed (optional)") |
| | args = parser.parse_args() |
| |
|
| | generate_test_data(db_uri=args.db_uri, db_name=args.db_name, |
| | num_users=args.num_users, seed=args.seed) |
| |
|