File size: 1,995 Bytes
34367da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import sys
import json
import torch
import os
from sentence_transformers import SentenceTransformer

def log(msg):
    sys.stderr.write(f"[GPU-Bridge] {msg}\n")
    sys.stderr.flush()

def main():
    try:
        # Check for GPU
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        log(f"Initializing on device: {device}")

        # Load model (optimized for T4 GPU)
        model_name = os.environ.get('EMBEDDING_MODEL', 'sentence-transformers/all-MiniLM-L6-v2')
        log(f"Loading model: {model_name}...")
        model = SentenceTransformer(model_name, device=device)
        log("✅ Model loaded successfully.")

        # Signal readiness
        print(json.dumps({"status": "ready", "device": device}))
        sys.stdout.flush()

        # Processing Loop
        for line in sys.stdin:
            try:
                if not line.strip():
                    continue
                
                payload = json.loads(line)
                
                if 'text' in payload:
                    # Single embedding
                    embedding = model.encode(payload['text'], convert_to_numpy=True).tolist()
                    print(json.dumps({"embedding": embedding}))
                
                elif 'texts' in payload:
                    # Batch embedding
                    embeddings = model.encode(payload['texts'], convert_to_numpy=True).tolist()
                    print(json.dumps({"embeddings": embeddings}))
                
                elif 'ping' in payload:
                    print(json.dumps({"pong": True}))

                sys.stdout.flush()

            except Exception as e:
                log(f"Error processing request: {str(e)}")
                print(json.dumps({"error": str(e)}))
                sys.stdout.flush()

    except Exception as e:
        log(f"Fatal startup error: {str(e)}")
        sys.exit(1)

if __name__ == "__main__":
    main()