File size: 5,066 Bytes
ff41043 bcede05 ff41043 bcede05 ff41043 2127a29 5c20eb3 2127a29 da92ced ff41043 da92ced 197bff9 2127a29 da92ced 2127a29 b313d1c ff41043 b313d1c ff41043 bcede05 b313d1c ff41043 da92ced 2127a29 ff41043 da92ced 7545ff7 da92ced ff41043 da92ced bcede05 b313d1c 7545ff7 2127a29 6cfd530 ff41043 2127a29 ff41043 bcede05 da92ced 2127a29 bcede05 2127a29 bcede05 da92ced bcede05 6cfd530 ff41043 b313d1c ff41043 bcede05 2127a29 ff41043 bcede05 ff41043 bcede05 2127a29 ff41043 da92ced bcede05 2127a29 ff41043 da92ced ff41043 5c20eb3 ff41043 2127a29 ff41043 bcede05 ff41043 da92ced |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
import torch
import torch.nn as nn
from flask import Flask, request, jsonify
from transformers import (
AutoTokenizer,
AutoModel,
AutoConfig,
PretrainedConfig,
PreTrainedModel,
)
# ============================================================
# Redirect Hugging Face cache to /app/hf_cache (always writable)
CACHE_DIR = "/app/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ["HF_HOME"] = CACHE_DIR
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
MODEL_DIR = "./"
PORT = int(os.environ.get("PORT", 7860))
app = Flask(__name__)
# ============================================================
# Register Custom SNP Architecture
# ============================================================
class CustomSNPConfig(PretrainedConfig):
model_type = "custom_snp"
class CustomSNPModel(PreTrainedModel):
config_class = CustomSNPConfig
def __init__(self, config):
super().__init__(config)
hidden_size = getattr(config, "hidden_size", 768)
# Mirror and Prism heads
self.encoder = nn.Linear(hidden_size, hidden_size)
self.mirror_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
self.prism_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
self.projection = nn.Linear(hidden_size, 6)
def forward(self, input_ids=None, attention_mask=None, **kwargs):
if input_ids is not None:
x = input_ids.float()
# Flatten token sequence for single-vector representation
x = x.view(x.size(0), -1)
# Match encoder input dimension
hidden_dim = self.encoder.in_features
seq_dim = x.size(1)
if seq_dim < hidden_dim:
# Pad with zeros if sequence shorter than hidden size
pad = hidden_dim - seq_dim
x = torch.nn.functional.pad(x, (0, pad))
elif seq_dim > hidden_dim:
# Trim if sequence longer than hidden size
x = x[:, :hidden_dim]
# Pass through encoder
x = self.encoder(x)
else:
x = None
# Mirror → Prism → Projection
x = self.mirror_head(x)
x = self.prism_head(x)
return self.projection(x)
# Register model so AutoModel recognizes it
AutoConfig.register("custom_snp", CustomSNPConfig)
AutoModel.register(CustomSNPConfig, CustomSNPModel)
# ============================================================
# Load Model & Tokenizer
# ============================================================
try:
print("Loading model from:", MODEL_DIR)
config = AutoConfig.from_pretrained(MODEL_DIR, trust_remote_code=True)
# Try loading tokenizer; fallback if not mapped
from transformers import RobertaTokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
except Exception:
print("⚠️ Falling back to default RoBERTa tokenizer.")
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
model = AutoModel.from_pretrained(MODEL_DIR, config=config, trust_remote_code=True)
model.eval()
print("✅ Custom SNP model loaded successfully.")
except Exception as e:
print("❌ Error loading custom model:", e)
raise e
# ============================================================
# Flask API Routes
# ============================================================
@app.route("/", methods=["GET"])
def home():
return jsonify({"status": "SNP Universal Embedding API running"})
@app.route("/health", methods=["GET"])
def health():
return jsonify({"status": "healthy"})
@app.route("/embed", methods=["POST"])
def embed():
data = request.get_json(force=True)
text = data.get("text", "")
if not text:
return jsonify({"error": "Text is required"}), 400
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
inputs["input_ids"] = inputs["input_ids"].float()
embeddings = model(**inputs)
if hasattr(embeddings, "last_hidden_state"):
embeddings = embeddings.last_hidden_state.mean(dim=1)
elif isinstance(embeddings, tuple):
embeddings = embeddings[0]
return jsonify({"embedding": embeddings.tolist()})
@app.route("/reason", methods=["POST"])
def reason():
data = request.get_json(force=True)
premise = data.get("premise", "")
hypothesis = data.get("hypothesis", "")
combined = f"{premise} {hypothesis}"
inputs = tokenizer(combined, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
output = model(**inputs)
score = float(output.mean().item())
return jsonify({"reasoning_score": score})
# ============================================================
# Run Server
# ============================================================
if __name__ == "__main__":
print(f"🚀 Starting SNP Universal Embedding API on port {PORT}")
app.run(host="0.0.0.0", port=PORT)
|