cwpkd commited on
Commit
cc84412
Β·
verified Β·
1 Parent(s): c672125

Update utils/llm_analyzer.py

Browse files
Files changed (1) hide show
  1. utils/llm_analyzer.py +24 -5
utils/llm_analyzer.py CHANGED
@@ -14,19 +14,28 @@ class LLMAnalyzer:
14
 
15
  def __init__(self):
16
  """Initialize Gemma model"""
 
17
  print("Loading Gemma model...")
18
- self.tokenizer = AutoTokenizer.from_pretrained(GEMMA_MODEL)
 
 
 
 
 
 
 
19
  self.model = AutoModelForCausalLM.from_pretrained(
20
  GEMMA_MODEL,
21
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
22
- device_map="auto"
 
23
  )
24
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
25
  print(f"Gemma loaded on {self.device}!")
26
 
27
  def generate_response(self, prompt: str, max_length: int = LLM_MAX_LENGTH) -> str:
28
  """
29
- Generate response from Gemma
30
 
31
  Args:
32
  prompt: Input prompt
@@ -35,8 +44,12 @@ class LLMAnalyzer:
35
  Returns:
36
  Generated text
37
  """
38
- # Format prompt for Gemma
39
- formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
 
 
 
 
40
 
41
  inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device)
42
 
@@ -55,6 +68,12 @@ class LLMAnalyzer:
55
  # Extract only the model's response
56
  if "<start_of_turn>model" in response:
57
  response = response.split("<start_of_turn>model")[-1].strip()
 
 
 
 
 
 
58
 
59
  return response
60
 
 
14
 
15
  def __init__(self):
16
  """Initialize Gemma model"""
17
+ import os
18
  print("Loading Gemma model...")
19
+
20
+ # Get token from environment
21
+ hf_token = os.environ.get("HF_TOKEN", None)
22
+
23
+ self.tokenizer = AutoTokenizer.from_pretrained(
24
+ GEMMA_MODEL,
25
+ token=hf_token
26
+ )
27
  self.model = AutoModelForCausalLM.from_pretrained(
28
  GEMMA_MODEL,
29
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
30
+ device_map="auto",
31
+ token=hf_token
32
  )
33
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
34
  print(f"Gemma loaded on {self.device}!")
35
 
36
  def generate_response(self, prompt: str, max_length: int = LLM_MAX_LENGTH) -> str:
37
  """
38
+ Generate response from LLM
39
 
40
  Args:
41
  prompt: Input prompt
 
44
  Returns:
45
  Generated text
46
  """
47
+ # Format prompt (works for both Gemma and Zephyr)
48
+ if "gemma" in GEMMA_MODEL.lower():
49
+ formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
50
+ else:
51
+ # Zephyr format
52
+ formatted_prompt = f"<|user|>\n{prompt}</s>\n<|assistant|>\n"
53
 
54
  inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device)
55
 
 
68
  # Extract only the model's response
69
  if "<start_of_turn>model" in response:
70
  response = response.split("<start_of_turn>model")[-1].strip()
71
+ elif "<|assistant|>" in response:
72
+ response = response.split("<|assistant|>")[-1].strip()
73
+
74
+ # Remove the original prompt if still present
75
+ if prompt in response:
76
+ response = response.replace(prompt, "").strip()
77
 
78
  return response
79