"""Two interchangeable backends for the generator side of the RAG pipeline. Preferred: LiteRTLMBackend with gemma4-E2B-it deployed via mediapipe's LiteRT-LM runtime (on-device, runs on CPU class hardware). Fallback: TransformersBackend, loads a Gemma variant through Hugging Face transformers. Slower, but works in any Python env including Colab. Both expose the same `generate(prompt: str, max_new_tokens=512) -> str` API. """ from pathlib import Path # ---------------------------------------------------------------------------- # Backend interface # ---------------------------------------------------------------------------- class LLMBackend: name: str = "abstract" def generate(self, prompt: str, max_new_tokens: int = 512) -> str: raise NotImplementedError # ---------------------------------------------------------------------------- # Preferred: LiteRT-LM via mediapipe # ---------------------------------------------------------------------------- class LiteRTLMBackend(LLMBackend): """Run gemma4-E2B-it (or any LiteRT-LM .task file) on CPU. Requires: pip install mediapipe # most platforms pip install mediapipe-silicon # Apple Silicon """ name = "liteRT-LM (gemma4-E2B-it)" def __init__(self, model_path: str, max_tokens: int = 2048, temperature: float = 0.0, top_k: int = 40): from mediapipe.tasks.python.genai import llm_inference p = Path(model_path) if not p.exists(): raise FileNotFoundError( f"LiteRT-LM model not found: {model_path}\n" "Download gemma4-e2b-it.task from kaggle.com/models/google/gemma-4 " "and place it in data/." ) options = llm_inference.LlmInferenceOptions( model_path=str(p), max_tokens=max_tokens, top_k=top_k, temperature=temperature, random_seed=42, ) self.llm = llm_inference.LlmInference.create_from_options(options) def generate(self, prompt: str, max_new_tokens: int = 512) -> str: # The mediapipe API caps via the construction-time max_tokens; max_new_tokens here # is informational (used to guide prompt budgeting upstream). return self.llm.generate_response(prompt) # ---------------------------------------------------------------------------- # Fallback: Hugging Face transformers # ---------------------------------------------------------------------------- class TransformersBackend(LLMBackend): """Load a model via transformers when LiteRT-LM is not available. Default model: 'Qwen/Qwen3-0.6B'. Swap as needed. """ name = "transformers (Qwen3-0.6B)" def __init__(self, model_name: str = "Qwen/Qwen3-0.6B"): import torch from transformers import AutoTokenizer, AutoModelForCausalLM self.device = "cuda" if torch.cuda.is_available() else "cpu" self.tok = AutoTokenizer.from_pretrained(model_name) dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map=self.device, ) self.model.eval() def generate(self, prompt: str, max_new_tokens: int = 512) -> str: import torch enc = self.tok(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): out = self.model.generate( **enc, max_new_tokens=max_new_tokens, do_sample=False, repetition_penalty=1.3, no_repeat_ngram_size=4, ) return self.tok.decode(out[0][enc["input_ids"].shape[-1]:], skip_special_tokens=True) # ---------------------------------------------------------------------------- # Claude Haiku Backend (Anthropic API) # ---------------------------------------------------------------------------- class ClaudeHaikuBackend(LLMBackend): """ Fast, high-quality answers via Claude Haiku (Anthropic API). Requires ANTHROPIC_API_KEY in environment or .env file. """ name = "claude-haiku (API)" def __init__(self, model: str = "claude-haiku-4-5-20251001"): import os # Load .env file if present env_path = Path(".env") if env_path.exists(): for line in env_path.read_text().splitlines(): if "=" in line and not line.startswith("#"): k, v = line.split("=", 1) os.environ.setdefault(k.strip(), v.strip()) self.api_key = os.environ.get("ANTHROPIC_API_KEY", "") if not self.api_key: raise ValueError( "ANTHROPIC_API_KEY non trouvée. " "Crée un fichier .env avec: ANTHROPIC_API_KEY=sk-ant-..." ) self.model = model print(f"✅ Claude Haiku backend prêt ({model})") def generate(self, prompt: str, max_new_tokens: int = 512) -> str: import urllib.request, json payload = json.dumps({ "model": self.model, "max_tokens": min(max_new_tokens, 1024), "messages": [{"role": "user", "content": prompt}], }).encode() req = urllib.request.Request( "https://api.anthropic.com/v1/messages", data=payload, headers={ "Content-Type": "application/json", "x-api-key": self.api_key, "anthropic-version": "2023-06-01", }, ) with urllib.request.urlopen(req, timeout=60) as resp: data = json.loads(resp.read()) return data["content"][0]["text"] # ---------------------------------------------------------------------------- # Factory # ---------------------------------------------------------------------------- def make_llm( model_path: str = "data/gemma4-e2b-it.task", fallback_hf: str = "Qwen/Qwen3-0.6B", use_claude: bool = False, ) -> LLMBackend: """ Backend priority: 1. Claude Haiku (si use_claude=True) 2. LiteRT-LM (si fichier .task présent) 3. Transformers (fallback) """ if use_claude: try: llm = ClaudeHaikuBackend() print(f"using backend: {llm.name}") return llm except Exception as e: print(f"Claude Haiku init failed ({e}); falling back to transformers.") if Path(model_path).exists(): try: llm = LiteRTLMBackend(model_path) print(f"using backend: {llm.name}") return llm except ImportError: print("mediapipe not installed; falling back to transformers.") except Exception as e: print(f"LiteRT-LM init failed ({e}); falling back to transformers.") else: print(f"no .task file at {model_path}; falling back to transformers.") llm = TransformersBackend(fallback_hf) print(f"using backend: {llm.name}") return llm