genAI-Project / src /llm_backend.py
OGB2000's picture
Initial clean deployment
bf77be6
Raw
History Blame Contribute Delete
6.99 kB
"""Two interchangeable backends for the generator side of the RAG pipeline.
Preferred: LiteRTLMBackend with gemma4-E2B-it deployed via mediapipe's
LiteRT-LM runtime (on-device, runs on CPU class hardware).
Fallback: TransformersBackend, loads a Gemma variant through Hugging Face
transformers. Slower, but works in any Python env including Colab.
Both expose the same `generate(prompt: str, max_new_tokens=512) -> str` API.
"""
from pathlib import Path
# ----------------------------------------------------------------------------
# Backend interface
# ----------------------------------------------------------------------------
class LLMBackend:
name: str = "abstract"
def generate(self, prompt: str, max_new_tokens: int = 512) -> str:
raise NotImplementedError
# ----------------------------------------------------------------------------
# Preferred: LiteRT-LM via mediapipe
# ----------------------------------------------------------------------------
class LiteRTLMBackend(LLMBackend):
"""Run gemma4-E2B-it (or any LiteRT-LM .task file) on CPU.
Requires:
pip install mediapipe # most platforms
pip install mediapipe-silicon # Apple Silicon
"""
name = "liteRT-LM (gemma4-E2B-it)"
def __init__(self, model_path: str, max_tokens: int = 2048, temperature: float = 0.0, top_k: int = 40):
from mediapipe.tasks.python.genai import llm_inference
p = Path(model_path)
if not p.exists():
raise FileNotFoundError(
f"LiteRT-LM model not found: {model_path}\n"
"Download gemma4-e2b-it.task from kaggle.com/models/google/gemma-4 "
"and place it in data/."
)
options = llm_inference.LlmInferenceOptions(
model_path=str(p),
max_tokens=max_tokens,
top_k=top_k,
temperature=temperature,
random_seed=42,
)
self.llm = llm_inference.LlmInference.create_from_options(options)
def generate(self, prompt: str, max_new_tokens: int = 512) -> str:
# The mediapipe API caps via the construction-time max_tokens; max_new_tokens here
# is informational (used to guide prompt budgeting upstream).
return self.llm.generate_response(prompt)
# ----------------------------------------------------------------------------
# Fallback: Hugging Face transformers
# ----------------------------------------------------------------------------
class TransformersBackend(LLMBackend):
"""Load a model via transformers when LiteRT-LM is not available.
Default model: 'Qwen/Qwen3-0.6B'. Swap as needed.
"""
name = "transformers (Qwen3-0.6B)"
def __init__(self, model_name: str = "Qwen/Qwen3-0.6B"):
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tok = AutoTokenizer.from_pretrained(model_name)
dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
self.model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map=self.device,
)
self.model.eval()
def generate(self, prompt: str, max_new_tokens: int = 512) -> str:
import torch
enc = self.tok(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
out = self.model.generate(
**enc,
max_new_tokens=max_new_tokens,
do_sample=False,
repetition_penalty=1.3,
no_repeat_ngram_size=4,
)
return self.tok.decode(out[0][enc["input_ids"].shape[-1]:], skip_special_tokens=True)
# ----------------------------------------------------------------------------
# Claude Haiku Backend (Anthropic API)
# ----------------------------------------------------------------------------
class ClaudeHaikuBackend(LLMBackend):
"""
Fast, high-quality answers via Claude Haiku (Anthropic API).
Requires ANTHROPIC_API_KEY in environment or .env file.
"""
name = "claude-haiku (API)"
def __init__(self, model: str = "claude-haiku-4-5-20251001"):
import os
# Load .env file if present
env_path = Path(".env")
if env_path.exists():
for line in env_path.read_text().splitlines():
if "=" in line and not line.startswith("#"):
k, v = line.split("=", 1)
os.environ.setdefault(k.strip(), v.strip())
self.api_key = os.environ.get("ANTHROPIC_API_KEY", "")
if not self.api_key:
raise ValueError(
"ANTHROPIC_API_KEY non trouvée. "
"Crée un fichier .env avec: ANTHROPIC_API_KEY=sk-ant-..."
)
self.model = model
print(f"✅ Claude Haiku backend prêt ({model})")
def generate(self, prompt: str, max_new_tokens: int = 512) -> str:
import urllib.request, json
payload = json.dumps({
"model": self.model,
"max_tokens": min(max_new_tokens, 1024),
"messages": [{"role": "user", "content": prompt}],
}).encode()
req = urllib.request.Request(
"https://api.anthropic.com/v1/messages",
data=payload,
headers={
"Content-Type": "application/json",
"x-api-key": self.api_key,
"anthropic-version": "2023-06-01",
},
)
with urllib.request.urlopen(req, timeout=60) as resp:
data = json.loads(resp.read())
return data["content"][0]["text"]
# ----------------------------------------------------------------------------
# Factory
# ----------------------------------------------------------------------------
def make_llm(
model_path: str = "data/gemma4-e2b-it.task",
fallback_hf: str = "Qwen/Qwen3-0.6B",
use_claude: bool = False,
) -> LLMBackend:
"""
Backend priority:
1. Claude Haiku (si use_claude=True)
2. LiteRT-LM (si fichier .task présent)
3. Transformers (fallback)
"""
if use_claude:
try:
llm = ClaudeHaikuBackend()
print(f"using backend: {llm.name}")
return llm
except Exception as e:
print(f"Claude Haiku init failed ({e}); falling back to transformers.")
if Path(model_path).exists():
try:
llm = LiteRTLMBackend(model_path)
print(f"using backend: {llm.name}")
return llm
except ImportError:
print("mediapipe not installed; falling back to transformers.")
except Exception as e:
print(f"LiteRT-LM init failed ({e}); falling back to transformers.")
else:
print(f"no .task file at {model_path}; falling back to transformers.")
llm = TransformersBackend(fallback_hf)
print(f"using backend: {llm.name}")
return llm