Muhammadidrees commited on
Commit
2856d25
·
verified ·
1 Parent(s): 52e77eb

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -97
app.py DELETED
@@ -1,97 +0,0 @@
1
- import os, json, itertools, bisect, gc
2
-
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
4
- import transformers
5
- import torch
6
- from accelerate import Accelerator
7
- import accelerate
8
- import time
9
- import gradio as gr
10
-
11
- model = None
12
- tokenizer = None
13
- generator = None
14
-
15
- def load_model(model_name = "zl111/ChatDoctor", eight_bit=0, device_map="auto"):
16
- global model, tokenizer, generator
17
-
18
- print("Loading "+model_name+"...")
19
-
20
- if device_map == "zero":
21
- device_map = "balanced_low_0"
22
-
23
- # config
24
- gpu_count = torch.cuda.device_count()
25
- print('gpu_count', gpu_count)
26
-
27
- tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name)
28
- model = transformers.LlamaForCausalLM.from_pretrained(
29
- model_name,
30
- #device_map=device_map,
31
- #device_map="auto",
32
- torch_dtype=torch.float16,
33
- #max_memory = {0: "14GB", 1: "14GB", 2: "14GB", 3: "14GB",4: "14GB",5: "14GB",6: "14GB",7: "14GB"},
34
- #load_in_8bit=eight_bit,
35
- low_cpu_mem_usage=True,
36
- load_in_8bit=False,
37
- cache_dir="cache"
38
- ).cuda()
39
-
40
- generator = model.generate
41
-
42
- load_model()
43
-
44
- history = []
45
-
46
- def go():
47
- invitation = "Assistant: "
48
- human_invitation = "Human: "
49
-
50
- # input
51
- msg = input(human_invitation)
52
- print("")
53
-
54
- history.append(human_invitation + msg)
55
-
56
- fulltext = "\n\n".join(history) + "\n\n" + invitation
57
-
58
- # print('SENDING==========')
59
- # print(fulltext)
60
- # print('==========')
61
-
62
- generated_text = ""
63
- gen_in = tokenizer(fulltext, return_tensors="pt").input_ids.cuda()
64
- in_tokens = len(gen_in)
65
- with torch.no_grad():
66
- generated_ids = generator(
67
- gen_in,
68
- max_new_tokens=200,
69
- use_cache=True,
70
- pad_token_id=tokenizer.eos_token_id,
71
- num_return_sequences=1,
72
- do_sample=True,
73
- repetition_penalty=1.1, # 1.0 means 'off'. unfortunately if we penalize it it will not output Sphynx:
74
- temperature=0.5, # default: 1.0
75
- top_k = 50, # default: 50
76
- top_p = 1.0, # default: 1.0
77
- early_stopping=True,
78
- )
79
- generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] # for some reason, batch_decode returns an array of one element?
80
-
81
- text_without_prompt = generated_text[len(fulltext):]
82
-
83
- response = text_without_prompt
84
-
85
- response = response.split(human_invitation)[0]
86
-
87
- response.strip()
88
-
89
- print(invitation + response)
90
-
91
- print("")
92
-
93
- history.append(invitation + response)
94
-
95
- while True:
96
- go()
97
-