PunchNFIT commited on
Commit
2127a29
·
1 Parent(s): 5c20eb3

Fix tokenizer mapping for CustomSNPConfig

Browse files
Files changed (1) hide show
  1. api_inference.py +41 -47
api_inference.py CHANGED
@@ -3,15 +3,27 @@ import torch
3
  import torch.nn as nn
4
  from flask import Flask, request, jsonify
5
  from transformers import (
6
- AutoConfig,
7
  AutoModel,
 
8
  PretrainedConfig,
9
  PreTrainedModel,
10
  )
11
- from transformers import RobertaTokenizerFast as RobertaTokenizer
12
 
13
  # ============================================================
14
- # Custom SNP Architecture (no Gunicorn complications)
 
 
 
 
 
 
 
 
 
 
 
 
15
  # ============================================================
16
  class CustomSNPConfig(PretrainedConfig):
17
  model_type = "custom_snp"
@@ -22,72 +34,70 @@ class CustomSNPModel(PreTrainedModel):
22
 
23
  def __init__(self, config):
24
  super().__init__(config)
25
- hidden = getattr(config, "hidden_size", 768)
26
- self.encoder = nn.Linear(hidden, hidden)
27
- self.mirror_head = nn.Sequential(nn.Linear(hidden, hidden), nn.Tanh())
28
- self.prism_head = nn.Sequential(nn.Linear(hidden, hidden), nn.Tanh())
29
- self.projection = nn.Linear(hidden, 6)
 
30
 
31
  def forward(self, input_ids=None, attention_mask=None, **kwargs):
32
- if input_ids is None:
33
- raise ValueError("input_ids required")
34
- x = self.encoder(input_ids.float())
35
  x = self.mirror_head(x)
36
  x = self.prism_head(x)
37
  return self.projection(x)
38
 
39
 
40
- # ============================================================
41
- # Environment
42
- # ============================================================
43
- os.environ["HF_HOME"] = "/tmp/huggingface"
44
- MODEL_DIR = "./"
45
- PORT = int(os.environ.get("PORT", 7860))
46
- app = Flask(__name__)
47
 
48
  # ============================================================
49
- # Load Model & Tokenizer (direct tokenizer, no AutoTokenizer)
50
  # ============================================================
51
  try:
52
  print("Loading model from:", MODEL_DIR)
53
  config = AutoConfig.from_pretrained(MODEL_DIR, trust_remote_code=True)
54
 
55
- # Use concrete tokenizer to avoid mapping issues
 
56
  try:
57
- tokenizer = RobertaTokenizer.from_pretrained(MODEL_DIR)
58
- print("✅ Loaded tokenizer from model directory.")
59
  except Exception:
60
- print("⚠️ Falling back to default roberta-base tokenizer.")
61
  tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
62
 
63
- model = CustomSNPModel(config)
64
- if os.path.exists(os.path.join(MODEL_DIR, "pytorch_model.bin")):
65
- state = torch.load(os.path.join(MODEL_DIR, "pytorch_model.bin"), map_location="cpu")
66
- model.load_state_dict(state, strict=False)
67
  model.eval()
68
  print("✅ Custom SNP model loaded successfully.")
 
69
  except Exception as e:
70
  print("❌ Error loading custom model:", e)
71
  raise e
72
 
73
 
74
  # ============================================================
75
- # Routes
76
  # ============================================================
77
  @app.route("/", methods=["GET"])
78
  def home():
79
  return jsonify({"status": "SNP Universal Embedding API running"})
80
 
 
81
  @app.route("/health", methods=["GET"])
82
  def health():
83
  return jsonify({"status": "healthy"})
84
 
 
85
  @app.route("/embed", methods=["POST"])
86
  def embed():
87
  data = request.get_json(force=True)
88
  text = data.get("text", "")
89
  if not text:
90
  return jsonify({"error": "Text is required"}), 400
 
91
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
92
  with torch.no_grad():
93
  embeddings = model(**inputs)
@@ -97,6 +107,7 @@ def embed():
97
  embeddings = embeddings[0]
98
  return jsonify({"embedding": embeddings.tolist()})
99
 
 
100
  @app.route("/reason", methods=["POST"])
101
  def reason():
102
  data = request.get_json(force=True)
@@ -109,28 +120,11 @@ def reason():
109
  score = float(output.mean().item())
110
  return jsonify({"reasoning_score": score})
111
 
