|
|
from typing import Dict, List, Any |
|
|
import json |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
""" |
|
|
Initialize model and tokenizer |
|
|
""" |
|
|
self.model_dir = path |
|
|
self.initialized = False |
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
def __call__(self, data): |
|
|
""" |
|
|
Main entry point for the handler |
|
|
""" |
|
|
if not self.initialized: |
|
|
self.initialize() |
|
|
|
|
|
if data is None: |
|
|
return {"error": "No input data provided"} |
|
|
|
|
|
try: |
|
|
inputs = self.preprocess(data) |
|
|
outputs = self.inference(inputs) |
|
|
return self.postprocess(outputs) |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
def initialize(self): |
|
|
""" |
|
|
Load the model and tokenizer |
|
|
""" |
|
|
try: |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) |
|
|
if not self.tokenizer.pad_token: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
self.model_dir, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16 |
|
|
) |
|
|
|
|
|
self.initialized = True |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Error initializing model: {str(e)}") |
|
|
|
|
|
def build_prompt(self, project_info): |
|
|
""" |
|
|
Build an input prompt from project features |
|
|
""" |
|
|
nom = project_info.get("Nom du projet", "") |
|
|
description = project_info.get("Description", "") |
|
|
duree = project_info.get("Durée (mois)", "") |
|
|
complexite = project_info.get("Complexité (1-5)", "") |
|
|
secteur = project_info.get("Secteur", "") |
|
|
taches = project_info.get("Tâches Identifiées", "") |
|
|
|
|
|
prompt = (f"Nom du projet: {nom}\n" |
|
|
f"Description: {description}\n" |
|
|
f"Durée (mois): {duree}\n" |
|
|
f"Complexité (1-5): {complexite}\n" |
|
|
f"Secteur: {secteur}\n" |
|
|
f"Tâches Identifiées: {taches}\n\n" |
|
|
"### Instruction:\n" |
|
|
"Fournis les informations en format JSON pour:\n" |
|
|
"- Compétences Requises\n" |
|
|
"- Employés Alloués\n" |
|
|
"- Répartition par Compétences\n\n" |
|
|
"### Réponse:\n") |
|
|
return prompt |
|
|
|
|
|
def preprocess(self, data): |
|
|
""" |
|
|
Preprocess the input data |
|
|
""" |
|
|
try: |
|
|
inputs = data.get("inputs", {}) |
|
|
|
|
|
|
|
|
if isinstance(inputs, str): |
|
|
try: |
|
|
|
|
|
inputs = json.loads(inputs) |
|
|
except: |
|
|
|
|
|
return {"prompt": inputs} |
|
|
|
|
|
|
|
|
if isinstance(inputs, dict) and "Nom du projet" in inputs: |
|
|
prompt = self.build_prompt(inputs) |
|
|
else: |
|
|
prompt = inputs |
|
|
|
|
|
return {"prompt": prompt} |
|
|
except Exception as e: |
|
|
raise Exception(f"Error in preprocessing: {str(e)}") |
|
|
|
|
|
def inference(self, inputs): |
|
|
""" |
|
|
Generate text based on the input prompt |
|
|
""" |
|
|
try: |
|
|
prompt = inputs.get("prompt", "") |
|
|
|
|
|
|
|
|
tokenized_inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**tokenized_inputs, |
|
|
max_new_tokens=800, |
|
|
do_sample=False, |
|
|
eos_token_id=self.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if "### Réponse:" in generated_text: |
|
|
response = generated_text.split("### Réponse:")[-1].strip() |
|
|
else: |
|
|
response = generated_text.strip() |
|
|
|
|
|
|
|
|
if response.startswith("```json"): |
|
|
response = response.split("```json", 1)[1] |
|
|
if response.startswith("```"): |
|
|
response = response.split("```", 1)[1] |
|
|
if response.endswith("```"): |
|
|
response = response.rsplit("```", 1)[0] |
|
|
|
|
|
return response.strip() |
|
|
except Exception as e: |
|
|
raise Exception(f"Error in inference: {str(e)}") |
|
|
|
|
|
def postprocess(self, inference_output): |
|
|
""" |
|
|
Post-process the model output |
|
|
""" |
|
|
try: |
|
|
|
|
|
try: |
|
|
parsed_json = json.loads(inference_output) |
|
|
|
|
|
return {"generated_text": inference_output} |
|
|
except json.JSONDecodeError: |
|
|
|
|
|
return {"generated_text": inference_output} |
|
|
except Exception as e: |
|
|
raise Exception(f"Error in postprocessing: {str(e)}") |