Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import transformers | |
| import logging | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor | |
| logger = logging.getLogger("app.llm") | |
| # Global Cache | |
| LLM_MODEL = None | |
| LLM_PROCESSOR = None | |
| CURRENT_MODEL_SIZE = None | |
| def get_device() -> torch.device: | |
| if torch.cuda.is_available(): return torch.device("cuda") | |
| return torch.device("cpu") | |
| def get_llm(model_size: str = "1b"): | |
| import sys | |
| cache_key = "sage_llm_cache" | |
| global LLM_MODEL, LLM_PROCESSOR, CURRENT_MODEL_SIZE | |
| if hasattr(sys, cache_key): | |
| cached_model, cached_proc, cached_size = getattr(sys, cache_key) | |
| if cached_size == model_size: | |
| return cached_model, cached_proc | |
| # Force 1B for HF/Stability if needed, but here we respect model_size | |
| # Actually, user said 4b for local, 1b for app.py | |
| llm_model_id = "google/gemma-3-1b-it" | |
| if model_size == "4b": | |
| llm_model_id = "google/gemma-3-4b-it" # Note: gated, requires auth or local files | |
| device = get_device() | |
| dtype = torch.bfloat16 if "cuda" in device.type else torch.float32 | |
| logger.info(f"Loading {llm_model_id} on {device}...") | |
| LLM_MODEL = AutoModelForCausalLM.from_pretrained( | |
| llm_model_id, | |
| dtype=dtype, | |
| device_map="auto" | |
| ).eval() | |
| try: | |
| LLM_PROCESSOR = AutoProcessor.from_pretrained(llm_model_id) | |
| except: | |
| LLM_PROCESSOR = AutoTokenizer.from_pretrained(llm_model_id) | |
| CURRENT_MODEL_SIZE = model_size | |
| setattr(sys, cache_key, (LLM_MODEL, LLM_PROCESSOR, model_size)) | |
| return LLM_MODEL, LLM_PROCESSOR | |
| def detect_language(text: str) -> str: | |
| if not text or len(text) < 5: return "English" | |
| model, processor = get_llm() | |
| prompt = f"Detect the language of the following text and return ONLY the language name:\n\n\"{text}\"" | |
| messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] | |
| inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate(inputs, max_new_tokens=10, do_sample=False) | |
| raw = processor.batch_decode(outputs[:, inputs.shape[1]:], skip_special_tokens=True)[0].strip() | |
| import re | |
| langs = ["English", "German", "French", "Spanish", "Italian", "Portuguese", "Russian", "Japanese", "Chinese"] | |
| for l in langs: | |
| if re.search(rf"\b{l}\b", raw, re.I): return l | |
| return "English" | |