StefanG2002 commited on
Commit
d6c82f4
·
verified ·
1 Parent(s): 68de808

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +1 -10
main.py CHANGED
@@ -6,7 +6,7 @@ import uvicorn
6
 
7
  app = FastAPI()
8
 
9
- client = InferenceClient("google/gemma-7b")
10
 
11
  class Item(BaseModel):
12
  prompt: str
@@ -17,14 +17,6 @@ class Item(BaseModel):
17
  top_p: float = 0.15
18
  repetition_penalty: float = 1.0
19
 
20
- # def format_prompt(message, history):
21
- # prompt = "<s>"
22
- # for user_prompt, bot_response in history:
23
- # prompt += f"[INST] {user_prompt} [/INST]"
24
- # prompt += f" {bot_response}</s> "
25
- # prompt += f"[INST] {message} [/INST]"
26
- # return prompt
27
-
28
  def generate(item: Item):
29
  temperature = float(item.temperature)
30
  if temperature < 1e-2:
@@ -40,7 +32,6 @@ def generate(item: Item):
40
  seed=42,
41
  )
42
 
43
- # formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
44
  stream = client.text_generation(item.prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
45
  output = ""
46
 
 
6
 
7
  app = FastAPI()
8
 
9
+ client = InferenceClient("google/codegemma-7b-it")
10
 
11
  class Item(BaseModel):
12
  prompt: str
 
17
  top_p: float = 0.15
18
  repetition_penalty: float = 1.0
19
 
 
 
 
 
 
 
 
 
20
  def generate(item: Item):
21
  temperature = float(item.temperature)
22
  if temperature < 1e-2:
 
32
  seed=42,
33
  )
34
 
 
35
  stream = client.text_generation(item.prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
36
  output = ""
37