366degrees commited on
Commit
0527b41
·
verified ·
1 Parent(s): 3795fda

Delete api_inference.py

Browse files
Files changed (1) hide show
  1. api_inference.py +0 -84
api_inference.py DELETED
@@ -1,84 +0,0 @@
1
- # api_inference.py
2
- import torch
3
- from transformers import AutoTokenizer, AutoModel, AutoConfig
4
- from flask import Flask, request, jsonify
5
- import os, json
6
-
7
- app = Flask(__name__)
8
-
9
- # === Load Model ===
10
- MODEL_DIR = os.path.dirname(os.path.abspath(__file__))
11
-
12
- print(f"🔍 Loading model from {MODEL_DIR} ...")
13
-
14
- try:
15
- # --- Register your custom model class ---
16
- from transformers.models.auto.modeling_auto import MODEL_MAPPING
17
- from snp_universal_embedding import CustomSNPModel
18
-
19
- # Register custom class to handle 'custom_snp' type
20
- class DummyConfig(AutoConfig):
21
- model_type = "custom_snp"
22
-
23
- MODEL_MAPPING.register(DummyConfig, CustomSNPModel)
24
-
25
- # Load model and tokenizer
26
- config = AutoConfig.from_pretrained(MODEL_DIR, trust_remote_code=True)
27
- tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
28
- model = AutoModel.from_pretrained(MODEL_DIR, config=config, trust_remote_code=True)
29
- model.eval()
30
-
31
- print("✅ Custom SNP model loaded successfully.")
32
- except Exception as e:
33
- print("❌ Error loading custom model:", e)
34
- raise e
35
-
36
-
37
- # === Define Endpoints ===
38
- @app.route("/")
39
- def index():
40
- return jsonify({
41
- "status": "SNP Universal Embedding API running",
42
- "endpoints": ["/embed", "/reason"]
43
- })
44
-
45
-
46
- @app.route("/embed", methods=["POST"])
47
- def embed():
48
- try:
49
- data = request.get_json()
50
- text = data.get("text", "")
51
- if not text:
52
- return jsonify({"error": "No text provided"}), 400
53
-
54
- inputs = tokenizer(text, return_tensors="pt")
55
- with torch.no_grad():
56
- outputs = model(**inputs)
57
- if isinstance(outputs, dict) and "last_hidden_state" in outputs:
58
- embedding = outputs["last_hidden_state"].mean(dim=1).squeeze().tolist()
59
- else:
60
- embedding = outputs.mean(dim=1).squeeze().tolist()
61
-
62
- return jsonify({"embedding": embedding})
63
- except Exception as e:
64
- return jsonify({"error": str(e)}), 500
65
-
66
-
67
- @app.route("/health")
68
- def health():
69
- return "ok", 200
70
-
71
-
72
- @app.route("/reason", methods=["POST"])
73
- def reason():
74
- data = request.get_json()
75
- text = data.get("text", "")
76
- return jsonify({
77
- "text": text,
78
- "reasoning_status": "Feature in development for SNP reasoning structure"
79
- })
80
-
81
-
82
- if __name__ == "__main__":
83
- port = int(os.environ.get("PORT", 8080))
84
- app.run(host="0.0.0.0", port=port)