likhonsheikh commited on
Commit
11ec1c6
·
verified ·
1 Parent(s): 1f3223c

Fix device mapping and pipeline creation for production use

Browse files
Files changed (1) hide show
  1. production_inference.py +30 -14
production_inference.py CHANGED
@@ -82,20 +82,36 @@ class ProthomAloModel:
82
  self.tokenizer.pad_token = self.tokenizer.eos_token
83
 
84
  # Load model
85
- self.model = AutoModelForCausalLM.from_pretrained(
86
- self.model_name,
87
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
88
- device_map=device,
89
- trust_remote_code=True
90
- )
91
-
92
- # Create pipeline for easier use
93
- self.pipeline = pipeline(
94
- "text-generation",
95
- model=self.model,
96
- tokenizer=self.tokenizer,
97
- device=0 if device == "cuda" else -1
98
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  logger.info("Model loaded successfully with production optimizations")
101
  return True
 
82
  self.tokenizer.pad_token = self.tokenizer.eos_token
83
 
84
  # Load model
85
+ if device == "auto":
86
+ # Use device_map="auto" for automatic device placement
87
+ self.model = AutoModelForCausalLM.from_pretrained(
88
+ self.model_name,
89
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
90
+ device_map="auto",
91
+ trust_remote_code=True
92
+ )
93
+ # Create pipeline without device specification when using device_map
94
+ self.pipeline = pipeline(
95
+ "text-generation",
96
+ model=self.model,
97
+ tokenizer=self.tokenizer
98
+ )
99
+ else:
100
+ # Specific device handling
101
+ device_obj = torch.device("cuda" if device == "cuda" and torch.cuda.is_available() else "cpu")
102
+ self.model = AutoModelForCausalLM.from_pretrained(
103
+ self.model_name,
104
+ torch_dtype=torch.float16 if device_obj.type == "cuda" else torch.float32,
105
+ trust_remote_code=True
106
+ ).to(device_obj)
107
+
108
+ # Create pipeline with device specification
109
+ self.pipeline = pipeline(
110
+ "text-generation",
111
+ model=self.model,
112
+ tokenizer=self.tokenizer,
113
+ device=0 if device_obj.type == "cuda" else -1
114
+ )
115
 
116
  logger.info("Model loaded successfully with production optimizations")
117
  return True