jingwang commited on
Commit
7321327
·
verified ·
1 Parent(s): c009b90

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +75 -1
README.md CHANGED
@@ -19,4 +19,78 @@ base_model: unsloth/mistral-7b-v0.3-bnb-4bit
19
 
20
  This mistral model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
21
 
22
- [<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  This mistral model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
21
 
22
+ ```python
23
+
24
+ class FormatPrompt_context_QA():
25
+ '''format prompt class'''
26
+ def __init__(self, eos_token:str='</s>') -> None:
27
+ self.inputs = ['context','question'] # required input fields
28
+ self.outputs = ['answer'] # for training, and model inference output fields
29
+ self.eos_token = eos_token
30
+
31
+ def __call__(self, instance: Dict[str, Any]) -> str:
32
+ '''
33
+ function call operator
34
+ Args:
35
+ instance: dictionary with keys: 'context', 'question', 'answer'
36
+ Returns:
37
+ prompt: formatted prompt
38
+ '''
39
+ return self.formatting_prompt_func(instance)
40
+
41
+ def formatting_prompt_func(self, instance: dict) -> str:
42
+ '''format prompt for domain specific QA
43
+ note this is for fine-tuning pre-trained model,
44
+ if starting with instuct tuned model, use `tokenizer.apply_chat_template(messages)` instead
45
+ '''
46
+
47
+ assert all([ item in instance.keys() for item in self.inputs ]), logging.info(f"instance must have {self.inputs}!")
48
+
49
+ prompt = f"""<s> [INST] Answer following question based on Context: {str(instance["context"])}\
50
+ Question: {str(instance["question"])} \
51
+ Answer: [/INST]"""
52
+
53
+ if 'answer' in instance:
54
+ prompt += str(instance['answer']) + self.eos_token
55
+ return prompt
56
+ ```
57
+
58
+ ```
59
+ formatting_func = FormatPrompt()
60
+
61
+ # pull model from huggingface
62
+ model, tokenizer = FastLanguageModel.from_pretrained(
63
+ model_name = "jingwang/mistral_context_qa",
64
+ max_seq_length = 2048,
65
+ dtype = None,
66
+ load_in_4bit = True,
67
+ )
68
+
69
+
70
+ FastLanguageModel.for_inference(model)
71
+
72
+ example = {'question': 'What does the graph compare in terms of cumulative total return?',
73
+ 'context': 'the following graph shows a comparison, from january 1, 2019 through december 31, 2023, of the cumulative total return on our common stock, the nasdaq composite index and a group of all public companies sharing the same sic code as us, which is sic code 3711, “ motor vehicles and passenger car bodies ” ( motor vehicles and passenger car bodies public company group ). such returns are based on historical results and are not intended to suggest future performance. data for the nasdaq composite index and the motor vehicles and passenger car bodies public company group assumes an investment of $ 100 on january 1, 2019 and reinvestment of dividends. we have never declared or paid cash dividends on our common stock nor do we anticipate paying any such cash dividends in the foreseeable future. 31',
74
+ 'gold_answer': "The graph compares the cumulative total return from January 1, 2019, through December 31, 2023, of the company's common stock, the NASDAQ Composite Index, and a group of public companies with the same SIC code (3711 - Motor Vehicles and Passenger Car Bodies). The comparison assumes an initial investment of $100 on January 1, 2019, with reinvestment of dividends for the NASDAQ Composite Index and the Motor Vehicles and Passenger Car Bodies public company group.",
75
+ }
76
+
77
+ for idx, row in tqdm(df_eval.iterrows()):
78
+ inputs = tokenizer([formatting_func(example)], return_tensors="pt", padding=False).to(model.device)
79
+ input_length = inputs.input_ids.shape[-1]
80
+
81
+ with torch.no_grad():
82
+ output = model.generate(**inputs,
83
+ do_sample=False,
84
+ temperature=0.1,
85
+ max_new_tokens=64,
86
+ pad_token_id=tokenizer.eos_token_id,
87
+ use_cache=False,
88
+ )
89
+ response = tokenizer.decode(
90
+ output[0][input_length::], # response only, remove prompts
91
+ skip_special_tokens=True,
92
+ )
93
+ print(response)
94
+
95
+ ```
96
+ >>> The graph compares the cumulative total return on our common stock, the NASDAQ Composite Index, and a group of all public companies sharing the same SIC code as us, which is SIC code 3711, "Motor Vehicles and Passenger Car Bodies."