gemma-sage / llm_module.py
neuralworm's picture
feat: Agentic Positive Prompting & Scenario Tests
058ae1e
import os
import torch
import transformers
import logging
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
logger = logging.getLogger("app.llm")
# Global Cache
LLM_MODEL = None
LLM_PROCESSOR = None
CURRENT_MODEL_SIZE = None
def get_device() -> torch.device:
if torch.cuda.is_available(): return torch.device("cuda")
return torch.device("cpu")
def get_llm(model_size: str = "1b"):
import sys
cache_key = "sage_llm_cache"
global LLM_MODEL, LLM_PROCESSOR, CURRENT_MODEL_SIZE
if hasattr(sys, cache_key):
cached_model, cached_proc, cached_size = getattr(sys, cache_key)
if cached_size == model_size:
return cached_model, cached_proc
# Force 1B for HF/Stability if needed, but here we respect model_size
# Actually, user said 4b for local, 1b for app.py
llm_model_id = "google/gemma-3-1b-it"
if model_size == "4b":
llm_model_id = "google/gemma-3-4b-it" # Note: gated, requires auth or local files
device = get_device()
dtype = torch.bfloat16 if "cuda" in device.type else torch.float32
logger.info(f"Loading {llm_model_id} on {device}...")
LLM_MODEL = AutoModelForCausalLM.from_pretrained(
llm_model_id,
dtype=dtype,
device_map="auto"
).eval()
try:
LLM_PROCESSOR = AutoProcessor.from_pretrained(llm_model_id)
except:
LLM_PROCESSOR = AutoTokenizer.from_pretrained(llm_model_id)
CURRENT_MODEL_SIZE = model_size
setattr(sys, cache_key, (LLM_MODEL, LLM_PROCESSOR, model_size))
return LLM_MODEL, LLM_PROCESSOR
def detect_language(text: str) -> str:
if not text or len(text) < 5: return "English"
model, processor = get_llm()
prompt = f"Detect the language of the following text and return ONLY the language name:\n\n\"{text}\""
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(inputs, max_new_tokens=10, do_sample=False)
raw = processor.batch_decode(outputs[:, inputs.shape[1]:], skip_special_tokens=True)[0].strip()
import re
langs = ["English", "German", "French", "Spanish", "Italian", "Portuguese", "Russian", "Japanese", "Chinese"]
for l in langs:
if re.search(rf"\b{l}\b", raw, re.I): return l
return "English"