File size: 4,468 Bytes
8c8d036 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
import torch
import torch.nn as nn
from flask import Flask, request, jsonify
from transformers import (
AutoTokenizer,
AutoModel,
AutoConfig,
PretrainedConfig,
PreTrainedModel,
)
# ============================================================
# Redirect Hugging Face cache to /app/hf_cache (always writable)
CACHE_DIR = "/app/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ["HF_HOME"] = CACHE_DIR
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
MODEL_DIR = "./"
PORT = int(os.environ.get("PORT", 7860))
app = Flask(__name__)
# ============================================================
# Register Custom SNP Architecture
# ============================================================
class CustomSNPConfig(PretrainedConfig):
model_type = "custom_snp"
class CustomSNPModel(PreTrainedModel):
config_class = CustomSNPConfig
def __init__(self, config):
super().__init__(config)
hidden_size = getattr(config, "hidden_size", 768)
# Mirror and Prism heads
self.encoder = nn.Linear(hidden_size, hidden_size)
self.mirror_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
self.prism_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
self.projection = nn.Linear(hidden_size, 6)
def forward(self, input_ids=None, attention_mask=None, **kwargs):
# Simulate encoded representations
x = self.encoder(input_ids.float()) if input_ids is not None else None
x = self.mirror_head(x)
x = self.prism_head(x)
return self.projection(x)
# Register model so AutoModel recognizes it
AutoConfig.register("custom_snp", CustomSNPConfig)
AutoModel.register(CustomSNPConfig, CustomSNPModel)
# ============================================================
# Load Model & Tokenizer
# ============================================================
try:
print("Loading model from:", MODEL_DIR)
config = AutoConfig.from_pretrained(MODEL_DIR, trust_remote_code=True)
# Try loading tokenizer; fallback if not mapped
from transformers import RobertaTokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
except Exception:
print("⚠️ Falling back to default RoBERTa tokenizer.")
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
model = AutoModel.from_pretrained(MODEL_DIR, config=config, trust_remote_code=True)
model.eval()
print("✅ Custom SNP model loaded successfully.")
except Exception as e:
print("❌ Error loading custom model:", e)
raise e
# ============================================================
# Flask API Routes
# ============================================================
@app.route("/", methods=["GET"])
def home():
return jsonify({"status": "SNP Universal Embedding API running"})
@app.route("/health", methods=["GET"])
def health():
return jsonify({"status": "healthy"})
@app.route("/embed", methods=["POST"])
def embed():
data = request.get_json(force=True)
text = data.get("text", "")
if not text:
return jsonify({"error": "Text is required"}), 400
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
embeddings = model(**inputs)
if hasattr(embeddings, "last_hidden_state"):
embeddings = embeddings.last_hidden_state.mean(dim=1)
elif isinstance(embeddings, tuple):
embeddings = embeddings[0]
return jsonify({"embedding": embeddings.tolist()})
@app.route("/reason", methods=["POST"])
def reason():
data = request.get_json(force=True)
premise = data.get("premise", "")
hypothesis = data.get("hypothesis", "")
combined = f"{premise} {hypothesis}"
inputs = tokenizer(combined, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
output = model(**inputs)
score = float(output.mean().item())
return jsonify({"reasoning_score": score})
# ============================================================
# Run Server
# ============================================================
if __name__ == "__main__":
print(f"🚀 Starting SNP Universal Embedding API on port {PORT}")
app.run(host="0.0.0.0", port=PORT)
|