PhishInf commited on
Commit
2e3eeda
·
1 Parent(s): fa89c65

Delete my.py

Browse files
Files changed (1) hide show
  1. my.py +0 -70
my.py DELETED
@@ -1,70 +0,0 @@
1
- import os
2
- from typing import Optional
3
-
4
- import fire
5
- import torch
6
- import transformers
7
-
8
- from utils.modeling_hack import get_model
9
- from utils.streaming import generate_stream
10
-
11
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
12
-
13
- tok_ins = "\n\n### Instruction:\n"
14
- tok_res = "\n\n### Response:\n"
15
- prompt_input = tok_ins + "{instruction}" + tok_res
16
-
17
-
18
- def main(
19
- model_path: str,
20
- max_input_length: int = 512,
21
- max_generate_length: int = 1024,
22
- model_type: str = 'chat',
23
- rope_scaling: Optional[str] = None,
24
- rope_factor: float = 8.0,
25
- streaming: bool = True # streaming is always enabled now
26
- ):
27
- assert transformers.__version__.startswith('4.34')
28
- assert model_type.lower() in ['chat', 'base'], f"model_type must be one of ['chat', 'base'], got {model_type}"
29
- assert rope_scaling in [None, 'yarn',
30
- 'dynamic'], f"rope_scaling must be one of [None, 'yarn', 'dynamic'], got {rope_scaling}"
31
-
32
- model, tokenizer, generation_config = get_model(model_path=model_path, rope_scaling=rope_scaling,
33
- rope_factor=rope_factor)
34
- generation_config.max_new_tokens = max_generate_length
35
- generation_config.max_length = max_input_length + max_generate_length
36
-
37
- device = torch.cuda.current_device()
38
- sess_text = ""
39
- while True:
40
- raw_text = input("prompt(\"exit\" to end, \"clear\" to clear session) >>> ")
41
- if not raw_text:
42
- print('prompt should not be empty!')
43
- continue
44
- if raw_text.strip() == "exit":
45
- print('session ended.')
46
- break
47
- if raw_text.strip() == "clear":
48
- print('session cleared.')
49
- sess_text = ""
50
- continue
51
-
52
- query_text = raw_text.strip()
53
- sess_text += tok_ins + query_text
54
- if model_type == 'chat':
55
- input_text = prompt_input.format_map({'instruction': sess_text.split(tok_ins, 1)[1]})
56
- else:
57
- input_text = query_text
58
- inputs = tokenizer(input_text, return_tensors='pt', truncation=True, max_length=max_input_length)
59
- inputs = {k: v.to(device) for k, v in inputs.items()}
60
-
61
- print('=' * 100)
62
- for text in generate_stream(model, tokenizer, inputs['input_ids'], inputs['attention_mask'],
63
- generation_config=generation_config):
64
- print(text, end='', flush=True)
65
- print('')
66
- print("=" * 100)
67
-
68
-
69
- if __name__ == "__main__":
70
- fire.Fire(main)