from typing import Dict, List, Any import json import torch from transformers import AutoModelForCausalLM, AutoTokenizer class EndpointHandler: def __init__(self, path=""): """ Initialize model and tokenizer """ self.model_dir = path self.initialized = False self.model = None self.tokenizer = None self.device = "cuda" if torch.cuda.is_available() else "cpu" def __call__(self, data): """ Main entry point for the handler """ if not self.initialized: self.initialize() if data is None: return {"error": "No input data provided"} try: inputs = self.preprocess(data) outputs = self.inference(inputs) return self.postprocess(outputs) except Exception as e: return {"error": str(e)} def initialize(self): """ Load the model and tokenizer """ try: # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) if not self.tokenizer.pad_token: self.tokenizer.pad_token = self.tokenizer.eos_token # Load model self.model = AutoModelForCausalLM.from_pretrained( self.model_dir, device_map="auto", torch_dtype=torch.float16 ) self.initialized = True except Exception as e: raise RuntimeError(f"Error initializing model: {str(e)}") def build_prompt(self, project_info): """ Build an input prompt from project features """ nom = project_info.get("Nom du projet", "") description = project_info.get("Description", "") duree = project_info.get("Durée (mois)", "") complexite = project_info.get("Complexité (1-5)", "") secteur = project_info.get("Secteur", "") taches = project_info.get("Tâches Identifiées", "") prompt = (f"Nom du projet: {nom}\n" f"Description: {description}\n" f"Durée (mois): {duree}\n" f"Complexité (1-5): {complexite}\n" f"Secteur: {secteur}\n" f"Tâches Identifiées: {taches}\n\n" "### Instruction:\n" "Fournis les informations en format JSON pour:\n" "- Compétences Requises\n" "- Employés Alloués\n" "- Répartition par Compétences\n\n" "### Réponse:\n") return prompt def preprocess(self, data): """ Preprocess the input data """ try: inputs = data.get("inputs", {}) # Handle string inputs (could be JSON string or direct prompt) if isinstance(inputs, str): try: # Try to parse as JSON inputs = json.loads(inputs) except: # If parsing fails, assume it's a direct prompt return {"prompt": inputs} # Build prompt if project info is provided if isinstance(inputs, dict) and "Nom du projet" in inputs: prompt = self.build_prompt(inputs) else: prompt = inputs return {"prompt": prompt} except Exception as e: raise Exception(f"Error in preprocessing: {str(e)}") def inference(self, inputs): """ Generate text based on the input prompt """ try: prompt = inputs.get("prompt", "") # Tokenize input tokenized_inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) # Generate output with torch.no_grad(): outputs = self.model.generate( **tokenized_inputs, max_new_tokens=800, do_sample=False, # Deterministic generation eos_token_id=self.tokenizer.eos_token_id ) # Decode output generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract response part (after "### Réponse:") if "### Réponse:" in generated_text: response = generated_text.split("### Réponse:")[-1].strip() else: response = generated_text.strip() # Clean up response (remove markdown code block markers) if response.startswith("```json"): response = response.split("```json", 1)[1] if response.startswith("```"): response = response.split("```", 1)[1] if response.endswith("```"): response = response.rsplit("```", 1)[0] return response.strip() except Exception as e: raise Exception(f"Error in inference: {str(e)}") def postprocess(self, inference_output): """ Post-process the model output """ try: # Try to parse as JSON to ensure it's valid try: parsed_json = json.loads(inference_output) # Return the parsed JSON if successful return {"generated_text": inference_output} except json.JSONDecodeError: # If not valid JSON, return as is return {"generated_text": inference_output} except Exception as e: raise Exception(f"Error in postprocessing: {str(e)}")