bhkkhjgkk commited on
Commit
cecd25d
·
verified ·
1 Parent(s): ddb69f0

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +35 -36
main.py CHANGED
@@ -6,53 +6,52 @@ import uvicorn
6
 
7
  app = FastAPI()
8
 
9
- client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
 
10
 
11
  class Item(BaseModel):
12
  prompt: str
13
  history: list
14
  system_prompt: str
15
- temperature: float = 0.0
16
- max_new_tokens: int = 1048
17
- top_p: float = 0.15
18
- repetition_penalty: float = 1.0
19
-
20
- def format_prompt(message, history):
21
- print("````")
22
- print(message)
23
- print("++++")
24
- print(history)
25
- print("````")
26
- prompt = "<s>"
27
- for user_prompt, bot_response in history:
28
- prompt += f"[INST] {user_prompt} [/INST]"
29
- prompt += f" {bot_response}</s> "
30
- prompt += f"[INST] {message} [/INST]"
31
  return prompt
32
 
33
  async def generate_stream(item: Item):
34
- temperature = float(item.temperature)
35
- if temperature < 1e-2:
36
- temperature = 1e-2
37
- top_p = float(item.top_p)
38
-
39
- generate_kwargs = dict(
40
- temperature=temperature,
41
- max_new_tokens=item.max_new_tokens,
42
- top_p=top_p,
43
- repetition_penalty=item.repetition_penalty,
44
- do_sample=True,
45
- seed=42,
 
 
 
 
 
 
 
 
46
  )
47
-
48
- formatted_prompt = format_prompt(f"{item.system_prompt} [/INST] Ok..! </s> [INST] {item.prompt}", item.history)
49
- print(formatted_prompt)
50
- print("=======")
51
- print(item.history)
52
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
53
 
54
  for response in stream:
55
- yield response.token.text # Stream each token as it's received
56
 
57
  @app.post("/generate/")
58
  async def generate_text(item: Item):
 
6
 
7
  app = FastAPI()
8
 
9
+
10
+ client = InferenceClient("deepseek-ai/deepseek-llm-67b-chat")
11
 
12
  class Item(BaseModel):
13
  prompt: str
14
  history: list
15
  system_prompt: str
16
+ temperature: float = 0.7
17
+ max_new_tokens: int = 1024
18
+ top_p: float = 0.9
19
+ repetition_penalty: float = 1.1
20
+
21
+ def format_prompt(message, history, system_prompt):
22
+ prompt = f"<|begin▁of▁sentence|>{system_prompt}\n\n" if system_prompt else ""
23
+
24
+ for user_msg, bot_res in history:
25
+ prompt += f"User: {user_msg}\n\nAssistant: {bot_res}\n\n"
26
+
27
+ prompt += f"User: {message}\n\nAssistant: "
 
 
 
 
28
  return prompt
29
 
30
  async def generate_stream(item: Item):
31
+ generate_kwargs = {
32
+ "temperature": max(item.temperature, 0.01),
33
+ "max_new_tokens": item.max_new_tokens,
34
+ "top_p": item.top_p,
35
+ "repetition_penalty": item.repetition_penalty,
36
+ "do_sample": True,
37
+ "seed": 42,
38
+ }
39
+
40
+
41
+ formatted_prompt = format_prompt(
42
+ item.prompt,
43
+ item.history,
44
+ item.system_prompt
45
+ )
46
+
47
+ stream = client.text_generation(
48
+ formatted_prompt,
49
+ stream=True,
50
+ **generate_kwargs
51
  )
 
 
 
 
 
 
52
 
53
  for response in stream:
54
+ yield response
55
 
56
  @app.post("/generate/")
57
  async def generate_text(item: Item):