Arif commited on
Commit
be05fd6
·
1 Parent(s): 8c389ce

Fix: Add conditional import for MLX with CPU fallback

Browse files
Files changed (1) hide show
  1. src/generation/mlx_wrapper.py +49 -19
src/generation/mlx_wrapper.py CHANGED
@@ -1,29 +1,52 @@
1
- # src/generation/mlx_wrapper.py
2
  import os
3
  from typing import Any, List, Optional
4
  from langchain_core.callbacks.manager import CallbackManagerForLLMRun
5
  from langchain_core.language_models.llms import LLM
6
- from mlx_lm import load, generate
7
  from dotenv import load_dotenv
8
 
9
  load_dotenv()
10
 
 
 
 
 
 
 
 
 
11
  class MLXLLM(LLM):
12
- """Custom LangChain Wrapper for MLX Models"""
13
 
14
  model_id: str = os.getenv("MODEL_ID", "mlx-community/Llama-3.2-3B-Instruct-4bit")
15
  model: Any = None
16
  tokenizer: Any = None
17
  max_tokens: int = int(os.getenv("MAX_TOKENS", 512))
 
18
 
19
  def __init__(self, **kwargs):
20
  super().__init__(**kwargs)
21
- print(f"🚀 Loading MLX Model: {self.model_id}")
22
- self.model, self.tokenizer = load(self.model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  @property
25
  def _llm_type(self) -> str:
26
- return "mlx_llama"
27
 
28
  def _call(
29
  self,
@@ -35,16 +58,23 @@ class MLXLLM(LLM):
35
  if stop is not None:
36
  raise ValueError("stop kwargs are not permitted.")
37
 
38
- messages = [{"role": "user", "content": prompt}]
39
- formatted_prompt = self.tokenizer.apply_chat_template(
40
- messages, tokenize=False, add_generation_prompt=True
41
- )
42
-
43
- response = generate(
44
- self.model,
45
- self.tokenizer,
46
- prompt=formatted_prompt,
47
- verbose=False,
48
- max_tokens=self.max_tokens
49
- )
50
- return response
 
 
 
 
 
 
 
 
 
1
  import os
2
  from typing import Any, List, Optional
3
  from langchain_core.callbacks.manager import CallbackManagerForLLMRun
4
  from langchain_core.language_models.llms import LLM
 
5
  from dotenv import load_dotenv
6
 
7
  load_dotenv()
8
 
9
+ # --- CRITICAL FIX: Handle Import Error ---
10
+ try:
11
+ from mlx_lm import load, generate
12
+ HAS_MLX = True
13
+ except ImportError:
14
+ HAS_MLX = False
15
+ # ----------------------------------------
16
+
17
  class MLXLLM(LLM):
18
+ """Custom LangChain Wrapper for MLX Models (with Cloud Fallback)"""
19
 
20
  model_id: str = os.getenv("MODEL_ID", "mlx-community/Llama-3.2-3B-Instruct-4bit")
21
  model: Any = None
22
  tokenizer: Any = None
23
  max_tokens: int = int(os.getenv("MAX_TOKENS", 512))
24
+ pipeline: Any = None # For Cloud Fallback
25
 
26
  def __init__(self, **kwargs):
27
  super().__init__(**kwargs)
28
+
29
+ if HAS_MLX:
30
+ print(f"🚀 Loading MLX Model: {self.model_id}")
31
+ self.model, self.tokenizer = load(self.model_id)
32
+ else:
33
+ print(f"⚠️ MLX not found. Falling back to HuggingFace Transformers (CPU/Cloud).")
34
+ # Fallback: Use standard Transformers
35
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
36
+
37
+ # Use the MODEL_ID env var (set to 'gpt2' or 'facebook/opt-125m' in HF Secrets)
38
+ # Do NOT use the MLX model ID here, as it requires MLX format.
39
+ cloud_model_id = os.getenv("MODEL_ID", "gpt2")
40
+
41
+ self.pipeline = pipeline(
42
+ "text-generation",
43
+ model=cloud_model_id,
44
+ max_new_tokens=self.max_tokens
45
+ )
46
 
47
  @property
48
  def _llm_type(self) -> str:
49
+ return "mlx_llama" if HAS_MLX else "transformers_fallback"
50
 
51
  def _call(
52
  self,
 
58
  if stop is not None:
59
  raise ValueError("stop kwargs are not permitted.")
60
 
61
+ if HAS_MLX:
62
+ # MLX Generation Logic
63
+ messages = [{"role": "user", "content": prompt}]
64
+ formatted_prompt = self.tokenizer.apply_chat_template(
65
+ messages, tokenize=False, add_generation_prompt=True
66
+ )
67
+ response = generate(
68
+ self.model,
69
+ self.tokenizer,
70
+ prompt=formatted_prompt,
71
+ verbose=False,
72
+ max_tokens=self.max_tokens
73
+ )
74
+ return response
75
+ else:
76
+ # Cloud/CPU Fallback Logic
77
+ # Simple text generation for MVP
78
+ response = self.pipeline(prompt)[0]['generated_text']
79
+ # Remove the prompt from the response if needed
80
+ return response[len(prompt):]