File size: 5,745 Bytes
d1735ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from typing import Dict, List, Any
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

class EndpointHandler:
    def __init__(self, path=""):
        """
        Initialize model and tokenizer
        """
        self.model_dir = path
        self.initialized = False
        self.model = None
        self.tokenizer = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

    def __call__(self, data):
        """
        Main entry point for the handler
        """
        if not self.initialized:
            self.initialize()
        
        if data is None:
            return {"error": "No input data provided"}
        
        try:
            inputs = self.preprocess(data)
            outputs = self.inference(inputs)
            return self.postprocess(outputs)
        except Exception as e:
            return {"error": str(e)}

    def initialize(self):
        """
        Load the model and tokenizer
        """
        try:
            # Load tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
            if not self.tokenizer.pad_token:
                self.tokenizer.pad_token = self.tokenizer.eos_token
                
            # Load model
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_dir,
                device_map="auto",
                torch_dtype=torch.float16
            )
            
            self.initialized = True
        except Exception as e:
            raise RuntimeError(f"Error initializing model: {str(e)}")

    def build_prompt(self, project_info):
        """
        Build an input prompt from project features
        """
        nom = project_info.get("Nom du projet", "")
        description = project_info.get("Description", "")
        duree = project_info.get("Durée (mois)", "")
        complexite = project_info.get("Complexité (1-5)", "")
        secteur = project_info.get("Secteur", "")
        taches = project_info.get("Tâches Identifiées", "")

        prompt = (f"Nom du projet: {nom}\n"
                f"Description: {description}\n"
                f"Durée (mois): {duree}\n"
                f"Complexité (1-5): {complexite}\n"
                f"Secteur: {secteur}\n"
                f"Tâches Identifiées: {taches}\n\n"
                "### Instruction:\n"
                "Fournis les informations en format JSON pour:\n"
                "- Compétences Requises\n"
                "- Employés Alloués\n"
                "- Répartition par Compétences\n\n"
                "### Réponse:\n")
        return prompt

    def preprocess(self, data):
        """
        Preprocess the input data
        """
        try:
            inputs = data.get("inputs", {})
            
            # Handle string inputs (could be JSON string or direct prompt)
            if isinstance(inputs, str):
                try:
                    # Try to parse as JSON
                    inputs = json.loads(inputs)
                except:
                    # If parsing fails, assume it's a direct prompt
                    return {"prompt": inputs}
            
            # Build prompt if project info is provided
            if isinstance(inputs, dict) and "Nom du projet" in inputs:
                prompt = self.build_prompt(inputs)
            else:
                prompt = inputs
                
            return {"prompt": prompt}
        except Exception as e:
            raise Exception(f"Error in preprocessing: {str(e)}")

    def inference(self, inputs):
        """
        Generate text based on the input prompt
        """
        try:
            prompt = inputs.get("prompt", "")
            
            # Tokenize input
            tokenized_inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            
            # Generate output
            with torch.no_grad():
                outputs = self.model.generate(
                    **tokenized_inputs,
                    max_new_tokens=800,
                    do_sample=False,  # Deterministic generation
                    eos_token_id=self.tokenizer.eos_token_id
                )
            
            # Decode output
            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Extract response part (after "### Réponse:")
            if "### Réponse:" in generated_text:
                response = generated_text.split("### Réponse:")[-1].strip()
            else:
                response = generated_text.strip()
                
            # Clean up response (remove markdown code block markers)
            if response.startswith("```json"):
                response = response.split("```json", 1)[1]
            if response.startswith("```"):
                response = response.split("```", 1)[1]
            if response.endswith("```"):
                response = response.rsplit("```", 1)[0]
                
            return response.strip()
        except Exception as e:
            raise Exception(f"Error in inference: {str(e)}")

    def postprocess(self, inference_output):
        """
        Post-process the model output
        """
        try:
            # Try to parse as JSON to ensure it's valid
            try:
                parsed_json = json.loads(inference_output)
                # Return the parsed JSON if successful
                return {"generated_text": inference_output}
            except json.JSONDecodeError:
                # If not valid JSON, return as is
                return {"generated_text": inference_output}
        except Exception as e:
            raise Exception(f"Error in postprocessing: {str(e)}")