premalt commited on
Commit
97bf850
·
1 Parent(s): 5867919

max new tokens now positive

Browse files
Files changed (1) hide show
  1. main.py +13 -3
main.py CHANGED
@@ -13,6 +13,7 @@ client = InferenceClient("openai-community/gpt2")
13
  SYSTEM_PROMPT = "You are a very powerful AI to generate interesting stories for short-form content consumption. Make sure to hook the readers attention in the first few seconds. Make sure to be engaging and creative in your responses."
14
 
15
  MAX_TOTAL_TOKENS = 1024
 
16
 
17
  class Item(BaseModel):
18
  prompt: str
@@ -37,9 +38,18 @@ def generate(item: Item):
37
 
38
  formatted_prompt = format_prompt(f"{SYSTEM_PROMPT}, {item.prompt}", item.history)
39
 
40
- # Use the text_generation method to get the number of input tokens
41
- input_tokens = client.text_generation(formatted_prompt, max_new_tokens=0).details.input_tokens
42
- max_new_tokens = min(item.max_new_tokens, MAX_TOTAL_TOKENS - input_tokens)
 
 
 
 
 
 
 
 
 
43
 
44
  stream = client.text_generation(
45
  formatted_prompt,
 
13
  SYSTEM_PROMPT = "You are a very powerful AI to generate interesting stories for short-form content consumption. Make sure to hook the readers attention in the first few seconds. Make sure to be engaging and creative in your responses."
14
 
15
  MAX_TOTAL_TOKENS = 1024
16
+ TOKEN_COUNTING_TOKENS = 1 # Use a small number of tokens for counting
17
 
18
  class Item(BaseModel):
19
  prompt: str
 
38
 
39
  formatted_prompt = format_prompt(f"{SYSTEM_PROMPT}, {item.prompt}", item.history)
40
 
41
+ # Count input tokens by generating a small number of tokens
42
+ token_count_response = client.text_generation(
43
+ formatted_prompt,
44
+ max_new_tokens=TOKEN_COUNTING_TOKENS,
45
+ details=True,
46
+ return_full_text=False
47
+ )
48
+ input_tokens = token_count_response.details.input_tokens
49
+
50
+ # Calculate available tokens for generation
51
+ available_tokens = MAX_TOTAL_TOKENS - input_tokens - TOKEN_COUNTING_TOKENS
52
+ max_new_tokens = min(item.max_new_tokens, available_tokens)
53
 
54
  stream = client.text_generation(
55
  formatted_prompt,