mohamed-amine49 commited on
Commit
d1735ec
·
verified ·
1 Parent(s): 3755199

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +160 -165
handler.py CHANGED
@@ -1,165 +1,160 @@
1
- from typing import Dict, List, Any
2
- import json
3
- import torch
4
- from transformers import AutoModelForCausalLM, AutoTokenizer
5
-
6
-
7
- # Default handler class
8
- class Handler:
9
- def __init__(self):
10
- self.initialized = False
11
- self.model = None
12
- self.tokenizer = None
13
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
-
15
- def initialize(self, context):
16
- """
17
- Initialize model and tokenizer
18
- """
19
- model_dir = context.model_dir
20
-
21
- # Load tokenizer and model
22
- self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
23
- if not self.tokenizer.pad_token:
24
- self.tokenizer.pad_token = self.tokenizer.eos_token
25
-
26
- self.model = AutoModelForCausalLM.from_pretrained(
27
- model_dir,
28
- device_map="auto",
29
- torch_dtype=torch.float16
30
- )
31
-
32
- self.initialized = True
33
-
34
- def build_prompt(self, project_info):
35
- """
36
- Build an input prompt from project features
37
- """
38
- nom = project_info.get("Nom du projet", "")
39
- description = project_info.get("Description", "")
40
- duree = project_info.get("Durée (mois)", "")
41
- complexite = project_info.get("Complexité (1-5)", "")
42
- secteur = project_info.get("Secteur", "")
43
- taches = project_info.get("Tâches Identifiées", "")
44
-
45
- prompt = (f"Nom du projet: {nom}\n"
46
- f"Description: {description}\n"
47
- f"Durée (mois): {duree}\n"
48
- f"Complexité (1-5): {complexite}\n"
49
- f"Secteur: {secteur}\n"
50
- f"Tâches Identifiées: {taches}\n\n"
51
- "### Instruction:\n"
52
- "Fournis les informations en format JSON pour:\n"
53
- "- Compétences Requises\n"
54
- "- Employés Alloués\n"
55
- "- Répartition par Compétences\n\n"
56
- "### Réponse:\n")
57
- return prompt
58
-
59
- def preprocess(self, data):
60
- """
61
- Preprocess the input data
62
- """
63
- try:
64
- inputs = data.get("inputs", {})
65
- if isinstance(inputs, str):
66
- # If inputs is a string, try to parse it as JSON
67
- try:
68
- inputs = json.loads(inputs)
69
- except:
70
- # If parsing fails, assume it's a direct prompt
71
- return {"prompt": inputs}
72
-
73
- # Build prompt if project info is provided
74
- if isinstance(inputs, dict) and "Nom du projet" in inputs:
75
- prompt = self.build_prompt(inputs)
76
- else:
77
- prompt = inputs
78
-
79
- return {"prompt": prompt}
80
- except Exception as e:
81
- raise Exception(f"Error in preprocessing: {str(e)}")
82
-
83
- def inference(self, inputs):
84
- """
85
- Generate text based on the input prompt
86
- """
87
- try:
88
- prompt = inputs.get("prompt", "")
89
-
90
- # Tokenize input
91
- tokenized_inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
92
-
93
- # Generate output
94
- with torch.no_grad():
95
- outputs = self.model.generate(
96
- **tokenized_inputs,
97
- max_new_tokens=800,
98
- do_sample=False, # Deterministic generation
99
- eos_token_id=self.tokenizer.eos_token_id
100
- )
101
-
102
- # Decode output
103
- generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
104
-
105
- # Extract response part (after "### Réponse:")
106
- if "### Réponse:" in generated_text:
107
- response = generated_text.split("### Réponse:")[-1].strip()
108
- else:
109
- response = generated_text.strip()
110
-
111
- # Clean up response (remove markdown code block markers)
112
- if response.startswith("```json"):
113
- response = response.split("```json", 1)[1]
114
- if response.startswith("```"):
115
- response = response.split("```", 1)[1]
116
- if response.endswith("```"):
117
- response = response.rsplit("```", 1)[0]
118
-
119
- return response.strip()
120
- except Exception as e:
121
- raise Exception(f"Error in inference: {str(e)}")
122
-
123
- def postprocess(self, inference_output):
124
- """
125
- Post-process the model output
126
- """
127
- try:
128
- # Try to parse as JSON to ensure it's valid
129
- try:
130
- parsed_json = json.loads(inference_output)
131
- # Return the parsed JSON if successful
132
- return {"generated_text": inference_output}
133
- except json.JSONDecodeError:
134
- # If not valid JSON, return as is
135
- return {"generated_text": inference_output}
136
- except Exception as e:
137
- raise Exception(f"Error in postprocessing: {str(e)}")
138
-
139
- def handle(self, data, context):
140
- """
141
- Handle the complete inference process
142
- """
143
- try:
144
- if not self.initialized:
145
- self.initialize(context)
146
-
147
- preprocessed_data = self.preprocess(data)
148
- inference_output = self.inference(preprocessed_data)
149
- return self.postprocess(inference_output)
150
- except Exception as e:
151
- raise Exception(f"Error in handler: {str(e)}")
152
-
153
-
154
- # The handler function
155
- _service = Handler()
156
-
157
-
158
- def handle(data, context):
159
- if not _service.initialized:
160
- _service.initialize(context)
161
-
162
- if data is None:
163
- return None
164
-
165
- return _service.handle(data, context)
 
