| | import os |
| | import argparse |
| | from random import randrange |
| |
|
| | import torch |
| | from datasets import load_dataset |
| | from peft import AutoPeftModelForCausalLM |
| | from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
| |
|
| |
|
| | def format_instruction(sample): |
| | return f"""You are a personal stylist recommending fashion advice and clothing combinations. Use the self body and style description below, combined with the event described in the context to generate 5 self-contained and complete outfit combinations. |
| | ### Input: |
| | {sample["input"]} |
| | |
| | ### Context: |
| | {sample["context"]} |
| | |
| | ### Response: |
| | """ |
| |
|
| | def postprocess(outputs, tokenizer, prompt, sample): |
| | outputs = outputs.detach().cpu().numpy() |
| | outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| | output = outputs[0][len(prompt):] |
| |
|
| | print(f"Instruction: \n{sample['input']}\n") |
| | print(f"Context: \n{sample['context']}\n") |
| | print(f"Ground truth: \n{sample['completion']}\n") |
| | print(f"Generated output: \n{output}\n\n\n") |
| | return |
| |
|
| |
|
| | def run_model(config): |
| | |
| | dataset = load_dataset(config.dataset) |
| | sample = dataset[randrange(len(dataset))] |
| | prompt = format_instruction(sample) |
| |
|
| | |
| | model = AutoPeftModelForCausalLM.from_pretrained( |
| | config.model_id, |
| | low_cpu_mem_usage=True, |
| | torch_dtype=torch.float16, |
| | load_in_4bit=True, |
| | ) |
| | tokenizer = AutoTokenizer.from_pretrained(config.model_id) |
| | input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda() |
| | |
| | |
| | with torch.inference_mode(): |
| | outputs = model.generate( |
| | input_ids=input_ids, |
| | max_new_tokens=800, |
| | do_sample=True, |
| | top_p=0.9, |
| | temperature=0.9 |
| | ) |
| |
|
| | postprocess(outputs, tokenizer, prompt, sample) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--dataset", type=str, default="neuralwork/fashion-style-instruct", |
| | help="HF dataset id or path to local dataset folder." |
| | ) |
| | parser.add_argument( |
| | "--model_id", type=str, default="neuralwork/mistral-7b-style-instruct", |
| | help="HF LoRA model id or path to local finetuned model folder." |
| | ) |
| |
|
| | config = parser.parse_args() |
| | run_model(config) |
| |
|