curiouscurrent commited on
Commit
e672e6e
·
verified ·
1 Parent(s): 9f08e65

Update AI_Agent/llm_adapters/hf_adapter.py

Browse files
AI_Agent/llm_adapters/hf_adapter.py CHANGED
@@ -1,20 +1,21 @@
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
  import torch
3
  import asyncio
4
 
5
  class HuggingFaceAdapter:
6
- def __init__(self, model_name="tiiuae/falcon-7b-instruct"):
7
  self.model_name = model_name
8
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  self.model = AutoModelForCausalLM.from_pretrained(
10
  model_name,
11
- dtype=torch.float16, # updated from torch_dtype to dtype
12
- device_map="auto" # requires accelerate
13
  )
14
 
15
- async def generate(self, prompt: str, max_tokens=500):
16
  def _sync_generate():
17
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
18
  outputs = self.model.generate(**inputs, max_new_tokens=max_tokens)
19
  text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
20
  return text
 
1
+ # AI_Agent/llm_adapters/hf_adapter.py
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
  import asyncio
5
 
6
  class HuggingFaceAdapter:
7
+ def __init__(self, model_name="google/gemma-3n-E2B-it"):
8
  self.model_name = model_name
9
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
10
  self.model = AutoModelForCausalLM.from_pretrained(
11
  model_name,
12
+ dtype=torch.float32, # CPU-friendly
13
+ device_map=None # CPU only
14
  )
15
 
16
+ async def generate(self, prompt: str, max_tokens=300):
17
  def _sync_generate():
18
+ inputs = self.tokenizer(prompt, return_tensors="pt") # no .to(self.model.device) needed
19
  outputs = self.model.generate(**inputs, max_new_tokens=max_tokens)
20
  text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
21
  return text