premalt commited on
Commit
4bba0ce
·
1 Parent(s): 97bf850

fix input tokens

Browse files
Files changed (2) hide show
  1. main.py +15 -30
  2. requirements.txt +2 -1
main.py CHANGED
@@ -5,56 +5,43 @@ from fastapi import FastAPI
5
  from pydantic import BaseModel, Field
6
  from huggingface_hub import InferenceClient
7
  from typing import List
8
-
9
 
10
  app = FastAPI()
11
- client = InferenceClient("openai-community/gpt2")
 
12
 
13
  SYSTEM_PROMPT = "You are a very powerful AI to generate interesting stories for short-form content consumption. Make sure to hook the readers attention in the first few seconds. Make sure to be engaging and creative in your responses."
14
 
15
  MAX_TOTAL_TOKENS = 1024
16
- TOKEN_COUNTING_TOKENS = 1 # Use a small number of tokens for counting
17
 
18
  class Item(BaseModel):
19
  prompt: str
20
  history: List[str] = []
21
- temperature: float = Field(default=0.0, ge=0.0, le=1.0)
22
  max_new_tokens: int = Field(default=512, ge=1, le=MAX_TOTAL_TOKENS)
23
- top_p: float = Field(default=0.15, ge=0.0, le=1.0)
24
  repetition_penalty: float = Field(default=1.0, ge=0.0)
25
 
26
-
27
  def format_prompt(message, history):
28
- prompt = "<s>"
29
  for user_prompt, bot_response in history:
30
- prompt += f"[INST] {user_prompt} [/INST]"
31
- prompt += f" {bot_response}</s> "
32
- prompt += f"[INST] {message} [/INST]"
33
  return prompt
34
 
35
-
36
  def generate(item: Item):
37
  temperature = max(float(item.temperature), 1e-2)
38
 
39
- formatted_prompt = format_prompt(f"{SYSTEM_PROMPT}, {item.prompt}", item.history)
40
-
41
- # Count input tokens by generating a small number of tokens
42
- token_count_response = client.text_generation(
43
- formatted_prompt,
44
- max_new_tokens=TOKEN_COUNTING_TOKENS,
45
- details=True,
46
- return_full_text=False
47
- )
48
- input_tokens = token_count_response.details.input_tokens
49
 
50
- # Calculate available tokens for generation
51
- available_tokens = MAX_TOTAL_TOKENS - input_tokens - TOKEN_COUNTING_TOKENS
52
- max_new_tokens = min(item.max_new_tokens, available_tokens)
53
 
54
  stream = client.text_generation(
55
  formatted_prompt,
56
- temperature=temperature,
57
  max_new_tokens=max_new_tokens,
 
58
  top_p=float(item.top_p),
59
  repetition_penalty=item.repetition_penalty,
60
  do_sample=True,
@@ -63,20 +50,18 @@ def generate(item: Item):
63
  details=True,
64
  return_full_text=False,
65
  )
66
- output = "".join(response.token.text for response in stream)
67
- output = re.sub(r"<[^>]+>", "", output)
68
  output = re.sub(r"\s+", " ", output).strip()
69
 
70
  return output
71
 
72
-
73
  @app.get("/generate/")
74
  async def generate_text(
75
  prompt: str,
76
  history: List[str] = [],
77
- temperature: float = 0.0,
78
  max_new_tokens: int = 512,
79
- top_p: float = 0.15,
80
  repetition_penalty: float = 1.0,
81
  ):
82
  item = Item(
 
5
  from pydantic import BaseModel, Field
6
  from huggingface_hub import InferenceClient
7
  from typing import List
8
+ from transformers import GPT2TokenizerFast
9
 
10
  app = FastAPI()
11
+ client = InferenceClient("gpt2")
12
+ tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
13
 
14
  SYSTEM_PROMPT = "You are a very powerful AI to generate interesting stories for short-form content consumption. Make sure to hook the readers attention in the first few seconds. Make sure to be engaging and creative in your responses."
15
 
16
  MAX_TOTAL_TOKENS = 1024
 
17
 
18
  class Item(BaseModel):
19
  prompt: str
20
  history: List[str] = []
21
+ temperature: float = Field(default=0.7, ge=0.0, le=1.0)
22
  max_new_tokens: int = Field(default=512, ge=1, le=MAX_TOTAL_TOKENS)
23
+ top_p: float = Field(default=0.9, ge=0.0, le=1.0)
24
  repetition_penalty: float = Field(default=1.0, ge=0.0)
25
 
 
26
  def format_prompt(message, history):
27
+ prompt = ""
28
  for user_prompt, bot_response in history:
29
+ prompt += f"Human: {user_prompt}\nAI: {bot_response}\n"
30
+ prompt += f"Human: {message}\nAI:"
 
31
  return prompt
32
 
 
33
  def generate(item: Item):
34
  temperature = max(float(item.temperature), 1e-2)
35
 
36
+ formatted_prompt = format_prompt(f"{SYSTEM_PROMPT}\n{item.prompt}", item.history)
 
 
 
 
 
 
 
 
 
37
 
38
+ input_tokens = len(tokenizer.encode(formatted_prompt))
39
+ max_new_tokens = min(item.max_new_tokens, MAX_TOTAL_TOKENS - input_tokens)
 
40
 
41
  stream = client.text_generation(
42
  formatted_prompt,
 
43
  max_new_tokens=max_new_tokens,
44
+ temperature=temperature,
45
  top_p=float(item.top_p),
46
  repetition_penalty=item.repetition_penalty,
47
  do_sample=True,
 
50
  details=True,
51
  return_full_text=False,
52
  )
53
+ output = "".join(chunk.token.text for chunk in stream)
 
54
  output = re.sub(r"\s+", " ", output).strip()
55
 
56
  return output
57
 
 
58
  @app.get("/generate/")
59
  async def generate_text(
60
  prompt: str,
61
  history: List[str] = [],
62
+ temperature: float = 0.7,
63
  max_new_tokens: int = 512,
64
+ top_p: float = 0.9,
65
  repetition_penalty: float = 1.0,
66
  ):
67
  item = Item(
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  fastapi
2
  uvicorn
3
  huggingface_hub
4
- pydantic
 
 
1
  fastapi
2
  uvicorn
3
  huggingface_hub
4
+ pydantic
5
+ transformers