#!/usr/bin/env python # Interactively play with a fine-tuned Vertex AI model, giving it back the # accumulated prompt as necessary so it's not stateless. import os import re import requests import time import random import sys import argparse sys.path.insert(0, os.path.dirname(os.path.realpath(__file__))) import vertexai from vertexai.tuning import sft from vertexai.generative_models import GenerativeModel from google.cloud import aiplatform_v1 import google.auth from google.auth.transport.requests import AuthorizedSession from auth_setup import PROJECT_ID, REGION, ZONE, ensure_gcloud ensure_gcloud() BOT_NAME = "hotbot" # Frobot/Hotbot/Coolbot/etc if __name__ == "__main__": parser = argparse.ArgumentParser(description="Play with fine-tuned model.") parser.add_argument( "tuning_job_id", type=int, help="The tuning job ID, which can be obtained from running 'showtuningjobs succeeded' and reading carefully. Pirates can always be reached at 117775339060461568." ) parser.add_argument( "prompt_flag", type=str, help="The prompt file to be used as instructions to the model c for coolbot, f for frobot, h for hotbot, N for none." ) args = parser.parse_args() credentials, _ = google.auth.default( scopes=["https://www.googleapis.com/auth/cloud-platform"] ) session = AuthorizedSession(credentials) tuning_job_name = f"projects/{PROJECT_ID}/locations/{REGION}/tuningJobs/{args.tuning_job_id}" uri = f"https://{REGION}-aiplatform.googleapis.com/v1/{tuning_job_name}" resp = session.get(uri) resp.raise_for_status() data = resp.json() display_name = data.get("tunedModelDisplayName") tj = sft.SupervisedTuningJob(tuning_job_name) tm = GenerativeModel(tj.tuned_model_endpoint_name) accumulated_content = "" flag = args.prompt_flag if flag not in ["N","f","c","h"]: raise Exception("Missing flag for prompt file must be f,c,h,or N") if flag != "N": if flag == "f": prompt_file = "../prompts/experiment/frobot_prompt.txt" elif flag == "c": prompt_file = "../prompts/experiment/coolbot_prompt.txt" elif flag == "h": prompt_file = "../prompts/experiment/hotbot_prompt.txt" with open(prompt_file,"r") as f: accumulated_content = f.read() accumulated_content = re.sub(r"","B",accumulated_content) new_input = input(f"Type something to {display_name}> ") while new_input != "done": accumulated_content += '\nA: ' + new_input response = tm.generate_content(accumulated_content) response_txt = response.candidates[0].content.parts[0].text accumulated_content += "\nB: " + response_txt print(accumulated_content) print("") new_input = input(f"Type something to {display_name}> ")