Spaces:
Running
Running
| import os | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig | |
| from huggingface_hub import snapshot_download # <-- needed to pull the folder | |
| # ββ 1) PATHS / VARS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| REPO_ID = "MAS-AI-0000/Authentica" | |
| TEXT_SUBFOLDER = "Lib/Models/Text" # where config.json/model.safetensors live in the repo | |
| # download a local snapshot of just the Text folder and point MODEL_DIR at it | |
| _snapshot_dir = snapshot_download( | |
| repo_id=REPO_ID, | |
| allow_patterns=[f"{TEXT_SUBFOLDER}/*"] | |
| ) | |
| MODEL_DIR = os.path.join(_snapshot_dir, TEXT_SUBFOLDER) | |
| # individual file paths (in case you need them elsewhere) | |
| CONFIG_PATH = os.path.join(MODEL_DIR, "config.json") | |
| MODEL_SAFETENSORS_PATH = os.path.join(MODEL_DIR, "model.safetensors") | |
| TOKENIZER_JSON_PATH = os.path.join(MODEL_DIR, "tokenizer.json") | |
| TOKENIZER_CONFIG_PATH = os.path.join(MODEL_DIR, "tokenizer_config.json") | |
| SPECIAL_TOKENS_MAP_PATH = os.path.join(MODEL_DIR, "special_tokens_map.json") | |
| TRAINING_ARGS_BIN_PATH = os.path.join(MODEL_DIR, "training_args.bin") # optional | |
| TEXT_TXT_PATH = os.path.join(MODEL_DIR, "text.txt") # optional | |
| MAX_LEN = 512 | |
| # ββ 2) Load model & tokenizer ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Text prediction device: {device}") | |
| tokenizer = None | |
| model = None | |
| ID2LABEL = {0: "human", 1: "ai"} | |
| try: | |
| # load directly from the local MODEL_DIR | |
| config = AutoConfig.from_pretrained(MODEL_DIR) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR, config=config) | |
| model.eval().to(device) | |
| # override labels from config if present | |
| if getattr(model.config, "id2label", None): | |
| ID2LABEL = {int(k): v for k, v in model.config.id2label.items()} | |
| print("Text classification model loaded successfully") | |
| print("MODEL_DIR:", MODEL_DIR) | |
| print("Labels:", ID2LABEL) | |
| except Exception as e: | |
| print(f"Error loading text model: {e}") | |
| print("Text prediction will return fallback responses") | |
| # ββ 3) Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict_text(text: str, max_length: int | None = None): | |
| if model is None or tokenizer is None: | |
| return {"predicted_class": "Human", "confidence": 0.0} | |
| if max_length is None: | |
| max_length = MAX_LEN | |
| try: | |
| enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length) | |
| enc = {k: v.to(device) for k, v in enc.items()} | |
| logits = model(**enc).logits | |
| probs = torch.softmax(logits, dim=-1).squeeze(0).detach().cpu().numpy() | |
| pred_id = int(probs.argmax(-1)) | |
| label = ID2LABEL.get(pred_id, str(pred_id)).capitalize() | |
| return {"predicted_class": label, "confidence": float(probs[pred_id])} | |
| except Exception as e: | |
| print(f"Error during text prediction: {e}") | |
| return {"predicted_class": "Human", "confidence": 0.0} | |
| # ββ 4) Batch (optional) ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict_batch(texts, batch_size=16): | |
| if model is None or tokenizer is None: | |
| return [{"predicted_class": "Human", "confidence": 0.0} for _ in texts] | |
| results = [] | |
| for i in range(0, len(texts), batch_size): | |
| chunk = texts[i:i+batch_size] | |
| enc = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=MAX_LEN, padding=True) | |
| enc = {k: v.to(device) for k, v in enc.items()} | |
| probs = torch.softmax(model(**enc).logits, dim=-1).detach().cpu().numpy() | |
| ids = probs.argmax(-1) | |
| for t, pid, p in zip(chunk, ids, probs): | |
| label = ID2LABEL.get(int(pid), str(int(pid))).capitalize() | |
| results.append({"text": t, "predicted_class": label, "confidence": float(p[int(pid)])}) | |
| return results | |