Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| from tqdm import tqdm | |
| import config | |
| import dataset_statistics | |
| from api_wrappers import grazie_wrapper | |
| from generation_steps import examples | |
| GENERATION_MULTIPLIER = 3 | |
| REL_DELETIONS_THRESHOLD = 0.75 | |
| GENERATION_ATTEMPTS = 3 | |
| def build_prompt(prediction, diff): | |
| return f"""A LLM generated a commit message for the following source code changes: | |
| START OF THE SOURCE CODE CHANGES | |
| {diff} | |
| END OF THE SOURCE CODE CHANGES | |
| Here is the message the LLM generated: | |
| START OF THE COMMIT MESSAGE | |
| {prediction} | |
| END OF THE COMMIT MESSAGE | |
| This generated message is not perfect. Your task is to rewrite and improve it. | |
| You have to simulate a human software developer who manually rewrites the LLM-generated commit message, | |
| so the message you print must share some fragments with the generated message. | |
| Your message should be concise. | |
| Follow the Conventional Commits guidelines. | |
| Here are some examples of what you should output: | |
| START OF THE EXAMPLES LIST | |
| {examples.EXAMPLES_START_TO_END} | |
| END OF THE EXAMPLES LIST | |
| Print only the improved commit message's text after the | |
| token "OUTPUT". | |
| OUTPUT""" | |
| def generate_end_msg(start_msg, diff): | |
| prompt = build_prompt(prediction=start_msg, diff=diff) | |
| results = [] | |
| for i in range(GENERATION_ATTEMPTS): | |
| end_msg_pred = grazie_wrapper.generate_for_prompt(prompt) | |
| stats = dataset_statistics.get_statistics_for_sample( | |
| start_msg=start_msg, | |
| end_msg=end_msg_pred, | |
| ) | |
| if stats["deletions"] < REL_DELETIONS_THRESHOLD: | |
| return end_msg_pred | |
| else: | |
| results.append((stats["deletions"], end_msg_pred)) | |
| results.sort() | |
| return results[0][1] | |
| COLS_TO_KEEP = ["hash", "repo", "commit_msg_start", "mods", "session", "end_to_start"] | |
| def print_config(): | |
| print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}") | |
| print(f"GENERATION_MULTIPLIER = {GENERATION_MULTIPLIER}") | |
| print(f"REL_DELETIONS_THRESHOLD = {REL_DELETIONS_THRESHOLD}") | |
| print(f"GENERATION_ATTEMPTS = {GENERATION_ATTEMPTS}") | |
| def transform(df): | |
| print("Start -> send synthesis:") | |
| print_config() | |
| df["start_to_end"] = False | |
| generated_data = {"commit_msg_end": []} | |
| for col in COLS_TO_KEEP: | |
| generated_data[col] = [] | |
| for _, row in tqdm(df.iterrows(), total=len(df)): | |
| for i in range(GENERATION_MULTIPLIER): | |
| commit_msg_end_pred = generate_end_msg(start_msg=row["commit_msg_start"], diff=row["mods"]) | |
| generated_data["commit_msg_end"].append(commit_msg_end_pred) | |
| for col in COLS_TO_KEEP: | |
| generated_data[col].append(row[col]) | |
| generated_df = pd.DataFrame.from_dict(generated_data) | |
| generated_df["start_to_end"] = True | |
| result = pd.concat([df, generated_df], ignore_index=True) | |
| result.to_csv(config.START_TO_END_ARTIFACT) | |
| print("Done") | |
| return result | |
| def main(): | |
| df = pd.read_csv(config.END_TO_START_ARTIFACT, index_col=[0]) | |
| transform(df) | |
| if __name__ == "__main__": | |
| main() | |