Spaces:
Runtime error
Runtime error
| from text_generation import Client | |
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| PAPERSPACE_IP = os.getenv("PAPERSPACE_IP") | |
| client = Client(PAPERSPACE_IP) | |
| def generate_text(input_text, max_new_tokens=20, temperature=1): | |
| return client.generate(input_text, max_new_tokens=max_new_tokens, temperature=temperature).generated_text | |
| def generate_multi_text(input_text, file_path, max_new_tokens=20, temperature=1, out_path=None, earlystop = None): | |
| with open(file_path, "r") as file: | |
| rows = file.readlines() | |
| if earlystop is not None: | |
| rows = rows[:earlystop] | |
| multi_turns = [formatter(row.strip()) for row in rows] | |
| print("You are playing " + str(len(multi_turns)) + " turns.") | |
| generated_text = [] | |
| with open(out_path, "w") as file: | |
| for i, turn in enumerate(multi_turns): | |
| single_turn_resp = generate_text(input_text+turn, | |
| max_new_tokens=max_new_tokens, temperature=temperature) | |
| generated_text.append(single_turn_resp) | |
| file.write(f"Turn {i+1}: {single_turn_resp}\n") | |
| print(turn) | |
| print(single_turn_resp) | |
| print("-----------") | |
| return generated_text | |
| def read_text_file(file_path): | |
| with open(file_path, 'r') as file: | |
| return file.read() | |
| def formatter(user_prompt): | |
| return f"[User]: {user_prompt.strip()} \n [You]: \n" | |
| def main(): | |
| cwd = os.getcwd() | |
| input_text = read_text_file(os.path.join(cwd, 'utils/prompts/prompt_attitude.txt')) | |
| # user_turn = read_text_file(os.path.join(cwd, '../finetune/data/turns/conversation_nothing.txt')) | |
| max_new_tokens = 40 | |
| temperature = 0.3 | |
| multi_path = os.path.join(cwd,'inappropriate.txt') | |
| out_path = os.path.join(cwd, f'utils/user_turns/multi_turns_conversation_t{temperature}_m{max_new_tokens}_promptatt_mistral_inapp.txt') | |
| generated_text = generate_multi_text(input_text, multi_path, max_new_tokens, temperature, out_path) | |
| # print(input_text+user_turn) | |
| # generate_text_resp = generate_text(input_text+user_turn,max_new_tokens ) | |
| # print(generate_text_resp) | |
| if __name__ == "__main__": | |
| main() | |