import os import torch import transformers import logging from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor logger = logging.getLogger("app.llm") # Global Cache LLM_MODEL = None LLM_PROCESSOR = None CURRENT_MODEL_SIZE = None def get_device() -> torch.device: if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") def get_llm(model_size: str = "1b"): import sys cache_key = "sage_llm_cache" global LLM_MODEL, LLM_PROCESSOR, CURRENT_MODEL_SIZE if hasattr(sys, cache_key): cached_model, cached_proc, cached_size = getattr(sys, cache_key) if cached_size == model_size: return cached_model, cached_proc # Force 1B for HF/Stability if needed, but here we respect model_size # Actually, user said 4b for local, 1b for app.py llm_model_id = "google/gemma-3-1b-it" if model_size == "4b": llm_model_id = "google/gemma-3-4b-it" # Note: gated, requires auth or local files device = get_device() dtype = torch.bfloat16 if "cuda" in device.type else torch.float32 logger.info(f"Loading {llm_model_id} on {device}...") LLM_MODEL = AutoModelForCausalLM.from_pretrained( llm_model_id, dtype=dtype, device_map="auto" ).eval() try: LLM_PROCESSOR = AutoProcessor.from_pretrained(llm_model_id) except: LLM_PROCESSOR = AutoTokenizer.from_pretrained(llm_model_id) CURRENT_MODEL_SIZE = model_size setattr(sys, cache_key, (LLM_MODEL, LLM_PROCESSOR, model_size)) return LLM_MODEL, LLM_PROCESSOR def detect_language(text: str) -> str: if not text or len(text) < 5: return "English" model, processor = get_llm() prompt = f"Detect the language of the following text and return ONLY the language name:\n\n\"{text}\"" messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate(inputs, max_new_tokens=10, do_sample=False) raw = processor.batch_decode(outputs[:, inputs.shape[1]:], skip_special_tokens=True)[0].strip() import re langs = ["English", "German", "French", "Spanish", "Italian", "Portuguese", "Russian", "Japanese", "Chinese"] for l in langs: if re.search(rf"\b{l}\b", raw, re.I): return l return "English"