jaothan's picture
Update app.py
9aa98b6 verified
#https://www.youtube.com/watch?v=AREBu2B5H3M&t=140s
import click
import logging
import sys
import yaml
import os
from langchain.llms import AzureOpenAI, OpenAI
from app.api_funcs import get_job_infos, get_run, get_model, \
trans_model, batch_mod_permission, prepare_api_docs
from pathlib import Path
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG)
logger.addHandler(handler)
PATH = Path(os.path.abspath(os.path.dirname(__file__)))
# Open the YAML file
conf_path = PATH / "app" / 'llm.yaml'
with open(conf_path) as f:
config = yaml.safe_load(f)
# Use AzureOpenAI, if config contains deployment name, otherwise OpenAI
if config['model'].get('deployment_name', False):
llm = AzureOpenAI(**config['model'])
else:
llm = OpenAI(**config['model'])
headers = {"Authorization": f"Bearer {os.getenv('DBR_BEARER_TOKEN')}"}
updated_api_docs = prepare_api_docs()
def comma_list(comma_str: str):
return comma_str.split(',')
def determine_api_text(updated_api_docs: dict, query: str):
pick_api_prompt = """Please return the file name from the list {api_docs}
that best corresponds to the following query: {query}. \
DO NOT EXPLAIN your answer!
"""
api_docs = os.listdir(PATH / "app" / "dbr_api_docs")
selected_api_doc = llm(pick_api_prompt.format(api_docs=api_docs, query=query)).lstrip().rstrip()
logger.info(f"\nSelecting the following api document: {selected_api_doc}")
api_text = updated_api_docs[selected_api_doc]
return api_text, selected_api_doc
# Add subcommands for commands
@click.group()
def cli():
pass
@cli.group(help='Run machine learning model.')
def ml():
pass
# Add commands for specific subcommands of 'ml'
@ml.command(help='Get information about a model.')
@click.argument('query', type=str)
def get_model_info(query):
# Instruction to get model infos
api_text, _ = determine_api_text(updated_api_docs, query)
logger.info(get_model(llm, query, api_text, headers))
@ml.command(help='Get information about a model run.')
@click.argument('run_id', type=str)
@click.argument('query', type=str)
def get_run_info(query, run_id):
# ID of the model run for which you'd like information.
# Which information should be pulled from the run?
api_text, _ = determine_api_text(updated_api_docs, query)
logger.info(get_run(llm, run_id, query, api_text, headers))
@ml.command(help='Transition a model from one state to another.')
@click.argument('query', type=str)
def transition_model(query):
# Instruction to transition a model.
api_text, _ = determine_api_text(updated_api_docs, query)
trans_model(llm, query, api_text, headers)
@cli.command(help='View job history.')
@click.argument('query', type=str)
def jobs(query):
if ";" not in query:
query = query + ";"
query, response_query = query.split(";")
api_text, _ = determine_api_text(updated_api_docs, query)
# The query for the LLM + an optional query for the API response
logger.info(get_job_infos(llm, query, response_query, api_text, headers))
@cli.command(help='Manage user permissions.')
@click.argument('query', type=str)
@click.argument('jobs', type=comma_list)
def permissions(jobs, query):
api_text, api_name = determine_api_text(updated_api_docs, query)
# Add/Get user permissions.
batch_mod_permission(
logger, llm, updated_api_docs, api_text, api_name, headers,
query, jobs=jobs
)
if __name__ == '__main__':
cli()