premalt commited on
Commit
72d1491
·
1 Parent(s): 4bba0ce

fix cache directory

Browse files
Files changed (2) hide show
  1. main.py +11 -11
  2. requirements.txt +1 -2
main.py CHANGED
@@ -1,15 +1,16 @@
1
- import uvicorn
2
  import re
3
  import asyncio
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel, Field
6
  from huggingface_hub import InferenceClient
7
  from typing import List
8
- from transformers import GPT2TokenizerFast
 
 
9
 
10
  app = FastAPI()
11
  client = InferenceClient("gpt2")
12
- tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
13
 
14
  SYSTEM_PROMPT = "You are a very powerful AI to generate interesting stories for short-form content consumption. Make sure to hook the readers attention in the first few seconds. Make sure to be engaging and creative in your responses."
15
 
@@ -35,10 +36,11 @@ def generate(item: Item):
35
 
36
  formatted_prompt = format_prompt(f"{SYSTEM_PROMPT}\n{item.prompt}", item.history)
37
 
38
- input_tokens = len(tokenizer.encode(formatted_prompt))
39
- max_new_tokens = min(item.max_new_tokens, MAX_TOTAL_TOKENS - input_tokens)
 
40
 
41
- stream = client.text_generation(
42
  formatted_prompt,
43
  max_new_tokens=max_new_tokens,
44
  temperature=temperature,
@@ -46,12 +48,10 @@ def generate(item: Item):
46
  repetition_penalty=item.repetition_penalty,
47
  do_sample=True,
48
  seed=42,
49
- stream=True,
50
- details=True,
51
- return_full_text=False,
52
  )
53
- output = "".join(chunk.token.text for chunk in stream)
54
- output = re.sub(r"\s+", " ", output).strip()
 
55
 
56
  return output
57
 
 
1
+ import os
2
  import re
3
  import asyncio
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel, Field
6
  from huggingface_hub import InferenceClient
7
  from typing import List
8
+
9
+ # Set the cache directory to a writable location
10
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface'
11
 
12
  app = FastAPI()
13
  client = InferenceClient("gpt2")
 
14
 
15
  SYSTEM_PROMPT = "You are a very powerful AI to generate interesting stories for short-form content consumption. Make sure to hook the readers attention in the first few seconds. Make sure to be engaging and creative in your responses."
16
 
 
36
 
37
  formatted_prompt = format_prompt(f"{SYSTEM_PROMPT}\n{item.prompt}", item.history)
38
 
39
+ # A simple approximation for token count
40
+ estimated_input_tokens = len(formatted_prompt.split())
41
+ max_new_tokens = min(item.max_new_tokens, MAX_TOTAL_TOKENS - estimated_input_tokens)
42
 
43
+ response = client.text_generation(
44
  formatted_prompt,
45
  max_new_tokens=max_new_tokens,
46
  temperature=temperature,
 
48
  repetition_penalty=item.repetition_penalty,
49
  do_sample=True,
50
  seed=42,
 
 
 
51
  )
52
+
53
+ output = response.strip()
54
+ output = re.sub(r"\s+", " ", output)
55
 
56
  return output
57
 
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  fastapi
2
  uvicorn
3
  huggingface_hub
4
- pydantic
5
- transformers
 
1
  fastapi
2
  uvicorn
3
  huggingface_hub
4
+ pydantic