raksama19 commited on
Commit
6ed2a3f
·
verified ·
1 Parent(s): 2160785

Update gemma_inference.py

Browse files
Files changed (1) hide show
  1. gemma_inference.py +12 -2
gemma_inference.py CHANGED
@@ -9,6 +9,7 @@ from typing import Generator, Optional
9
  import numpy as np
10
  from utils.snac_utils import generate_audio_data, get_snac
11
  from utils.vad import get_speech_timestamps, collect_chunks
 
12
 
13
  class GemmaOmniInference:
14
  """
@@ -20,15 +21,24 @@ class GemmaOmniInference:
20
  self.device = device
21
  self.model_id = model_id
22
 
 
 
 
 
 
 
 
 
23
  # Initialize models
24
  print("Loading Gemma 3n model...")
25
  self.model = Gemma3nForConditionalGeneration.from_pretrained(
26
  model_id,
27
  device_map="auto",
28
- torch_dtype=torch.bfloat16
 
29
  ).eval()
30
 
31
- self.processor = AutoProcessor.from_pretrained(model_id)
32
 
33
  # Keep the audio processing models
34
  print("Loading audio processing models...")
 
9
  import numpy as np
10
  from utils.snac_utils import generate_audio_data, get_snac
11
  from utils.vad import get_speech_timestamps, collect_chunks
12
+ from huggingface_hub import login
13
 
14
  class GemmaOmniInference:
15
  """
 
21
  self.device = device
22
  self.model_id = model_id
23
 
24
+ # Authenticate with Hugging Face
25
+ hf_token = os.getenv("HF_TOKEN")
26
+ if hf_token:
27
+ print("Authenticating with Hugging Face...")
28
+ login(token=hf_token)
29
+ else:
30
+ print("Warning: HF_TOKEN not found. Make sure to set it in Space secrets.")
31
+
32
  # Initialize models
33
  print("Loading Gemma 3n model...")
34
  self.model = Gemma3nForConditionalGeneration.from_pretrained(
35
  model_id,
36
  device_map="auto",
37
+ torch_dtype=torch.bfloat16,
38
+ token=hf_token
39
  ).eval()
40
 
41
+ self.processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
42
 
43
  # Keep the audio processing models
44
  print("Loading audio processing models...")