PhishInf commited on
Commit
fa89c65
·
1 Parent(s): 82ad616

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ import fire
5
+ import torch
6
+ import transformers
7
+
8
+ from utils.modeling_hack import get_model
9
+ from utils.streaming import generate_stream
10
+
11
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
12
+
13
+ tok_ins = "\n\n### Instruction:\n"
14
+ tok_res = "\n\n### Response:\n"
15
+ prompt_input = tok_ins + "{instruction}" + tok_res
16
+
17
+
18
+ def main(
19
+ model_path: str,
20
+ max_input_length: int = 512,
21
+ max_generate_length: int = 1024,
22
+ model_type: str = 'chat',
23
+ rope_scaling: Optional[str] = None,
24
+ rope_factor: float = 8.0,
25
+ streaming: bool = True # streaming is always enabled now
26
+ ):
27
+ assert transformers.__version__.startswith('4.34')
28
+ assert model_type.lower() in ['chat', 'base'], f"model_type must be one of ['chat', 'base'], got {model_type}"
29
+ assert rope_scaling in [None, 'yarn',
30
+ 'dynamic'], f"rope_scaling must be one of [None, 'yarn', 'dynamic'], got {rope_scaling}"
31
+
32
+ model, tokenizer, generation_config = get_model(model_path=model_path, rope_scaling=rope_scaling,
33
+ rope_factor=rope_factor)
34
+ generation_config.max_new_tokens = max_generate_length
35
+ generation_config.max_length = max_input_length + max_generate_length
36
+
37
+ device = torch.cuda.current_device()
38
+ sess_text = ""
39
+ while True:
40
+ raw_text = input("prompt(\"exit\" to end, \"clear\" to clear session) >>> ")
41
+ if not raw_text:
42
+ print('prompt should not be empty!')
43
+ continue
44
+ if raw_text.strip() == "exit":
45
+ print('session ended.')
46
+ break
47
+ if raw_text.strip() == "clear":
48
+ print('session cleared.')
49
+ sess_text = ""
50
+ continue
51
+
52
+ query_text = raw_text.strip()
53
+ sess_text += tok_ins + query_text
54
+ if model_type == 'chat':
55
+ input_text = prompt_input.format_map({'instruction': sess_text.split(tok_ins, 1)[1]})
56
+ else:
57
+ input_text = query_text
58
+ inputs = tokenizer(input_text, return_tensors='pt', truncation=True, max_length=max_input_length)
59
+ inputs = {k: v.to(device) for k, v in inputs.items()}
60
+
61
+ print('=' * 100)
62
+ for text in generate_stream(model, tokenizer, inputs['input_ids'], inputs['attention_mask'],
63
+ generation_config=generation_config):
64
+ print(text, end='', flush=True)
65
+ print('')
66
+ print("=" * 100)
67
+
68
+
69
+ if __name__ == "__main__":
70
+ fire.Fire(main)