Spaces:
Sleeping
Sleeping
fix cache directory
Browse files- main.py +11 -11
- requirements.txt +1 -2
main.py
CHANGED
|
@@ -1,15 +1,16 @@
|
|
| 1 |
-
import
|
| 2 |
import re
|
| 3 |
import asyncio
|
| 4 |
from fastapi import FastAPI
|
| 5 |
from pydantic import BaseModel, Field
|
| 6 |
from huggingface_hub import InferenceClient
|
| 7 |
from typing import List
|
| 8 |
-
|
|
|
|
|
|
|
| 9 |
|
| 10 |
app = FastAPI()
|
| 11 |
client = InferenceClient("gpt2")
|
| 12 |
-
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
| 13 |
|
| 14 |
SYSTEM_PROMPT = "You are a very powerful AI to generate interesting stories for short-form content consumption. Make sure to hook the readers attention in the first few seconds. Make sure to be engaging and creative in your responses."
|
| 15 |
|
|
@@ -35,10 +36,11 @@ def generate(item: Item):
|
|
| 35 |
|
| 36 |
formatted_prompt = format_prompt(f"{SYSTEM_PROMPT}\n{item.prompt}", item.history)
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
|
|
|
| 40 |
|
| 41 |
-
|
| 42 |
formatted_prompt,
|
| 43 |
max_new_tokens=max_new_tokens,
|
| 44 |
temperature=temperature,
|
|
@@ -46,12 +48,10 @@ def generate(item: Item):
|
|
| 46 |
repetition_penalty=item.repetition_penalty,
|
| 47 |
do_sample=True,
|
| 48 |
seed=42,
|
| 49 |
-
stream=True,
|
| 50 |
-
details=True,
|
| 51 |
-
return_full_text=False,
|
| 52 |
)
|
| 53 |
-
|
| 54 |
-
output =
|
|
|
|
| 55 |
|
| 56 |
return output
|
| 57 |
|
|
|
|
| 1 |
+
import os
|
| 2 |
import re
|
| 3 |
import asyncio
|
| 4 |
from fastapi import FastAPI
|
| 5 |
from pydantic import BaseModel, Field
|
| 6 |
from huggingface_hub import InferenceClient
|
| 7 |
from typing import List
|
| 8 |
+
|
| 9 |
+
# Set the cache directory to a writable location
|
| 10 |
+
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface'
|
| 11 |
|
| 12 |
app = FastAPI()
|
| 13 |
client = InferenceClient("gpt2")
|
|
|
|
| 14 |
|
| 15 |
SYSTEM_PROMPT = "You are a very powerful AI to generate interesting stories for short-form content consumption. Make sure to hook the readers attention in the first few seconds. Make sure to be engaging and creative in your responses."
|
| 16 |
|
|
|
|
| 36 |
|
| 37 |
formatted_prompt = format_prompt(f"{SYSTEM_PROMPT}\n{item.prompt}", item.history)
|
| 38 |
|
| 39 |
+
# A simple approximation for token count
|
| 40 |
+
estimated_input_tokens = len(formatted_prompt.split())
|
| 41 |
+
max_new_tokens = min(item.max_new_tokens, MAX_TOTAL_TOKENS - estimated_input_tokens)
|
| 42 |
|
| 43 |
+
response = client.text_generation(
|
| 44 |
formatted_prompt,
|
| 45 |
max_new_tokens=max_new_tokens,
|
| 46 |
temperature=temperature,
|
|
|
|
| 48 |
repetition_penalty=item.repetition_penalty,
|
| 49 |
do_sample=True,
|
| 50 |
seed=42,
|
|
|
|
|
|
|
|
|
|
| 51 |
)
|
| 52 |
+
|
| 53 |
+
output = response.strip()
|
| 54 |
+
output = re.sub(r"\s+", " ", output)
|
| 55 |
|
| 56 |
return output
|
| 57 |
|
requirements.txt
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
fastapi
|
| 2 |
uvicorn
|
| 3 |
huggingface_hub
|
| 4 |
-
pydantic
|
| 5 |
-
transformers
|
|
|
|
| 1 |
fastapi
|
| 2 |
uvicorn
|
| 3 |
huggingface_hub
|
| 4 |
+
pydantic
|
|
|