TheUntraceable commited on
Commit
13c78d8
·
1 Parent(s): 013032b

Format with Ruff

Browse files
Files changed (1) hide show
  1. main.py +9 -5
main.py CHANGED
@@ -7,14 +7,15 @@ from huggingface_hub import InferenceClient
7
  from typing import List
8
 
9
  # Set the cache directory to a writable location
10
- os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface'
11
 
12
  app = FastAPI()
13
  client = InferenceClient("facebook/opt-1.3b")
14
 
15
  SYSTEM_PROMPT = "You are a very powerful AI to generate interesting stories for short-form content consumption. Make sure to hook the readers attention in the first few seconds. Make sure to be engaging and creative in your responses."
16
 
17
- MAX_TOTAL_TOKENS = 2048
 
18
 
19
  class Item(BaseModel):
20
  prompt: str
@@ -24,6 +25,7 @@ class Item(BaseModel):
24
  top_p: float = Field(default=0.9, ge=0.0, le=1.0)
25
  repetition_penalty: float = Field(default=1.1, ge=0.0)
26
 
 
27
  def format_prompt(message, history):
28
  prompt = "".join(
29
  f"Human: {user_prompt}\nAI: {bot_response}\n"
@@ -32,11 +34,12 @@ def format_prompt(message, history):
32
  prompt += f"Human: {message}\nAI:"
33
  return prompt
34
 
 
35
  def generate(item: Item):
36
  temperature = max(float(item.temperature), 1e-2)
37
 
38
  formatted_prompt = format_prompt(f"{SYSTEM_PROMPT}\n{item.prompt}", item.history)
39
-
40
  # A simple approximation for token count
41
  estimated_input_tokens = len(formatted_prompt.split())
42
  max_new_tokens = min(item.max_new_tokens, MAX_TOTAL_TOKENS - estimated_input_tokens)
@@ -50,12 +53,13 @@ def generate(item: Item):
50
  do_sample=True,
51
  seed=42,
52
  )
53
-
54
  output = response.strip()
55
  output = re.sub(r"\s+", " ", output)
56
 
57
  return output
58
 
 
59
  @app.get("/generate/")
60
  async def generate_text(
61
  prompt: str,
@@ -76,4 +80,4 @@ async def generate_text(
76
 
77
  response = await asyncio.to_thread(generate, item)
78
 
79
- return {"response": response}
 
7
  from typing import List
8
 
9
  # Set the cache directory to a writable location
10
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
11
 
12
  app = FastAPI()
13
  client = InferenceClient("facebook/opt-1.3b")
14
 
15
  SYSTEM_PROMPT = "You are a very powerful AI to generate interesting stories for short-form content consumption. Make sure to hook the readers attention in the first few seconds. Make sure to be engaging and creative in your responses."
16
 
17
+ MAX_TOTAL_TOKENS = 2048
18
+
19
 
20
  class Item(BaseModel):
21
  prompt: str
 
25
  top_p: float = Field(default=0.9, ge=0.0, le=1.0)
26
  repetition_penalty: float = Field(default=1.1, ge=0.0)
27
 
28
+
29
  def format_prompt(message, history):
30
  prompt = "".join(
31
  f"Human: {user_prompt}\nAI: {bot_response}\n"
 
34
  prompt += f"Human: {message}\nAI:"
35
  return prompt
36
 
37
+
38
  def generate(item: Item):
39
  temperature = max(float(item.temperature), 1e-2)
40
 
41
  formatted_prompt = format_prompt(f"{SYSTEM_PROMPT}\n{item.prompt}", item.history)
42
+
43
  # A simple approximation for token count
44
  estimated_input_tokens = len(formatted_prompt.split())
45
  max_new_tokens = min(item.max_new_tokens, MAX_TOTAL_TOKENS - estimated_input_tokens)
 
53
  do_sample=True,
54
  seed=42,
55
  )
56
+
57
  output = response.strip()
58
  output = re.sub(r"\s+", " ", output)
59
 
60
  return output
61
 
62
+
63
  @app.get("/generate/")
64
  async def generate_text(
65
  prompt: str,
 
80
 
81
  response = await asyncio.to_thread(generate, item)
82
 
83
+ return {"response": response}