ta4tsering commited on
Commit
aeca817
·
1 Parent(s): 15c943a

refactor: update model loading logic to improve compatibility and adjust output textbox settings

Browse files
Files changed (1) hide show
  1. app.py +3 -5
app.py CHANGED
@@ -59,24 +59,23 @@ def load_model():
59
  model_path = snapshot_download(
60
  repo_id=MODEL_ID,
61
  local_dir=MODEL_DIR,
62
- local_dir_use_symlinks=False,
63
  )
64
 
65
  patch_configuration_dots(model_path)
66
  sys.path.insert(0, model_path)
67
 
68
- # Try flash_attention_2 first, fall back to sdpa
69
  attn_impl = "flash_attention_2"
70
  try:
71
  import flash_attn # noqa: F401
72
  except ImportError:
73
- attn_impl = "sdpa"
74
 
75
  print(f"Loading model with attn_implementation={attn_impl} ...")
76
  model = AutoModelForCausalLM.from_pretrained(
77
  model_path,
78
  attn_implementation=attn_impl,
79
- torch_dtype=torch.bfloat16,
80
  device_map="auto",
81
  trust_remote_code=True,
82
  )
@@ -195,7 +194,6 @@ print(result)
195
  output_text = gr.Textbox(
196
  label="Model Output",
197
  lines=20,
198
- show_copy_button=True,
199
  )
200
 
201
  run_btn.click(
 
59
  model_path = snapshot_download(
60
  repo_id=MODEL_ID,
61
  local_dir=MODEL_DIR,
 
62
  )
63
 
64
  patch_configuration_dots(model_path)
65
  sys.path.insert(0, model_path)
66
 
67
+ # Try flash_attention_2 first, fall back to eager for compatibility.
68
  attn_impl = "flash_attention_2"
69
  try:
70
  import flash_attn # noqa: F401
71
  except ImportError:
72
+ attn_impl = "eager"
73
 
74
  print(f"Loading model with attn_implementation={attn_impl} ...")
75
  model = AutoModelForCausalLM.from_pretrained(
76
  model_path,
77
  attn_implementation=attn_impl,
78
+ dtype=torch.bfloat16,
79
  device_map="auto",
80
  trust_remote_code=True,
81
  )
 
194
  output_text = gr.Textbox(
195
  label="Model Output",
196
  lines=20,
 
197
  )
198
 
199
  run_btn.click(