112
- @app.route("/test", methods=["GET"])
113
- def test():
114
- sample_text = "She knows he cheats but stays anyway."
115
- inputs = tokenizer(sample_text, return_tensors="pt")
116
- with torch.no_grad():
117
- output = model(**inputs)
118
- if hasattr(output, "last_hidden_state"):
119
- vector = output.last_hidden_state.mean(dim=1).tolist()
120
- elif isinstance(output, tuple):
121
- vector = output[0].tolist()
122
- else:
123
- vector = output.tolist()
124
- return jsonify({
125
- "message": "SNP Universal Embedding model is active.",
126
- "sample_text": sample_text,
127
- "embedding_preview": vector[0][:6]
128
- })
129
-
130
 
131
  # ============================================================
132
- # Run Flask directly (no Gunicorn)
133
  # ============================================================
134
  if __name__ == "__main__":
135
  print(f"🚀 Starting SNP Universal Embedding API on port {PORT}")
136
  app.run(host="0.0.0.0", port=PORT)
 
 
3
  import torch.nn as nn
4
  from flask import Flask, request, jsonify
5
  from transformers import (
6
+ AutoTokenizer,
7
  AutoModel,
8
+ AutoConfig,
9
  PretrainedConfig,
10
  PreTrainedModel,
11
  )
 
12
 
13
  # ============================================================
14
+ # Environment Configuration
15
+ # ============================================================
16
+ os.environ["HF_HOME"] = "/tmp/huggingface"
17
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
18
+
19
+ MODEL_DIR = "./"
20
+ PORT = int(os.environ.get("PORT", 7860))
21
+
22
+ app = Flask(__name__)
23
+
24
+
25
+ # ============================================================
26
+ # Register Custom SNP Architecture
27
  # ============================================================
28
  class CustomSNPConfig(PretrainedConfig):
29
  model_type = "custom_snp"
 
34
 
35
  def __init__(self, config):
36
  super().__init__(config)
37
+ hidden_size = getattr(config, "hidden_size", 768)
38
+ # Mirror and Prism heads
39
+ self.encoder = nn.Linear(hidden_size, hidden_size)
40
+ self.mirror_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
41
+ self.prism_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
42
+ self.projection = nn.Linear(hidden_size, 6)
43
 
44
  def forward(self, input_ids=None, attention_mask=None, **kwargs):
45
+ # Simulate encoded representations
46
+ x = self.encoder(input_ids.float()) if input_ids is not None else None
 
47
  x = self.mirror_head(x)
48
  x = self.prism_head(x)
49
  return self.projection(x)
50
 
51
 
52
+ # Register model so AutoModel recognizes it
53
+ AutoConfig.register("custom_snp", CustomSNPConfig)
54
+ AutoModel.register(CustomSNPConfig, CustomSNPModel)
55
+
 
 
 
56
 
57
  # ============================================================
58
+ # Load Model & Tokenizer
59
  # ============================================================
60
  try:
61
  print("Loading model from:", MODEL_DIR)
62
  config = AutoConfig.from_pretrained(MODEL_DIR, trust_remote_code=True)
63
 
64
+ # Try loading tokenizer; fallback if not mapped
65
+ from transformers import RobertaTokenizer
66
  try:
67
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
 
68
  except Exception:
69
+ print("⚠️ Falling back to default RoBERTa tokenizer.")
70
  tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
71
 
72
+ model = AutoModel.from_pretrained(MODEL_DIR, config=config, trust_remote_code=True)
 
 
 
73
  model.eval()
74
  print("✅ Custom SNP model loaded successfully.")
75
+
76
  except Exception as e:
77
  print("❌ Error loading custom model:", e)
78
  raise e
79
 
80
 
81
  # ============================================================
82
+ # Flask API Routes
83
  # ============================================================
84
  @app.route("/", methods=["GET"])
85
  def home():
86
  return jsonify({"status": "SNP Universal Embedding API running"})
87
 
88
+
89
  @app.route("/health", methods=["GET"])
90
  def health():
91
  return jsonify({"status": "healthy"})
92
 
93
+
94
  @app.route("/embed", methods=["POST"])
95
  def embed():
96
  data = request.get_json(force=True)
97
  text = data.get("text", "")
98
  if not text:
99
  return jsonify({"error": "Text is required"}), 400
100
+
101
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
102
  with torch.no_grad():
103
  embeddings = model(**inputs)
 
107
  embeddings = embeddings[0]
108
  return jsonify({"embedding": embeddings.tolist()})
109
 
110
+
111
  @app.route("/reason", methods=["POST"])
112
  def reason():
113
  data = request.get_json(force=True)
 
120
  score = float(output.mean().item())
121
  return jsonify({"reasoning_score": score})
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  # ============================================================
125
+ # Run Server
126
  # ============================================================
127
  if __name__ == "__main__":
128
  print(f"🚀 Starting SNP Universal Embedding API on port {PORT}")
129
  app.run(host="0.0.0.0", port=PORT)
130
+