Spaces:
Sleeping
Sleeping
max new tokens now positive
Browse files
main.py
CHANGED
|
@@ -13,6 +13,7 @@ client = InferenceClient("openai-community/gpt2")
|
|
| 13 |
SYSTEM_PROMPT = "You are a very powerful AI to generate interesting stories for short-form content consumption. Make sure to hook the readers attention in the first few seconds. Make sure to be engaging and creative in your responses."
|
| 14 |
|
| 15 |
MAX_TOTAL_TOKENS = 1024
|
|
|
|
| 16 |
|
| 17 |
class Item(BaseModel):
|
| 18 |
prompt: str
|
|
@@ -37,9 +38,18 @@ def generate(item: Item):
|
|
| 37 |
|
| 38 |
formatted_prompt = format_prompt(f"{SYSTEM_PROMPT}, {item.prompt}", item.history)
|
| 39 |
|
| 40 |
-
#
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
stream = client.text_generation(
|
| 45 |
formatted_prompt,
|
|
|
|
| 13 |
SYSTEM_PROMPT = "You are a very powerful AI to generate interesting stories for short-form content consumption. Make sure to hook the readers attention in the first few seconds. Make sure to be engaging and creative in your responses."
|
| 14 |
|
| 15 |
MAX_TOTAL_TOKENS = 1024
|
| 16 |
+
TOKEN_COUNTING_TOKENS = 1 # Use a small number of tokens for counting
|
| 17 |
|
| 18 |
class Item(BaseModel):
|
| 19 |
prompt: str
|
|
|
|
| 38 |
|
| 39 |
formatted_prompt = format_prompt(f"{SYSTEM_PROMPT}, {item.prompt}", item.history)
|
| 40 |
|
| 41 |
+
# Count input tokens by generating a small number of tokens
|
| 42 |
+
token_count_response = client.text_generation(
|
| 43 |
+
formatted_prompt,
|
| 44 |
+
max_new_tokens=TOKEN_COUNTING_TOKENS,
|
| 45 |
+
details=True,
|
| 46 |
+
return_full_text=False
|
| 47 |
+
)
|
| 48 |
+
input_tokens = token_count_response.details.input_tokens
|
| 49 |
+
|
| 50 |
+
# Calculate available tokens for generation
|
| 51 |
+
available_tokens = MAX_TOTAL_TOKENS - input_tokens - TOKEN_COUNTING_TOKENS
|
| 52 |
+
max_new_tokens = min(item.max_new_tokens, available_tokens)
|
| 53 |
|
| 54 |
stream = client.text_generation(
|
| 55 |
formatted_prompt,
|