File size: 2,900 Bytes
40a04d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python
# Interactively play with a fine-tuned Vertex AI model, giving it back the
# accumulated prompt as necessary so it's not stateless.
import os
import re
import requests
import time
import random
import sys
import argparse
sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)))

import vertexai
from vertexai.tuning import sft
from vertexai.generative_models import GenerativeModel
from google.cloud import aiplatform_v1
import google.auth
from google.auth.transport.requests import AuthorizedSession

from auth_setup import PROJECT_ID, REGION, ZONE, ensure_gcloud
ensure_gcloud()

BOT_NAME = "hotbot"   # Frobot/Hotbot/Coolbot/etc

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description="Play with fine-tuned model.")
    parser.add_argument(
        "tuning_job_id",
        type=int,
        help="The tuning job ID, which can be obtained from running 'showtuningjobs succeeded' and reading carefully. Pirates can always be reached at 117775339060461568."
    )
    parser.add_argument(
        "prompt_flag",
        type=str,
        help="The prompt file to be used as instructions to the model c for coolbot, f for frobot, h for hotbot, N for none."
    )
    args = parser.parse_args()

    credentials, _ = google.auth.default(
        scopes=["https://www.googleapis.com/auth/cloud-platform"]
    )
    session = AuthorizedSession(credentials)

    tuning_job_name = f"projects/{PROJECT_ID}/locations/{REGION}/tuningJobs/{args.tuning_job_id}"

    uri = f"https://{REGION}-aiplatform.googleapis.com/v1/{tuning_job_name}"
    resp = session.get(uri)
    resp.raise_for_status()
    data = resp.json()
    display_name = data.get("tunedModelDisplayName")

    tj = sft.SupervisedTuningJob(tuning_job_name)
    tm = GenerativeModel(tj.tuned_model_endpoint_name)

    accumulated_content = ""
    flag = args.prompt_flag
    if flag not in ["N","f","c","h"]:
        raise Exception("Missing flag for prompt file must be f,c,h,or N")
    if flag != "N":
        if flag == "f":
            prompt_file = "../prompts/experiment/frobot_prompt.txt"
        elif flag == "c":
            prompt_file = "../prompts/experiment/coolbot_prompt.txt"
        elif flag == "h":
            prompt_file = "../prompts/experiment/hotbot_prompt.txt"
        with open(prompt_file,"r") as f:
            accumulated_content = f.read()     
        accumulated_content = re.sub(r"<RE>","B",accumulated_content) 

    new_input = input(f"Type something to {display_name}> ")
    while new_input != "done":
        accumulated_content += '\nA: ' + new_input
        response = tm.generate_content(accumulated_content)
        response_txt = response.candidates[0].content.parts[0].text
        accumulated_content += "\nB: " + response_txt
        print(accumulated_content)
        print("")
        new_input = input(f"Type something to {display_name}> ")