saiteki-kai commited on
Commit
dd0f1e1
·
verified ·
1 Parent(s): 723e9ef

fix: correct dtype argument in model loading and enhance demo launch options

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -46,7 +46,7 @@ chat_tokenizer = AutoTokenizer.from_pretrained(CHAT_MODEL_NAME, padding_side="le
46
  if chat_tokenizer.pad_token is None:
47
  chat_tokenizer.pad_token = chat_tokenizer.eos_token
48
 
49
- chat_model = AutoModelForCausalLM.from_pretrained(CHAT_MODEL_NAME, torch_dtype=torch.bfloat16)
50
 
51
  chat_model.to(device) # type: ignore
52
  chat_model.eval()
@@ -57,7 +57,7 @@ print("✓ Chat model loaded")
57
  print(f"Loading classifier: {CLASSIFIER_MODEL_NAME}")
58
 
59
  cls_tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_MODEL_NAME)
60
- cls_model = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_MODEL_NAME, torch_dtype=torch.bfloat16)
61
 
62
  cls_model.to(device)
63
  cls_model.eval()
@@ -117,6 +117,7 @@ def generate(submission: list[dict[str, str]], team_id: str) -> list[dict[str, s
117
  **inputs,
118
  max_new_tokens=MAX_NEW_TOKENS,
119
  do_sample=False,
 
120
  repetition_penalty=REPETITION_PENALTY,
121
  pad_token_id=chat_tokenizer.pad_token_id,
122
  eos_token_id=chat_tokenizer.eos_token_id,
@@ -193,5 +194,5 @@ with gr.Blocks() as demo:
193
  if __name__ == "__main__":
194
  print("\n=== Launching Application ===")
195
  demo.queue(default_concurrency_limit=None, api_open=True)
196
- demo.launch()
197
  print("✓ Application running")
 
46
  if chat_tokenizer.pad_token is None:
47
  chat_tokenizer.pad_token = chat_tokenizer.eos_token
48
 
49
+ chat_model = AutoModelForCausalLM.from_pretrained(CHAT_MODEL_NAME, dtype=torch.bfloat16)
50
 
51
  chat_model.to(device) # type: ignore
52
  chat_model.eval()
 
57
  print(f"Loading classifier: {CLASSIFIER_MODEL_NAME}")
58
 
59
  cls_tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_MODEL_NAME)
60
+ cls_model = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_MODEL_NAME, dtype=torch.bfloat16)
61
 
62
  cls_model.to(device)
63
  cls_model.eval()
 
117
  **inputs,
118
  max_new_tokens=MAX_NEW_TOKENS,
119
  do_sample=False,
120
+ temperature=None,
121
  repetition_penalty=REPETITION_PENALTY,
122
  pad_token_id=chat_tokenizer.pad_token_id,
123
  eos_token_id=chat_tokenizer.eos_token_id,
 
194
  if __name__ == "__main__":
195
  print("\n=== Launching Application ===")
196
  demo.queue(default_concurrency_limit=None, api_open=True)
197
+ demo.launch(show_error=True)
198
  print("✓ Application running")