Spaces:
Runtime error
Runtime error
| import json | |
| import pathlib | |
| from copy import deepcopy | |
| from typing import Callable | |
| from functools import partial | |
| import click | |
| import pandas as pd | |
| import pandera.pandas as pa | |
| from tqdm.auto import tqdm | |
| from langchain_core.runnables import Runnable | |
| from src.common.data import load_dataset | |
| from src.common.schema import DatasetSchema | |
| from src.generate.config import GenerationConfig | |
| from src.generate.schema import GeneratedDatasetSchema | |
| from src.generate.answer import make_root_model, matches_type, string_to_type | |
| from src.generate.generators import GenerationAnswer, GENERATORS_NAME_TO_FACTORY | |
| def _save_temp_file( | |
| row: dict, | |
| result: GenerationAnswer, | |
| temp_path: pathlib.Path, | |
| ) -> None: | |
| temp_file = temp_path / f"{row[DatasetSchema.id_]}.json" | |
| json.dump( | |
| { | |
| DatasetSchema.id_: row[DatasetSchema.id_], | |
| GeneratedDatasetSchema.generated_answer: result.model_dump(), | |
| }, | |
| open(temp_file, "w"), | |
| ensure_ascii=False, | |
| ) | |
| def _generate_single_answer( | |
| row: dict, | |
| build_chain: Callable[[type], Runnable], | |
| temp_path: pathlib.Path = None, | |
| ) -> GenerationAnswer: | |
| if temp_path and (temp_path / f"{row[DatasetSchema.id_]}.json").exists(): | |
| return GenerationAnswer.model_validate( | |
| json.load(open(temp_path / f"{row[DatasetSchema.id_]}.json", "r"))[ | |
| GeneratedDatasetSchema.generated_answer | |
| ] | |
| ) | |
| answer_type = make_root_model(row[DatasetSchema.answer_type]) | |
| chain = build_chain(answer_type) | |
| row = dict(row) | |
| row.pop(DatasetSchema.correct_answer, None) | |
| result: GenerationAnswer = chain.invoke(row) | |
| if temp_path: | |
| _save_temp_file(row, result, temp_path) | |
| return result | |
| def _generate_answers( | |
| df: pd.DataFrame, | |
| build_chain: Callable[[type], Runnable], | |
| use_tqdm: bool = True, | |
| temp_path: pathlib.Path = None, | |
| ) -> pd.DataFrame: | |
| if use_tqdm: | |
| tqdm.pandas() | |
| df[GeneratedDatasetSchema.generated_answer] = df.progress_apply( | |
| partial( | |
| _generate_single_answer, | |
| build_chain=build_chain, | |
| temp_path=temp_path, | |
| ), | |
| axis=1, | |
| ) | |
| else: | |
| df[GeneratedDatasetSchema.generated_answer] = df.apply( | |
| partial( | |
| _generate_single_answer, | |
| build_chain=build_chain, | |
| temp_path=temp_path, | |
| ), | |
| axis=1, | |
| ) | |
| df = df[list(GeneratedDatasetSchema._collect_fields().keys())] | |
| return df | |
| def generate( | |
| config_path: pathlib.Path = pathlib.Path("configs/ollama.yaml"), | |
| output_path: pathlib.Path = pathlib.Path("./gemma3:4b.jsonl"), | |
| temp_path: pathlib.Path = pathlib.Path("./tmp_gemma3:4b/"), | |
| use_tqdm: bool = True, | |
| ): | |
| output_path = pathlib.Path(output_path) | |
| temp_path = pathlib.Path(temp_path) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| temp_path.mkdir(parents=True, exist_ok=True) | |
| config = GenerationConfig.from_file(config_path) | |
| df = load_dataset() | |
| # df = df.head(3) | |
| build_chain_function = GENERATORS_NAME_TO_FACTORY[config.build_function] | |
| build_chain_function = partial( | |
| build_chain_function, | |
| llm_class=config.llm_class, | |
| structured_output_method=config.structured_output_method, | |
| **config.kwargs, | |
| ) | |
| df = _generate_answers( | |
| df, build_chain_function, use_tqdm=use_tqdm, temp_path=temp_path | |
| ) | |
| df[GeneratedDatasetSchema.generated_answer] = df[ | |
| GeneratedDatasetSchema.generated_answer | |
| ].apply(lambda x: x.model_dump()) | |
| df.to_json( | |
| output_path, | |
| lines=True, | |
| orient="records", | |
| force_ascii=False, | |
| ) | |
| def _type_sanitycheck( | |
| generated_df: pd.DataFrame, | |
| ) -> tuple[bool, str]: | |
| generated_df[GeneratedDatasetSchema.generated_answer] = generated_df[ | |
| GeneratedDatasetSchema.generated_answer | |
| ].apply( | |
| lambda x: GenerationAnswer.model_validate(deepcopy(x)) | |
| if not isinstance(x, GenerationAnswer) | |
| else x | |
| ) | |
| dataset_df = load_dataset() | |
| predicted_df = dataset_df.join( | |
| generated_df.set_index(GeneratedDatasetSchema.id_), | |
| on=DatasetSchema.id_, | |
| rsuffix="_generated", | |
| ).dropna(subset=[GeneratedDatasetSchema.generated_answer]) | |
| if len(predicted_df) == 0: | |
| return False, "No valid predictions found." | |
| TYPE_MATCH = "type_match" | |
| predicted_df[TYPE_MATCH] = predicted_df.apply( | |
| lambda row: matches_type( | |
| row[GeneratedDatasetSchema.generated_answer].answer, | |
| string_to_type(row[DatasetSchema.answer_type]), | |
| ), | |
| axis=1, | |
| ) | |
| if not predicted_df[TYPE_MATCH].all(): | |
| return ( | |
| False, | |
| f"Type mismatch found for {predicted_df[~predicted_df[TYPE_MATCH]][DatasetSchema.id_].tolist()}.", | |
| ) | |
| return ( | |
| True, | |
| f"All matched. Predicted count: {len(predicted_df)} of {len(dataset_df)}", | |
| ) | |
| def type_sanitycheck( | |
| file: pathlib.Path = pathlib.Path("./gemma3:4b.jsonl"), | |
| ): | |
| df = pd.read_json(file, lines=True) | |
| types_correct, message = _type_sanitycheck(df) | |
| if not types_correct: | |
| click.echo(f"❌ Type sanity check failed: {message}") | |
| exit(1) | |
| click.echo(f"✅ Type sanity check passed: {message}") | |
| def cli(): | |
| pass | |
| cli.add_command(generate) | |
| cli.add_command(type_sanitycheck) | |
| if __name__ == "__main__": | |
| cli() | |