StefanG2002 commited on
Commit
17257f3
·
verified ·
1 Parent(s): 80fc749

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +1 -24
main.py CHANGED
@@ -6,8 +6,6 @@ import uvicorn
6
 
7
  app = FastAPI()
8
 
9
- client = InferenceClient("google/gemma-1.1-7b-it")
10
-
11
  class Item(BaseModel):
12
  prompt: str
13
  system_prompt: str
@@ -16,28 +14,7 @@ class Item(BaseModel):
16
  top_p: float = 0.15
17
  repetition_penalty: float = 1.0
18
 
19
- def generate(item: Item):
20
- temperature = float(item.temperature)
21
- if temperature < 1e-2:
22
- temperature = 1e-2
23
- top_p = float(item.top_p)
24
-
25
- generate_kwargs = dict(
26
- temperature=temperature,
27
- max_new_tokens=item.max_new_tokens,
28
- top_p=top_p,
29
- repetition_penalty=item.repetition_penalty,
30
- do_sample=True,
31
- seed=42,
32
- )
33
-
34
- stream = client.text_generation(item.prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
35
- output = ""
36
-
37
- for response in stream:
38
- output += response.token.text
39
- return output
40
 
41
  @app.post("/generate/")
42
  async def generate_text(item: Item):
43
- return {"response": generate(item)}
 
6
 
7
  app = FastAPI()
8
 
 
 
9
  class Item(BaseModel):
10
  prompt: str
11
  system_prompt: str
 
14
  top_p: float = 0.15
15
  repetition_penalty: float = 1.0
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  @app.post("/generate/")
19
  async def generate_text(item: Item):
20
+ return {"response": item.prompt}