jdesiree commited on
Commit
c6b736c
·
verified ·
1 Parent(s): 82d9923

Create compile_model.py

Browse files
Files changed (1) hide show
  1. compile_model.py +73 -0
compile_model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # compile_model.py
2
+ """
3
+ With transformers, model downloading and caching is automatic.
4
+ This script just performs a warmup to:
5
+ 1. Download model during Docker build
6
+ 2. Compile CUDA kernels
7
+ 3. Verify installation
8
+ """
9
+ import os
10
+ import logging
11
+ from datetime import datetime
12
+
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Import shared model manager
17
+ from shared_models import get_shared_llama
18
+
19
+ HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
20
+
21
+ def warmup_model():
22
+ """
23
+ Warmup Llama-3.2-3B model:
24
+ - Downloads from HuggingFace Hub
25
+ - Loads with 4-bit quantization
26
+ - Runs test inference
27
+ """
28
+ logger.info("="*60)
29
+ logger.info("LLAMA-3.2-3B WARMUP")
30
+ logger.info("="*60)
31
+
32
+ try:
33
+ # Get shared model instance
34
+ llama = get_shared_llama()
35
+
36
+ # This triggers model download and loading
37
+ logger.info("Running warmup inference...")
38
+
39
+ test_response = llama.generate(
40
+ system_prompt="You are a helpful educational assistant.",
41
+ user_message="Hello, this is a test warmup message.",
42
+ max_tokens=20,
43
+ temperature=0.7,
44
+ )
45
+
46
+ logger.info(f"✅ Warmup successful")
47
+ logger.info(f" Response preview: {test_response[:80]}...")
48
+
49
+ # Get model info
50
+ info = llama.get_model_info()
51
+ logger.info("="*60)
52
+ logger.info("MODEL INFO")
53
+ logger.info("="*60)
54
+ for key, value in info.items():
55
+ logger.info(f" {key}: {value}")
56
+
57
+ logger.info("="*60)
58
+ logger.info("✅ MODEL READY FOR PRODUCTION")
59
+ logger.info("="*60)
60
+
61
+ return True
62
+
63
+ except Exception as e:
64
+ logger.error(f"❌ Warmup failed: {e}")
65
+ import traceback
66
+ traceback.print_exc()
67
+ return False
68
+
69
+
70
+ if __name__ == "__main__":
71
+ success = warmup_model()
72
+ if not success:
73
+ exit(1)