bhkkhjgkk commited on
Commit
79c2343
·
verified ·
1 Parent(s): cecd25d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +36 -35
main.py CHANGED
@@ -6,53 +6,54 @@ import uvicorn
6
 
7
  app = FastAPI()
8
 
9
-
10
  client = InferenceClient("deepseek-ai/deepseek-llm-67b-chat")
11
 
12
  class Item(BaseModel):
13
  prompt: str
14
  history: list
15
  system_prompt: str
16
- temperature: float = 0.7
17
- max_new_tokens: int = 1024
18
- top_p: float = 0.9
19
- repetition_penalty: float = 1.1
20
-
21
- def format_prompt(message, history, system_prompt):
22
- prompt = f"<|begin▁of▁sentence|>{system_prompt}\n\n" if system_prompt else ""
23
-
24
- for user_msg, bot_res in history:
25
- prompt += f"User: {user_msg}\n\nAssistant: {bot_res}\n\n"
26
-
27
- prompt += f"User: {message}\n\nAssistant: "
 
 
 
 
28
  return prompt
29
 
30
  async def generate_stream(item: Item):
31
- generate_kwargs = {
32
- "temperature": max(item.temperature, 0.01),
33
- "max_new_tokens": item.max_new_tokens,
34
- "top_p": item.top_p,
35
- "repetition_penalty": item.repetition_penalty,
36
- "do_sample": True,
37
- "seed": 42,
38
- }
39
-
40
-
41
- formatted_prompt = format_prompt(
42
- item.prompt,
43
- item.history,
44
- item.system_prompt
45
- )
46
-
47
- stream = client.text_generation(
48
- formatted_prompt,
49
- stream=True,
50
- **generate_kwargs
51
  )
52
 
 
 
 
 
 
 
53
  for response in stream:
54
- yield response
55
 
56
  @app.post("/generate/")
57
  async def generate_text(item: Item):
58
- return StreamingResponse(generate_stream(item), media_type="text/plain")
 
6
 
7
  app = FastAPI()
8
 
 
9
  client = InferenceClient("deepseek-ai/deepseek-llm-67b-chat")
10
 
11
  class Item(BaseModel):
12
  prompt: str
13
  history: list
14
  system_prompt: str
15
+ temperature: float = 0.0
16
+ max_new_tokens: int = 1048
17
+ top_p: float = 0.15
18
+ repetition_penalty: float = 1.0
19
+
20
+ def format_prompt(message, history):
21
+ print("````")
22
+ print(message)
23
+ print("++++")
24
+ print(history)
25
+ print("````")
26
+ prompt = "<s>"
27
+ for user_prompt, bot_response in history:
28
+ prompt += f"[INST] {user_prompt} [/INST]"
29
+ prompt += f" {bot_response}</s> "
30
+ prompt += f"[INST] {message} [/INST]"
31
  return prompt
32
 
33
  async def generate_stream(item: Item):
34
+ temperature = float(item.temperature)
35
+ if temperature < 1e-2:
36
+ temperature = 1e-2
37
+ top_p = float(item.top_p)
38
+
39
+ generate_kwargs = dict(
40
+ temperature=temperature,
41
+ max_new_tokens=item.max_new_tokens,
42
+ top_p=top_p,
43
+ repetition_penalty=item.repetition_penalty,
44
+ do_sample=True,
45
+ seed=42,
 
 
 
 
 
 
 
 
46
  )
47
 
48
+ formatted_prompt = format_prompt(f"{item.system_prompt} [/INST] Ok..! </s> [INST] {item.prompt}", item.history)
49
+ print(formatted_prompt)
50
+ print("=======")
51
+ print(item.history)
52
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
53
+
54
  for response in stream:
55
+ yield response.token.text # Stream each token as it's received
56
 
57
  @app.post("/generate/")
58
  async def generate_text(item: Item):
59
+ return StreamingResponse(generate_stream(item), media_type="text/plain")