|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from flask import Flask, request, jsonify |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModel, |
|
|
AutoConfig, |
|
|
PretrainedConfig, |
|
|
PreTrainedModel, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
CACHE_DIR = "/app/hf_cache" |
|
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
os.environ["HF_HOME"] = CACHE_DIR |
|
|
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR |
|
|
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" |
|
|
|
|
|
MODEL_DIR = "./" |
|
|
PORT = int(os.environ.get("PORT", 7860)) |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CustomSNPConfig(PretrainedConfig): |
|
|
model_type = "custom_snp" |
|
|
|
|
|
|
|
|
class CustomSNPModel(PreTrainedModel): |
|
|
config_class = CustomSNPConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
hidden_size = getattr(config, "hidden_size", 768) |
|
|
|
|
|
self.encoder = nn.Linear(hidden_size, hidden_size) |
|
|
self.mirror_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh()) |
|
|
self.prism_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh()) |
|
|
self.projection = nn.Linear(hidden_size, 6) |
|
|
|
|
|
def forward(self, input_ids=None, attention_mask=None, **kwargs): |
|
|
if input_ids is not None: |
|
|
x = input_ids.float() |
|
|
|
|
|
x = x.view(x.size(0), -1) |
|
|
|
|
|
|
|
|
hidden_dim = self.encoder.in_features |
|
|
seq_dim = x.size(1) |
|
|
|
|
|
if seq_dim < hidden_dim: |
|
|
|
|
|
pad = hidden_dim - seq_dim |
|
|
x = torch.nn.functional.pad(x, (0, pad)) |
|
|
elif seq_dim > hidden_dim: |
|
|
|
|
|
x = x[:, :hidden_dim] |
|
|
|
|
|
|
|
|
x = self.encoder(x) |
|
|
else: |
|
|
x = None |
|
|
|
|
|
|
|
|
x = self.mirror_head(x) |
|
|
x = self.prism_head(x) |
|
|
return self.projection(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AutoConfig.register("custom_snp", CustomSNPConfig) |
|
|
AutoModel.register(CustomSNPConfig, CustomSNPModel) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
print("Loading model from:", MODEL_DIR) |
|
|
config = AutoConfig.from_pretrained(MODEL_DIR, trust_remote_code=True) |
|
|
|
|
|
|
|
|
from transformers import RobertaTokenizer |
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) |
|
|
except Exception: |
|
|
print("β οΈ Falling back to default RoBERTa tokenizer.") |
|
|
tokenizer = RobertaTokenizer.from_pretrained("roberta-base") |
|
|
|
|
|
model = AutoModel.from_pretrained(MODEL_DIR, config=config, trust_remote_code=True) |
|
|
model.eval() |
|
|
print("β
Custom SNP model loaded successfully.") |
|
|
|
|
|
except Exception as e: |
|
|
print("β Error loading custom model:", e) |
|
|
raise e |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.route("/", methods=["GET"]) |
|
|
def home(): |
|
|
return jsonify({"status": "SNP Universal Embedding API running"}) |
|
|
|
|
|
|
|
|
@app.route("/health", methods=["GET"]) |
|
|
def health(): |
|
|
return jsonify({"status": "healthy"}) |
|
|
|
|
|
|
|
|
@app.route("/embed", methods=["POST"]) |
|
|
def embed(): |
|
|
data = request.get_json(force=True) |
|
|
text = data.get("text", "") |
|
|
if not text: |
|
|
return jsonify({"error": "Text is required"}), 400 |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
|
with torch.no_grad(): |
|
|
inputs["input_ids"] = inputs["input_ids"].float() |
|
|
embeddings = model(**inputs) |
|
|
if hasattr(embeddings, "last_hidden_state"): |
|
|
embeddings = embeddings.last_hidden_state.mean(dim=1) |
|
|
elif isinstance(embeddings, tuple): |
|
|
embeddings = embeddings[0] |
|
|
return jsonify({"embedding": embeddings.tolist()}) |
|
|
|
|
|
|
|
|
@app.route("/reason", methods=["POST"]) |
|
|
def reason(): |
|
|
data = request.get_json(force=True) |
|
|
premise = data.get("premise", "") |
|
|
hypothesis = data.get("hypothesis", "") |
|
|
combined = f"{premise} {hypothesis}" |
|
|
inputs = tokenizer(combined, return_tensors="pt", truncation=True, padding=True) |
|
|
with torch.no_grad(): |
|
|
output = model(**inputs) |
|
|
score = float(output.mean().item()) |
|
|
return jsonify({"reasoning_score": score}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print(f"π Starting SNP Universal Embedding API on port {PORT}") |
|
|
app.run(host="0.0.0.0", port=PORT) |
|
|
|
|
|
|