366degrees commited on
Commit
d2d2c63
·
verified ·
1 Parent(s): 3f56ea7

Update api_inference.py

Browse files
Files changed (1) hide show
  1. api_inference.py +73 -49
api_inference.py CHANGED
@@ -11,7 +11,8 @@ from transformers import (
11
  )
12
 
13
  # ============================================================
14
- # Redirect Hugging Face cache to /app/hf_cache (always writable)
 
15
  CACHE_DIR = "/app/hf_cache"
16
  os.makedirs(CACHE_DIR, exist_ok=True)
17
  os.environ["HF_HOME"] = CACHE_DIR
@@ -23,33 +24,44 @@ PORT = int(os.environ.get("PORT", 7860))
23
 
24
  app = Flask(__name__)
25
 
26
-
27
  # ============================================================
28
- # Register Custom SNP Architecture
29
  # ============================================================
30
- class CustomSNPConfig(PretrainedConfig):
 
31
  model_type = "custom_snp"
32
 
33
-
34
  class CustomSNPModel(PreTrainedModel):
35
  config_class = CustomSNPConfig
36
 
37
  def __init__(self, config):
38
  super().__init__(config)
39
- hidden_size = getattr(config, "hidden_size", 768)
40
- # Mirror and Prism heads
41
- self.encoder = nn.Linear(hidden_size, hidden_size)
 
 
 
42
  self.mirror_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
43
  self.prism_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
44
  self.projection = nn.Linear(hidden_size, 6)
45
 
46
- def forward(self, input_ids=None, attention_mask=None, **kwargs):
47
- # Simulate encoded representations
48
- x = self.encoder(input_ids.float()) if input_ids is not None else None
49
- x = self.mirror_head(x)
 
 
 
 
 
 
 
 
 
50
  x = self.prism_head(x)
51
- return self.projection(x)
52
-
53
 
54
  # Register model so AutoModel recognizes it
55
  AutoConfig.register("custom_snp", CustomSNPConfig)
@@ -61,27 +73,23 @@ AutoModel.register(CustomSNPConfig, CustomSNPModel)
61
  # ============================================================
62
  try:
63
  print("Loading model from:", MODEL_DIR)
 
 
64
  config = AutoConfig.from_pretrained(MODEL_DIR, trust_remote_code=True)
65
-
66
- # Try loading tokenizer; fallback if not mapped
67
- from transformers import RobertaTokenizer
68
- try:
69
- tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
70
- except Exception:
71
- print("⚠️ Falling back to default RoBERTa tokenizer.")
72
- tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
73
-
74
  model = AutoModel.from_pretrained(MODEL_DIR, config=config, trust_remote_code=True)
 
75
  model.eval()
76
  print("✅ Custom SNP model loaded successfully.")
77
 
78
  except Exception as e:
79
- print("❌ Error loading custom model:", e)
 
80
  raise e
81
 
82
 
83
  # ============================================================
84
- # Flask API Routes
85
  # ============================================================
86
  @app.route("/", methods=["GET"])
87
  def home():
@@ -95,32 +103,49 @@ def health():
95
 
96
  @app.route("/embed", methods=["POST"])
97
  def embed():
98
- data = request.get_json(force=True)
99
- text = data.get("text", "")
100
- if not text:
101
- return jsonify({"error": "Text is required"}), 400
102
-
103
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
104
- with torch.no_grad():
105
- embeddings = model(**inputs)
106
- if hasattr(embeddings, "last_hidden_state"):
107
- embeddings = embeddings.last_hidden_state.mean(dim=1)
108
- elif isinstance(embeddings, tuple):
109
- embeddings = embeddings[0]
110
- return jsonify({"embedding": embeddings.tolist()})
 
 
 
 
 
 
111
 
112
 
113
  @app.route("/reason", methods=["POST"])
114
  def reason():
115
- data = request.get_json(force=True)
116
- premise = data.get("premise", "")
117
- hypothesis = data.get("hypothesis", "")
118
- combined = f"{premise} {hypothesis}"
119
- inputs = tokenizer(combined, return_tensors="pt", truncation=True, padding=True)
120
- with torch.no_grad():
121
- output = model(**inputs)
122
- score = float(output.mean().item())
123
- return jsonify({"reasoning_score": score})
 
 
 
 
 
 
 
 
 
 
 
124
 
125
 
126
  # ============================================================
@@ -128,5 +153,4 @@ def reason():
128
  # ============================================================
129
  if __name__ == "__main__":
130
  print(f"🚀 Starting SNP Universal Embedding API on port {PORT}")
131
- app.run(host="0.0.0.0", port=PORT)
132
-
 
11
  )
12
 
13
  # ============================================================
