Spaces:
Sleeping
Sleeping
| """Two interchangeable backends for the generator side of the RAG pipeline. | |
| Preferred: LiteRTLMBackend with gemma4-E2B-it deployed via mediapipe's | |
| LiteRT-LM runtime (on-device, runs on CPU class hardware). | |
| Fallback: TransformersBackend, loads a Gemma variant through Hugging Face | |
| transformers. Slower, but works in any Python env including Colab. | |
| Both expose the same `generate(prompt: str, max_new_tokens=512) -> str` API. | |
| """ | |
| from pathlib import Path | |
| # ---------------------------------------------------------------------------- | |
| # Backend interface | |
| # ---------------------------------------------------------------------------- | |
| class LLMBackend: | |
| name: str = "abstract" | |
| def generate(self, prompt: str, max_new_tokens: int = 512) -> str: | |
| raise NotImplementedError | |
| # ---------------------------------------------------------------------------- | |
| # Preferred: LiteRT-LM via mediapipe | |
| # ---------------------------------------------------------------------------- | |
| class LiteRTLMBackend(LLMBackend): | |
| """Run gemma4-E2B-it (or any LiteRT-LM .task file) on CPU. | |
| Requires: | |
| pip install mediapipe # most platforms | |
| pip install mediapipe-silicon # Apple Silicon | |
| """ | |
| name = "liteRT-LM (gemma4-E2B-it)" | |
| def __init__(self, model_path: str, max_tokens: int = 2048, temperature: float = 0.0, top_k: int = 40): | |
| from mediapipe.tasks.python.genai import llm_inference | |
| p = Path(model_path) | |
| if not p.exists(): | |
| raise FileNotFoundError( | |
| f"LiteRT-LM model not found: {model_path}\n" | |
| "Download gemma4-e2b-it.task from kaggle.com/models/google/gemma-4 " | |
| "and place it in data/." | |
| ) | |
| options = llm_inference.LlmInferenceOptions( | |
| model_path=str(p), | |
| max_tokens=max_tokens, | |
| top_k=top_k, | |
| temperature=temperature, | |
| random_seed=42, | |
| ) | |
| self.llm = llm_inference.LlmInference.create_from_options(options) | |
| def generate(self, prompt: str, max_new_tokens: int = 512) -> str: | |
| # The mediapipe API caps via the construction-time max_tokens; max_new_tokens here | |
| # is informational (used to guide prompt budgeting upstream). | |
| return self.llm.generate_response(prompt) | |
| # ---------------------------------------------------------------------------- | |
| # Fallback: Hugging Face transformers | |
| # ---------------------------------------------------------------------------- | |
| class TransformersBackend(LLMBackend): | |
| """Load a model via transformers when LiteRT-LM is not available. | |
| Default model: 'Qwen/Qwen3-0.6B'. Swap as needed. | |
| """ | |
| name = "transformers (Qwen3-0.6B)" | |
| def __init__(self, model_name: str = "Qwen/Qwen3-0.6B"): | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.tok = AutoTokenizer.from_pretrained(model_name) | |
| dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name, torch_dtype=dtype, device_map=self.device, | |
| ) | |
| self.model.eval() | |
| def generate(self, prompt: str, max_new_tokens: int = 512) -> str: | |
| import torch | |
| enc = self.tok(prompt, return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| out = self.model.generate( | |
| **enc, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| repetition_penalty=1.3, | |
| no_repeat_ngram_size=4, | |
| ) | |
| return self.tok.decode(out[0][enc["input_ids"].shape[-1]:], skip_special_tokens=True) | |
| # ---------------------------------------------------------------------------- | |
| # Claude Haiku Backend (Anthropic API) | |
| # ---------------------------------------------------------------------------- | |
| class ClaudeHaikuBackend(LLMBackend): | |
| """ | |
| Fast, high-quality answers via Claude Haiku (Anthropic API). | |
| Requires ANTHROPIC_API_KEY in environment or .env file. | |
| """ | |
| name = "claude-haiku (API)" | |
| def __init__(self, model: str = "claude-haiku-4-5-20251001"): | |
| import os | |
| # Load .env file if present | |
| env_path = Path(".env") | |
| if env_path.exists(): | |
| for line in env_path.read_text().splitlines(): | |
| if "=" in line and not line.startswith("#"): | |
| k, v = line.split("=", 1) | |
| os.environ.setdefault(k.strip(), v.strip()) | |
| self.api_key = os.environ.get("ANTHROPIC_API_KEY", "") | |
| if not self.api_key: | |
| raise ValueError( | |
| "ANTHROPIC_API_KEY non trouvée. " | |
| "Crée un fichier .env avec: ANTHROPIC_API_KEY=sk-ant-..." | |
| ) | |
| self.model = model | |
| print(f"✅ Claude Haiku backend prêt ({model})") | |
| def generate(self, prompt: str, max_new_tokens: int = 512) -> str: | |
| import urllib.request, json | |
| payload = json.dumps({ | |
| "model": self.model, | |
| "max_tokens": min(max_new_tokens, 1024), | |
| "messages": [{"role": "user", "content": prompt}], | |
| }).encode() | |
| req = urllib.request.Request( | |
| "https://api.anthropic.com/v1/messages", | |
| data=payload, | |
| headers={ | |
| "Content-Type": "application/json", | |
| "x-api-key": self.api_key, | |
| "anthropic-version": "2023-06-01", | |
| }, | |
| ) | |
| with urllib.request.urlopen(req, timeout=60) as resp: | |
| data = json.loads(resp.read()) | |
| return data["content"][0]["text"] | |
| # ---------------------------------------------------------------------------- | |
| # Factory | |
| # ---------------------------------------------------------------------------- | |
| def make_llm( | |
| model_path: str = "data/gemma4-e2b-it.task", | |
| fallback_hf: str = "Qwen/Qwen3-0.6B", | |
| use_claude: bool = False, | |
| ) -> LLMBackend: | |
| """ | |
| Backend priority: | |
| 1. Claude Haiku (si use_claude=True) | |
| 2. LiteRT-LM (si fichier .task présent) | |
| 3. Transformers (fallback) | |
| """ | |
| if use_claude: | |
| try: | |
| llm = ClaudeHaikuBackend() | |
| print(f"using backend: {llm.name}") | |
| return llm | |
| except Exception as e: | |
| print(f"Claude Haiku init failed ({e}); falling back to transformers.") | |
| if Path(model_path).exists(): | |
| try: | |
| llm = LiteRTLMBackend(model_path) | |
| print(f"using backend: {llm.name}") | |
| return llm | |
| except ImportError: | |
| print("mediapipe not installed; falling back to transformers.") | |
| except Exception as e: | |
| print(f"LiteRT-LM init failed ({e}); falling back to transformers.") | |
| else: | |
| print(f"no .task file at {model_path}; falling back to transformers.") | |
| llm = TransformersBackend(fallback_hf) | |
| print(f"using backend: {llm.name}") | |
| return llm |