File size: 3,833 Bytes
a80f6e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# type: ignore
"""Script to generate migrations for the migration script."""

import json
import os
import pkgutil

import click

from langchain_cli.namespaces.migrate.generate.generic import (
    generate_simplified_migrations,
)
from langchain_cli.namespaces.migrate.generate.grit import (
    dump_migrations_as_grit,
)
from langchain_cli.namespaces.migrate.generate.partner import (
    get_migrations_for_partner_package,
)


@click.group()
def cli():
    """Migration scripts management."""
    pass


@cli.command()
@click.option(
    "--pkg1",
    default="langchain",
)
@click.option(
    "--pkg2",
    default="langchain_community",
)
@click.option(
    "--output",
    default=None,
    help="Output file for the migration script.",
)
@click.option(
    "--filter-by-all/--no-filter-by-all",
    default=True,
    help="Output file for the migration script.",
)
@click.option(
    "--format",
    type=click.Choice(["json", "grit"], case_sensitive=False),
    default="json",
    help="The output format for the migration script (json or grit).",
)
def generic(
    pkg1: str, pkg2: str, output: str, filter_by_all: bool, format: str
) -> None:
    """Generate a migration script."""
    click.echo("Migration script generated.")
    migrations = generate_simplified_migrations(pkg1, pkg2, filter_by_all=filter_by_all)

    if output is not None:
        name = output.removesuffix(".json").removesuffix(".grit")
    else:
        name = f"{pkg1}_to_{pkg2}"

    if output is None:
        output = f"{name}.json" if format == "json" else f"{name}.grit"

    if format == "json":
        dumped = json.dumps(migrations, indent=2, sort_keys=True)
    else:
        dumped = dump_migrations_as_grit(name, migrations)

    with open(output, "w") as f:
        f.write(dumped)


def handle_partner(pkg: str, output: str = None):
    migrations = get_migrations_for_partner_package(pkg)
    # Run with python 3.9+
    name = pkg.removeprefix("langchain_")
    data = dump_migrations_as_grit(name, migrations)
    output_name = f"{name}.grit" if output is None else output
    if migrations:
        with open(output_name, "w") as f:
            f.write(data)
        click.secho(f"LangChain migration script saved to {output_name}")
    else:
        click.secho(f"No migrations found for {pkg}", fg="yellow")


@cli.command()
@click.argument("pkg")
@click.option("--output", default=None, help="Output file for the migration script.")
def partner(pkg: str, output: str) -> None:
    """Generate migration scripts specifically for LangChain modules."""
    click.echo("Migration script for LangChain generated.")
    handle_partner(pkg, output)


@cli.command()
@click.argument("json_file")
def json_to_grit(json_file: str) -> None:
    """Generate a Grit migration from an old JSON migration file."""
    with open(json_file, "r") as f:
        migrations = json.load(f)
    name = os.path.basename(json_file).removesuffix(".json").removesuffix(".grit")
    data = dump_migrations_as_grit(name, migrations)
    output_name = f"{name}.grit"
    with open(output_name, "w") as f:
        f.write(data)
    click.secho(f"GritQL migration script saved to {output_name}")


@cli.command()
def all_installed_partner_pkgs() -> None:
    """Generate migration scripts for all LangChain modules."""
    # Will generate migrations for all partner packages.
    # Define as "langchain_<partner_name>".
    # First let's determine which packages are installed in the environment
    # and then generate migrations for them.
    langchain_pkgs = [
        name
        for _, name, _ in pkgutil.iter_modules()
        if name.startswith("langchain_")
        and name not in {"langchain_core", "langchain_cli", "langchain_community"}
    ]
    for pkg in langchain_pkgs:
        handle_partner(pkg)


if __name__ == "__main__":
    cli()