Luigi commited on
Commit
c16840d
·
1 Parent(s): 01dc9b6

add -c option to force cpu only

Browse files
Files changed (1) hide show
  1. summarize_transcript.py +6 -5
summarize_transcript.py CHANGED
@@ -8,14 +8,14 @@ import argparse
8
  from llama_cpp import Llama
9
  from huggingface_hub import hf_hub_download
10
 
11
- def load_model(repo_id, filename):
12
  """Load the model from Hugging Face Hub."""
13
 
14
- # Initialize the model with SYCL support
15
  llm = Llama.from_pretrained(
16
  repo_id=repo_id,
17
  filename=filename,
18
- n_gpu_layers=-1, # Use all layers on GPU
19
  seed=1337,
20
  n_ctx=32768, # Context size
21
  verbose=True, # Reduced verbosity for cleaner output
@@ -88,6 +88,7 @@ def main():
88
  parser.add_argument("-m", "--model", type=str,
89
  default="bartowski/baidu_ERNIE-4.5-0.3B-PT-GGUF:Q6_K",
90
  help="HuggingFace model in format repo_id:quant (e.g., Luigi/Falcon-H1-Tiny-Multilingual-100M-Instruct-GGUF:IQ4_NL)")
 
91
  args = parser.parse_args()
92
 
93
  # Parse model argument if provided
@@ -98,10 +99,10 @@ def main():
98
  print(f"Error: Invalid model format '{args.model}'. Expected format: repo_id:quant")
99
  return
100
 
101
- print(f"Loading model: {repo_id} ({filename}) with SYCL acceleration...")
102
 
103
  # Load the model
104
- llm = load_model(repo_id, filename)
105
 
106
  # Read the transcript
107
  transcript_path = args.input
 
8
  from llama_cpp import Llama
9
  from huggingface_hub import hf_hub_download
10
 
11
+ def load_model(repo_id, filename, cpu_only=False):
12
  """Load the model from Hugging Face Hub."""
13
 
14
+ # Initialize the model with SYCL support (or CPU only if requested)
15
  llm = Llama.from_pretrained(
16
  repo_id=repo_id,
17
  filename=filename,
18
+ n_gpu_layers=0 if cpu_only else -1, # 0 for CPU, -1 for all layers on GPU
19
  seed=1337,
20
  n_ctx=32768, # Context size
21
  verbose=True, # Reduced verbosity for cleaner output
 
88
  parser.add_argument("-m", "--model", type=str,
89
  default="bartowski/baidu_ERNIE-4.5-0.3B-PT-GGUF:Q6_K",
90
  help="HuggingFace model in format repo_id:quant (e.g., Luigi/Falcon-H1-Tiny-Multilingual-100M-Instruct-GGUF:IQ4_NL)")
91
+ parser.add_argument("-c", "--cpu", action="store_true", help="Force CPU only inference")
92
  args = parser.parse_args()
93
 
94
  # Parse model argument if provided
 
99
  print(f"Error: Invalid model format '{args.model}'. Expected format: repo_id:quant")
100
  return
101
 
102
+ print(f"Loading model: {repo_id} ({filename}) with {'CPU only' if args.cpu else 'SYCL acceleration'}...")
103
 
104
  # Load the model
105
+ llm = load_model(repo_id, filename, cpu_only=args.cpu)
106
 
107
  # Read the transcript
108
  transcript_path = args.input