pr0methium commited on
Commit
72c4724
·
verified ·
1 Parent(s): 13c78d8

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +8 -13
main.py CHANGED
@@ -7,16 +7,15 @@ from huggingface_hub import InferenceClient
7
  from typing import List
8
 
9
  # Set the cache directory to a writable location
10
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
11
 
12
  app = FastAPI()
13
- client = InferenceClient("facebook/opt-1.3b")
14
 
15
  SYSTEM_PROMPT = "You are a very powerful AI to generate interesting stories for short-form content consumption. Make sure to hook the readers attention in the first few seconds. Make sure to be engaging and creative in your responses."
16
 
17
  MAX_TOTAL_TOKENS = 2048
18
 
19
-
20
  class Item(BaseModel):
21
  prompt: str
22
  history: List[str] = []
@@ -25,21 +24,18 @@ class Item(BaseModel):
25
  top_p: float = Field(default=0.9, ge=0.0, le=1.0)
26
  repetition_penalty: float = Field(default=1.1, ge=0.0)
27
 
28
-
29
  def format_prompt(message, history):
30
- prompt = "".join(
31
- f"Human: {user_prompt}\nAI: {bot_response}\n"
32
- for user_prompt, bot_response in history
33
- )
34
  prompt += f"Human: {message}\nAI:"
35
  return prompt
36
 
37
-
38
  def generate(item: Item):
39
  temperature = max(float(item.temperature), 1e-2)
40
 
41
  formatted_prompt = format_prompt(f"{SYSTEM_PROMPT}\n{item.prompt}", item.history)
42
-
43
  # A simple approximation for token count
44
  estimated_input_tokens = len(formatted_prompt.split())
45
  max_new_tokens = min(item.max_new_tokens, MAX_TOTAL_TOKENS - estimated_input_tokens)
@@ -53,13 +49,12 @@ def generate(item: Item):
53
  do_sample=True,
54
  seed=42,
55
  )
56
-
57
  output = response.strip()
58
  output = re.sub(r"\s+", " ", output)
59
 
60
  return output
61
 
62
-
63
  @app.get("/generate/")
64
  async def generate_text(
65
  prompt: str,
@@ -80,4 +75,4 @@ async def generate_text(
80
 
81
  response = await asyncio.to_thread(generate, item)
82
 
83
- return {"response": response}
 
7
  from typing import List
8
 
9
  # Set the cache directory to a writable location
10
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface'
11
 
12
  app = FastAPI()
13
+ client = InferenceClient("EleutherAI/gpt-neo-125M")
14
 
15
  SYSTEM_PROMPT = "You are a very powerful AI to generate interesting stories for short-form content consumption. Make sure to hook the readers attention in the first few seconds. Make sure to be engaging and creative in your responses."
16
 
17
  MAX_TOTAL_TOKENS = 2048
18
 
 
19
  class Item(BaseModel):
20
  prompt: str
21
  history: List[str] = []
 
24
  top_p: float = Field(default=0.9, ge=0.0, le=1.0)
25
  repetition_penalty: float = Field(default=1.1, ge=0.0)
26
 
 
27
  def format_prompt(message, history):
28
+ prompt = ""
29
+ for user_prompt, bot_response in history:
30
+ prompt += f"Human: {user_prompt}\nAI: {bot_response}\n"
 
31
  prompt += f"Human: {message}\nAI:"
32
  return prompt
33
 
 
34
  def generate(item: Item):
35
  temperature = max(float(item.temperature), 1e-2)
36
 
37
  formatted_prompt = format_prompt(f"{SYSTEM_PROMPT}\n{item.prompt}", item.history)
38
+
39
  # A simple approximation for token count
40
  estimated_input_tokens = len(formatted_prompt.split())
41
  max_new_tokens = min(item.max_new_tokens, MAX_TOTAL_TOKENS - estimated_input_tokens)
 
49
  do_sample=True,
50
  seed=42,
51
  )
52
+
53
  output = response.strip()
54
  output = re.sub(r"\s+", " ", output)
55
 
56
  return output
57
 
 
58
  @app.get("/generate/")
59
  async def generate_text(
60
  prompt: str,
 
75
 
76
  response = await asyncio.to_thread(generate, item)
77
 
78
+ return {"response": response}