366degrees commited on
Commit
8c8d036
·
verified ·
1 Parent(s): 66c5911

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +28 -0
  2. api_inference.py +132 -0
  3. requirements.txt +7 -0
  4. snp_universal_embedding.py +62 -0
Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use lightweight Python base
2
+ FROM python:3.10-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Copy files
8
+ COPY . .
9
+
10
+ # Install dependencies
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ # Create cache directory and make it writable for non-root
14
+ RUN mkdir -p /app/hf_cache && chmod -R 777 /app/hf_cache
15
+
16
+ # Set environment variables for Hugging Face cache
17
+ ENV HF_HOME=/app/hf_cache
18
+ ENV TRANSFORMERS_CACHE=/app/hf_cache
19
+
20
+ # Expose Space port
21
+ EXPOSE 7860
22
+
23
+ # Switch to non-root user
24
+ RUN useradd -m appuser
25
+ USER appuser
26
+
27
+ # Run Flask directly (no Gunicorn)
28
+ CMD ["python", "api_inference.py"]
api_inference.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from flask import Flask, request, jsonify
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ AutoModel,
8
+ AutoConfig,
9
+ PretrainedConfig,
10
+ PreTrainedModel,
11
+ )
12
+
13
+ # ============================================================
14
+ # Redirect Hugging Face cache to /app/hf_cache (always writable)
15
+ CACHE_DIR = "/app/hf_cache"
16
+ os.makedirs(CACHE_DIR, exist_ok=True)
17
+ os.environ["HF_HOME"] = CACHE_DIR
18
+ os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
19
+ os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
20
+
21
+ MODEL_DIR = "./"
22
+ PORT = int(os.environ.get("PORT", 7860))
23
+
24
+ app = Flask(__name__)
25
+
26
+
27
+ # ============================================================
28
+ # Register Custom SNP Architecture
29
+ # ============================================================
30
+ class CustomSNPConfig(PretrainedConfig):
31
+ model_type = "custom_snp"
32
+
33
+
34
+ class CustomSNPModel(PreTrainedModel):
35
+ config_class = CustomSNPConfig
36
+
37
+ def __init__(self, config):
38
+ super().__init__(config)
39
+ hidden_size = getattr(config, "hidden_size", 768)
40
+ # Mirror and Prism heads
41
+ self.encoder = nn.Linear(hidden_size, hidden_size)
42
+ self.mirror_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
43
+ self.prism_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
44
+ self.projection = nn.Linear(hidden_size, 6)
45
+
46
+ def forward(self, input_ids=None, attention_mask=None, **kwargs):
47
+ # Simulate encoded representations
48
+ x = self.encoder(input_ids.float()) if input_ids is not None else None
49
+ x = self.mirror_head(x)
50
+ x = self.prism_head(x)
51
+ return self.projection(x)
52
+
53
+
54
+ # Register model so AutoModel recognizes it
55
+ AutoConfig.register("custom_snp", CustomSNPConfig)
56
+ AutoModel.register(CustomSNPConfig, CustomSNPModel)
57
+
58
+
59
+ # ============================================================
60
+ # Load Model & Tokenizer
61
+ # ============================================================
62
+ try:
63
+ print("Loading model from:", MODEL_DIR)
64
+ config = AutoConfig.from_pretrained(MODEL_DIR, trust_remote_code=True)
65
+
66
+ # Try loading tokenizer; fallback if not mapped
67
+ from transformers import RobertaTokenizer
68
+ try:
69
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
70
+ except Exception:
71
+ print("⚠️ Falling back to default RoBERTa tokenizer.")
72
+ tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
73
+
74
+ model = AutoModel.from_pretrained(MODEL_DIR, config=config, trust_remote_code=True)
75
+ model.eval()
76
+ print("✅ Custom SNP model loaded successfully.")
77
+
78
+ except Exception as e:
79
+ print("❌ Error loading custom model:", e)
80
+ raise e
81
+
82
+
83
+ # ============================================================
84
+ # Flask API Routes
85
+ # ============================================================
86
+ @app.route("/", methods=["GET"])
87
+ def home():
88
+ return jsonify({"status": "SNP Universal Embedding API running"})
89
+
90
+
91
+ @app.route("/health", methods=["GET"])
92
+ def health():
93
+ return jsonify({"status": "healthy"})
94
+
95
+
96
+ @app.route("/embed", methods=["POST"])
97
+ def embed():
98
+ data = request.get_json(force=True)
99
+ text = data.get("text", "")
100
+ if not text:
101
+ return jsonify({"error": "Text is required"}), 400
102
+
103
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
104
+ with torch.no_grad():
105
+ embeddings = model(**inputs)
106
+ if hasattr(embeddings, "last_hidden_state"):
107
+ embeddings = embeddings.last_hidden_state.mean(dim=1)
108
+ elif isinstance(embeddings, tuple):
109
+ embeddings = embeddings[0]
110
+ return jsonify({"embedding": embeddings.tolist()})
111
+
112
+
113
+ @app.route("/reason", methods=["POST"])
114
+ def reason():
115
+ data = request.get_json(force=True)
116
+ premise = data.get("premise", "")
117
+ hypothesis = data.get("hypothesis", "")
118
+ combined = f"{premise} {hypothesis}"
119
+ inputs = tokenizer(combined, return_tensors="pt", truncation=True, padding=True)
120
+ with torch.no_grad():
121
+ output = model(**inputs)
122
+ score = float(output.mean().item())
123
+ return jsonify({"reasoning_score": score})
124
+
125
+
126
+ # ============================================================
127
+ # Run Server
128
+ # ============================================================
129
+ if __name__ == "__main__":
130
+ print(f"🚀 Starting SNP Universal Embedding API on port {PORT}")
131
+ app.run(host="0.0.0.0", port=PORT)
132
+
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # requirements.txt
2
+ torch
3
+ transformers
4
+ sentence-transformers
5
+ flask
6
+ numpy
7
+ scikit-learn
snp_universal_embedding.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoModel, AutoTokenizer
4
+ import os, json
5
+
6
+ print("✅ Environment ready")
7
+ print("Torch:", torch.__version__)
8
+
9
+ # ============================================================
10
+ # Custom SNP Model Architecture
11
+ # ============================================================
12
+ class CustomSNPModel(nn.Module):
13
+ def __init__(self, base_model="bert-base-uncased"):
14
+ super().__init__()
15
+ self.shared_encoder = AutoModel.from_pretrained(base_model)
16
+ hidden_size = self.shared_encoder.config.hidden_size
17
+ self.mirror_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
18
+ self.prism_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
19
+ self.projection = nn.Linear(hidden_size, 6)
20
+
21
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None):
22
+ outputs = self.shared_encoder(
23
+ input_ids=input_ids,
24
+ attention_mask=attention_mask,
25
+ token_type_ids=token_type_ids
26
+ )
27
+ cls = outputs.last_hidden_state[:, 0, :]
28
+ proj = self.projection(cls)
29
+ return proj
30
+
31
+ print("✅ SNP architecture defined.")
32
+
33
+ # ============================================================
34
+ # Load Checkpoint (optional; comment out if not available)
35
+ # ============================================================
36
+ ckpt_path = "pytorch_model.bin"
37
+ if os.path.exists(ckpt_path):
38
+ print(f"Loading weights from {ckpt_path}")
39
+ state_dict = torch.load(ckpt_path, map_location="cpu")
40
+ clean_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
41
+
42
+ model = CustomSNPModel(base_model="bert-base-uncased")
43
+ model.load_state_dict(clean_state_dict, strict=False)
44
+ print("✅ Checkpoint loaded successfully.")
45
+ else:
46
+ print("⚠️ No checkpoint found, initializing new model.")
47
+ model = CustomSNPModel()
48
+
49
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
50
+
51
+ # ============================================================
52
+ # Example Inference
53
+ # ============================================================
54
+ text = "A student must decide between a scholarship and their family."
55
+ inputs = tokenizer(text, return_tensors="pt")
56
+ inputs.pop("token_type_ids", None)
57
+
58
+ with torch.no_grad():
59
+ output = model(**inputs)
60
+
61
+ print("✅ Embedding generated successfully.")
62
+ print("Embedding shape:", output.shape if hasattr(output, "shape") else type(output))