1
+ from typing import Dict, List, Any
2
+ import json
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, path=""):
8
+ """
9
+ Initialize model and tokenizer
10
+ """
11
+ self.model_dir = path
12
+ self.initialized = False
13
+ self.model = None
14
+ self.tokenizer = None
15
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ def __call__(self, data):
18
+ """
19
+ Main entry point for the handler
20
+ """
21
+ if not self.initialized:
22
+ self.initialize()
23
+
24
+ if data is None:
25
+ return {"error": "No input data provided"}
26
+
27
+ try:
28
+ inputs = self.preprocess(data)
29
+ outputs = self.inference(inputs)
30
+ return self.postprocess(outputs)
31
+ except Exception as e:
32
+ return {"error": str(e)}
33
+
34
+ def initialize(self):
35
+ """
36
+ Load the model and tokenizer
37
+ """
38
+ try:
39
+ # Load tokenizer
40
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
41
+ if not self.tokenizer.pad_token:
42
+ self.tokenizer.pad_token = self.tokenizer.eos_token
43
+
44
+ # Load model
45
+ self.model = AutoModelForCausalLM.from_pretrained(
46
+ self.model_dir,
47
+ device_map="auto",
48
+ torch_dtype=torch.float16
49
+ )
50
+
51
+ self.initialized = True
52
+ except Exception as e:
53
+ raise RuntimeError(f"Error initializing model: {str(e)}")
54
+
55
+ def build_prompt(self, project_info):
56
+ """
57
+ Build an input prompt from project features
58
+ """
59
+ nom = project_info.get("Nom du projet", "")
60
+ description = project_info.get("Description", "")
61
+ duree = project_info.get("Durée (mois)", "")
62
+ complexite = project_info.get("Complexité (1-5)", "")
63
+ secteur = project_info.get("Secteur", "")
64
+ taches = project_info.get("Tâches Identifiées", "")
65
+
66
+ prompt = (f"Nom du projet: {nom}\n"
67
+ f"Description: {description}\n"
68
+ f"Durée (mois): {duree}\n"
69
+ f"Complexité (1-5): {complexite}\n"
70
+ f"Secteur: {secteur}\n"
71
+ f"Tâches Identifiées: {taches}\n\n"
72
+ "### Instruction:\n"
73
+ "Fournis les informations en format JSON pour:\n"
74
+ "- Compétences Requises\n"
75
+ "- Employés Alloués\n"
76
+ "- Répartition par Compétences\n\n"
77
+ "### Réponse:\n")
78
+ return prompt
79
+
80
+ def preprocess(self, data):
81
+ """
82
+ Preprocess the input data
83
+ """
84
+ try:
85
+ inputs = data.get("inputs", {})
86
+
87
+ # Handle string inputs (could be JSON string or direct prompt)
88
+ if isinstance(inputs, str):
89
+ try:
90
+ # Try to parse as JSON
91
+ inputs = json.loads(inputs)
92
+ except:
93
+ # If parsing fails, assume it's a direct prompt
94
+ return {"prompt": inputs}
95
+
96
+ # Build prompt if project info is provided
97
+ if isinstance(inputs, dict) and "Nom du projet" in inputs:
98
+ prompt = self.build_prompt(inputs)
99
+ else:
100
+ prompt = inputs
101
+
102
+ return {"prompt": prompt}
103
+ except Exception as e:
104
+ raise Exception(f"Error in preprocessing: {str(e)}")
105
+
106
+ def inference(self, inputs):
107
+ """
108
+ Generate text based on the input prompt
109
+ """
110
+ try:
111
+ prompt = inputs.get("prompt", "")
112
+
113
+ # Tokenize input
114
+ tokenized_inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
115
+
116
+ # Generate output
117
+ with torch.no_grad():
118
+ outputs = self.model.generate(
119
+ **tokenized_inputs,
120
+ max_new_tokens=800,
121
+ do_sample=False, # Deterministic generation
122
+ eos_token_id=self.tokenizer.eos_token_id
123
+ )
124
+
125
+ # Decode output
126
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
127
+
128
+ # Extract response part (after "### Réponse:")
129
+ if "### Réponse:" in generated_text:
130
+ response = generated_text.split("### Réponse:")[-1].strip()
131
+ else:
132
+ response = generated_text.strip()
133
+
134
+ # Clean up response (remove markdown code block markers)
135
+ if response.startswith("```json"):
136
+ response = response.split("```json", 1)[1]
137
+ if response.startswith("```"):
138
+ response = response.split("```", 1)[1]
139
+ if response.endswith("```"):
140
+ response = response.rsplit("```", 1)[0]
141
+
142
+ return response.strip()
143
+ except Exception as e:
144
+ raise Exception(f"Error in inference: {str(e)}")
145
+
146
+ def postprocess(self, inference_output):
147
+ """
148
+ Post-process the model output
149
+ """
150
+ try:
151
+ # Try to parse as JSON to ensure it's valid
152
+ try:
153
+ parsed_json = json.loads(inference_output)
154
+ # Return the parsed JSON if successful
155
+ return {"generated_text": inference_output}
156
+ except json.JSONDecodeError:
157
+ # If not valid JSON, return as is
158
+ return {"generated_text": inference_output}
159
+ except Exception as e:
160
+ raise Exception(f"Error in postprocessing: {str(e)}")