Xeltron-cloud commited on
Commit
169c067
·
verified ·
1 Parent(s): 748f7be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -1
app.py CHANGED
@@ -2,10 +2,22 @@ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from huggingface_hub import login
 
5
  import os
6
  import torch
7
  import uvicorn
8
 
 
 
 
 
 
 
 
 
 
 
 
9
  login(os.getenv("HF_TOKEN"))
10
 
11
  app = FastAPI(
@@ -36,13 +48,18 @@ async def generate_text(request: GenerateRequest):
36
  inputs = tokenizer(request.prompt, return_tensors="pt").to(model.device)
37
 
38
  with torch.no_grad():
 
 
 
 
39
  outputs = model.generate(
40
  **inputs,
41
  max_new_tokens=request.max_new_tokens,
42
  temperature=request.temperature,
43
  do_sample=True,
44
  repetition_penalty=1.1,
45
- pad_token_id=tokenizer.eos_token_id
 
46
  )
47
 
48
  full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from huggingface_hub import login
5
+ from transformers import StoppingCriteria, StoppingCriteriaList
6
  import os
7
  import torch
8
  import uvicorn
9
 
10
+ class StopOnStrings(StoppingCriteria):
11
+ def __init__(self, tokenizer, stop_strings):
12
+ self.tokenizer = tokenizer
13
+ self.stop_ids = [tokenizer.encode(s, add_special_tokens=False) for s in stop_strings]
14
+
15
+ def __call__(self, input_ids, scores, **kwargs):
16
+ for stop_id in self.stop_ids:
17
+ if input_ids[0][-len(stop_id):].tolist() == stop_id:
18
+ return True
19
+ return False
20
+
21
  login(os.getenv("HF_TOKEN"))
22
 
23
  app = FastAPI(
 
48
  inputs = tokenizer(request.prompt, return_tensors="pt").to(model.device)
49
 
50
  with torch.no_grad():
51
+ stopping = StoppingCriteriaList([
52
+ StopOnStrings(tokenizer, ["\n\n", "###", "END"])
53
+ ])
54
+
55
  outputs = model.generate(
56
  **inputs,
57
  max_new_tokens=request.max_new_tokens,
58
  temperature=request.temperature,
59
  do_sample=True,
60
  repetition_penalty=1.1,
61
+ pad_token_id=tokenizer.eos_token_id,
62
+ stopping_criteria=stopping
63
  )
64
 
65
  full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)