vedant01 commited on
Commit
de9f38e
·
verified ·
1 Parent(s): 473963a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +24 -12
main.py CHANGED
@@ -2,7 +2,7 @@ from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from huggingface_hub import InferenceClient
4
  import uvicorn
5
-
6
 
7
  app = FastAPI()
8
 
@@ -10,8 +10,8 @@ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
10
 
11
  class Item(BaseModel):
12
  prompt: str
13
- history: list
14
- system_prompt: str
15
  temperature: float = 0.0
16
  max_new_tokens: int = 1048
17
  top_p: float = 0.15
@@ -26,15 +26,11 @@ def format_prompt(message, history):
26
  return prompt
27
 
28
  def generate(item: Item):
29
- temperature = float(item.temperature)
30
- if temperature < 1e-2:
31
- temperature = 1e-2
32
- top_p = float(item.top_p)
33
-
34
  generate_kwargs = dict(
35
  temperature=temperature,
36
  max_new_tokens=item.max_new_tokens,
37
- top_p=top_p,
38
  repetition_penalty=item.repetition_penalty,
39
  do_sample=True,
40
  seed=42,
@@ -48,7 +44,23 @@ def generate(item: Item):
48
  output += response.token.text
49
  return output
50
 
51
- @app.post("/generate/")
52
- async def generate_text(item: Item):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  return {"response": generate(item)}
54
-
 
2
  from pydantic import BaseModel
3
  from huggingface_hub import InferenceClient
4
  import uvicorn
5
+ from typing import List
6
 
7
  app = FastAPI()
8
 
 
10
 
11
  class Item(BaseModel):
12
  prompt: str
13
+ history: List[str] = []
14
+ system_prompt: str = "You are a very powerful AI assistant."
15
  temperature: float = 0.0
16
  max_new_tokens: int = 1048
17
  top_p: float = 0.15
 
26
  return prompt
27
 
28
  def generate(item: Item):
29
+ temperature = max(float(item.temperature), 1e-2)
 
 
 
 
30
  generate_kwargs = dict(
31
  temperature=temperature,
32
  max_new_tokens=item.max_new_tokens,
33
+ top_p=float(item.top_p),
34
  repetition_penalty=item.repetition_penalty,
35
  do_sample=True,
36
  seed=42,
 
44
  output += response.token.text
45
  return output
46
 
47
+ @app.get("/generate/")
48
+ async def generate_text(
49
+ prompt: str,
50
+ history: List[str] = [],
51
+ system_prompt: str = "You are a very powerful AI assistant.",
52
+ temperature: float = 0.0,
53
+ max_new_tokens: int = 1048,
54
+ top_p: float = 0.15,
55
+ repetition_penalty: float = 1.0
56
+ ):
57
+ item = Item(
58
+ prompt=prompt,
59
+ history=history,
60
+ system_prompt=system_prompt,
61
+ temperature=temperature,
62
+ max_new_tokens=max_new_tokens,
63
+ top_p=top_p,
64
+ repetition_penalty=repetition_penalty
65
+ )
66
  return {"response": generate(item)}