14
+ # Cache and Port Configuration
15
+ # ============================================================
16
  CACHE_DIR = "/app/hf_cache"
17
  os.makedirs(CACHE_DIR, exist_ok=True)
18
  os.environ["HF_HOME"] = CACHE_DIR
 
24
 
25
  app = Flask(__name__)
26
 
 
27
  # ============================================================
28
+ # Register Custom SNP Architecture (THE FIX IS HERE)
29
  # ============================================================
30
+ class CustomSNPConfig(AutoConfig):
31
+ # This will correctly inherit 'custom_snp' from your config.json
32
  model_type = "custom_snp"
33
 
 
34
  class CustomSNPModel(PreTrainedModel):
35
  config_class = CustomSNPConfig
36
 
37
  def __init__(self, config):
38
  super().__init__(config)
39
+ # This is the correct way to load the base transformer
40
+ self.shared_encoder = AutoModel.from_config(config)
41
+
42
+ hidden_size = self.shared_encoder.config.hidden_size
43
+
44
+ # Your custom heads
45
  self.mirror_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
46
  self.prism_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
47
  self.projection = nn.Linear(hidden_size, 6)
48
 
49
+ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs):
50
+ # Pass inputs through the transformer
51
+ outputs = self.shared_encoder(
52
+ input_ids=input_ids,
53
+ attention_mask=attention_mask,
54
+ token_type_ids=token_type_ids
55
+ )
56
+
57
+ # Get the [CLS] token embedding
58
+ cls_embedding = outputs.last_hidden_state[:, 0, :]
59
+
60
+ # Pass through your custom heads
61
+ x = self.mirror_head(cls_embedding)
62
  x = self.prism_head(x)
63
+ proj = self.projection(x)
64
+ return proj # Return the final projection
65
 
66
  # Register model so AutoModel recognizes it
67
  AutoConfig.register("custom_snp", CustomSNPConfig)
 
73
  # ============================================================
74
  try:
75
  print("Loading model from:", MODEL_DIR)
76
+
77
+ # trust_remote_code=True is essential for this to work
78
  config = AutoConfig.from_pretrained(MODEL_DIR, trust_remote_code=True)
79
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True)
 
 
 
 
 
 
 
 
80
  model = AutoModel.from_pretrained(MODEL_DIR, config=config, trust_remote_code=True)
81
+
82
  model.eval()
83
  print("✅ Custom SNP model loaded successfully.")
84
 
85
  except Exception as e:
86
+ print(f"❌ Error loading custom model: {e}")
87
+ # This will print the detailed error to your Space logs
88
  raise e
89
 
90
 
91
  # ============================================================
92
+ # Flask API Routes (Your routes are correct)
93
  # ============================================================
94
  @app.route("/", methods=["GET"])
95
  def home():
 
103
 
104
  @app.route("/embed", methods=["POST"])
105
  def embed():
106
+ try:
107
+ data = request.get_json(force=True)
108
+ text = data.get("text", "")
109
+ if not text:
110
+ return jsonify({"error": "Text is required"}), 400
111
+
112
+ # Tokenize the text
113
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
114
+
115
+ # Run inference
116
+ with torch.no_grad():
117
+ embeddings = model(**inputs)
118
+
119
+ # The model's forward() method now directly returns the projection
120
+ return jsonify({"embedding": embeddings.tolist()})
121
+
122
+ except Exception as e:
123
+ print(f"ERROR in /embed: {e}")
124
+ return jsonify({"error": "Internal Server Error", "message": str(e)}), 500
125
 
126
 
127
  @app.route("/reason", methods=["POST"])
128
  def reason():
129
+ try:
130
+ data = request.get_json(force=True)
131
+ premise = data.get("premise", "")
132
+ hypothesis = data.get("hypothesis", "")
133
+ combined = f"{premise} {hypothesis}"
134
+
135
+ # Tokenize
136
+ inputs = tokenizer(combined, return_tensors="pt", truncation=True, padding=True)
137
+
138
+ # Run inference
139
+ with torch.no_grad():
140
+ output = model(**inputs)
141
+
142
+ # Calculate a score (e.g., mean of the projection)
143
+ score = float(output.mean().item())
144
+ return jsonify({"reasoning_score": score})
145
+
146
+ except Exception as e:
147
+ print(f"ERROR in /reason: {e}")
148
+ return jsonify({"error": "Internal Server Error", "message": str(e)}), 500
149
 
150
 
151
  # ============================================================
 
153
  # ============================================================
154
  if __name__ == "__main__":
155
  print(f"🚀 Starting SNP Universal Embedding API on port {PORT}")
156
+ app.run(host="0.0.0.0", port=PORT)