pr0methium commited on
Commit
3d4ff6b
·
verified ·
1 Parent(s): 72c4724

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +39 -28
main.py CHANGED
@@ -1,72 +1,83 @@
1
- import os
2
  import re
3
  import asyncio
4
  from fastapi import FastAPI
5
- from pydantic import BaseModel, Field
6
  from huggingface_hub import InferenceClient
7
  from typing import List
8
 
9
- # Set the cache directory to a writable location
10
- os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface'
11
 
12
  app = FastAPI()
13
- client = InferenceClient("EleutherAI/gpt-neo-125M")
14
 
15
  SYSTEM_PROMPT = "You are a very powerful AI to generate interesting stories for short-form content consumption. Make sure to hook the readers attention in the first few seconds. Make sure to be engaging and creative in your responses."
16
 
17
- MAX_TOTAL_TOKENS = 2048
18
 
19
  class Item(BaseModel):
20
  prompt: str
21
  history: List[str] = []
22
- temperature: float = Field(default=0.8, ge=0.0, le=1.0)
23
- max_new_tokens: int = Field(default=1024, ge=1, le=MAX_TOTAL_TOKENS)
24
- top_p: float = Field(default=0.9, ge=0.0, le=1.0)
25
- repetition_penalty: float = Field(default=1.1, ge=0.0)
 
 
26
 
27
  def format_prompt(message, history):
28
- prompt = ""
29
  for user_prompt, bot_response in history:
30
- prompt += f"Human: {user_prompt}\nAI: {bot_response}\n"
31
- prompt += f"Human: {message}\nAI:"
 
32
  return prompt
33
 
 
34
  def generate(item: Item):
35
  temperature = max(float(item.temperature), 1e-2)
 
 
 
 
 
 
 
 
36
 
37
- formatted_prompt = format_prompt(f"{SYSTEM_PROMPT}\n{item.prompt}", item.history)
38
-
39
- # A simple approximation for token count
40
- estimated_input_tokens = len(formatted_prompt.split())
41
- max_new_tokens = min(item.max_new_tokens, MAX_TOTAL_TOKENS - estimated_input_tokens)
42
-
43
- response = client.text_generation(
44
  formatted_prompt,
45
- max_new_tokens=max_new_tokens,
46
  temperature=temperature,
 
47
  top_p=float(item.top_p),
48
  repetition_penalty=item.repetition_penalty,
49
  do_sample=True,
50
  seed=42,
 
 
 
51
  )
52
-
53
- output = response.strip()
54
- output = re.sub(r"\s+", " ", output)
 
55
 
56
  return output
57
 
 
58
  @app.get("/generate/")
59
  async def generate_text(
60
  prompt: str,
61
  history: List[str] = [],
62
- temperature: float = 0.8,
63
- max_new_tokens: int = 1024,
64
- top_p: float = 0.9,
65
- repetition_penalty: float = 1.1,
 
66
  ):
67
  item = Item(
68
  prompt=prompt,
69
  history=history,
 
70
  temperature=temperature,
71
  max_new_tokens=max_new_tokens,
72
  top_p=top_p,
 
1
+ import uvicorn
2
  import re
3
  import asyncio
4
  from fastapi import FastAPI
5
+ from pydantic import BaseModel
6
  from huggingface_hub import InferenceClient
7
  from typing import List
8
 
 
 
9
 
10
  app = FastAPI()
11
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
12
 
13
  SYSTEM_PROMPT = "You are a very powerful AI to generate interesting stories for short-form content consumption. Make sure to hook the readers attention in the first few seconds. Make sure to be engaging and creative in your responses."
14
 
 
15
 
16
  class Item(BaseModel):
17
  prompt: str
18
  history: List[str] = []
19
+ # system_prompt: str = "You are a very powerful AI assistant."
20
+ temperature: float = 0.0
21
+ max_new_tokens: int = 1048
22
+ top_p: float = 0.15
23
+ repetition_penalty: float = 1.0
24
+
25
 
26
  def format_prompt(message, history):
27
+ prompt = "<s>"
28
  for user_prompt, bot_response in history:
29
+ prompt += f"[INST] {user_prompt} [/INST]"
30
+ prompt += f" {bot_response}</s> "
31
+ prompt += f"[INST] {message} [/INST]"
32
  return prompt
33
 
34
+
35
  def generate(item: Item):
36
  temperature = max(float(item.temperature), 1e-2)
37
+ # generate_kwargs = dict(
38
+ # temperature=temperature,
39
+ # max_new_tokens=item.max_new_tokens,
40
+ # top_p=float(item.top_p),
41
+ # repetition_penalty=item.repetition_penalty,
42
+ # do_sample=True,
43
+ # seed=42,
44
+ # )
45
 
46
+ formatted_prompt = format_prompt(f"{SYSTEM_PROMPT}, {item.prompt}", item.history)
47
+ stream = client.text_generation(
 
 
 
 
 
48
  formatted_prompt,
 
49
  temperature=temperature,
50
+ max_new_tokens=item.max_new_tokens,
51
  top_p=float(item.top_p),
52
  repetition_penalty=item.repetition_penalty,
53
  do_sample=True,
54
  seed=42,
55
+ stream=True,
56
+ details=True,
57
+ return_full_text=False,
58
  )
59
+ output = "".join(response.token.text for response in stream)
60
+ # Remove unwanted sequences or patterns (e.g., <s>, [/INST], etc.)
61
+ output = re.sub(r"<[^>]+>", "", output) # Remove any HTML-like tags
62
+ output = re.sub(r"\s+", " ", output).strip() # Clean up extra whitespace
63
 
64
  return output
65
 
66
+
67
  @app.get("/generate/")
68
  async def generate_text(
69
  prompt: str,
70
  history: List[str] = [],
71
+ # system_prompt: str = "You are a very powerful AI assistant.",
72
+ temperature: float = 0.0,
73
+ max_new_tokens: int = 1048,
74
+ top_p: float = 0.15,
75
+ repetition_penalty: float = 1.0,
76
  ):
77
  item = Item(
78
  prompt=prompt,
79
  history=history,
80
+ # system_prompt=system_prompt,
81
  temperature=temperature,
82
  max_new_tokens=max_new_tokens,
83
  top_p=top_p,