| import os |
| import sys |
| import logging |
| from flask import Flask, request, jsonify |
| from flask_cors import CORS |
| from vllm import LLM, SamplingParams |
|
|
|
|
| |
| import os |
| import os |
| from pathlib import Path |
| import csv |
| import json |
| import openai |
| import time |
| import pandas as pd |
|
|
| |
| api_key = "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" |
|
|
| openai.api_key = api_key |
|
|
| |
| model_engine = "text-davinci-003" |
| import gradio as gr |
| import time |
| import argparse |
| from vllm import LLM, SamplingParams |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model", type=str) |
| parser.add_argument("--n_gpu", type=int, default=1) |
| return parser.parse_args() |
|
|
| def echo(message, history, system_prompt, temperature, max_tokens): |
| response = f"System prompt: {system_prompt}\n Message: {message}. \n Temperature: {temperature}. \n Max Tokens: {max_tokens}." |
| for i in range(min(len(response), int(max_tokens))): |
| time.sleep(0.05) |
| yield response[: i+1] |
|
|
|
|
| |
| def get_llm_result(input_sys_prompt_str, input_history_str, prompt_str, llm): |
| |
| prompt = "" |
|
|
| def predict(message, history, system_prompt, temperature, max_tokens): |
| instruction = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. " |
| for human, assistant in history: |
| instruction += 'USER: '+ human + ' ASSISTANT: '+ assistant + '</s>' |
| instruction += 'USER: '+ message + ' ASSISTANT:' |
| problem = [instruction] |
| stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"] |
| sampling_params = SamplingParams(temperature=temperature, top_p=1, max_tokens=max_tokens, stop=stop_tokens) |
| completions = llm.generate(problem, sampling_params) |
| for output in completions: |
| prompt = output.prompt |
| generated_text = output.outputs[0].text |
| return generated_text |
| |
| |
| try: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| history = input_history_str |
| prompt = prompt_str |
| system_prompt = input_sys_prompt_str |
|
|
| response = predict(prompt, history, system_prompt, 0.5, 3000) |
|
|
| print(response) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| return response, response |
|
|
| except Exception as ex: |
| print("File not exist") |
| raise ex |
|
|
|
|
|
|
| |
| |
|
|
| |
| app = Flask(__name__) |
| CORS(app) |
|
|
|
|
| |
| if 'DYNO' in os.environ: |
| app.logger.addHandler(logging.StreamHandler(sys.stdout)) |
| app.logger.setLevel(logging.INFO) |
|
|
| app.logger.addHandler(logging.StreamHandler(sys.stdout)) |
| app.logger.setLevel(logging.INFO) |
| |
| |
|
|
|
|
| |
| @app.route('/api', methods=['POST']) |
| def api(): |
| """API function |
| |
| All model-specific logic to be defined in the get_model_api() |
| function |
| """ |
| input_data = request.json |
| log = open("test_topic_serve_log.csv", 'a', encoding='utf-8') |
| app.logger.info("api_input: " + str(input_data)) |
| log.write("api_input: " + str(input_data)) |
| |
| |
| input_sys_prompt_str = input_data['system_prompt'] |
| input_USER_str = input_data['USER'] |
| |
| input_history_str = input_data['history'] |
| |
| model_path = "/workspaceblobstore/caxu/trained_models/13Bv2_497kcontinueroleplay_dsys_2048_e4_2e_5/checkpoint-75" |
| llm = LLM(model=model_path, tensor_parallel_size=1) |
|
|
| output_data = get_llm_result(input_sys_prompt_str, input_history_str, input_USER_str, llm) |
| app.logger.info("api_output: " + str(output_data)) |
| response = jsonify(output_data) |
| log.write("api_output: " + str(output_data) + "\n") |
|
|
| return response |
|
|
| |
| @app.route('/labelapi', methods=['POST']) |
| def labelapi(): |
| """label API function |
| record user label action |
| All model-specific logic to be defined in the get_model_api() |
| function |
| """ |
| input_data = request.json |
| log = open("test_topic_label_log.csv", 'a', encoding='utf-8') |
| app.logger.info("api_input: " + str(input_data)) |
| log.write("api_input: " + str(input_data)+ "\n") |
| output_data = {"input": input_data, "output": "record_success"} |
|
|
| response = output_data |
| return response |
|
|
| @app.route('/') |
| def index(): |
| return "Index API" |
|
|
| |
| @app.errorhandler(404) |
| def url_error(e): |
| return """ |
| Wrong URL! |
| <pre>{}</pre>""".format(e), 404 |
|
|
|
|
| @app.errorhandler(500) |
| def server_error(e): |
| return """ |
| An internal error occurred: <pre>{}</pre> |
| See logs for full stacktrace. |
| """.format(e), 500 |
|
|
|
|
| if __name__ == '__main__': |
| |
| |
| app.run(host='0.0.0.0',port=4455,debug=True) |
| |
|
|
|
|