mistral_deploy / handler.py
mohamed-amine49's picture
Update handler.py
d1735ec verified
from typing import Dict, List, Any
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize model and tokenizer
"""
self.model_dir = path
self.initialized = False
self.model = None
self.tokenizer = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def __call__(self, data):
"""
Main entry point for the handler
"""
if not self.initialized:
self.initialize()
if data is None:
return {"error": "No input data provided"}
try:
inputs = self.preprocess(data)
outputs = self.inference(inputs)
return self.postprocess(outputs)
except Exception as e:
return {"error": str(e)}
def initialize(self):
"""
Load the model and tokenizer
"""
try:
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model
self.model = AutoModelForCausalLM.from_pretrained(
self.model_dir,
device_map="auto",
torch_dtype=torch.float16
)
self.initialized = True
except Exception as e:
raise RuntimeError(f"Error initializing model: {str(e)}")
def build_prompt(self, project_info):
"""
Build an input prompt from project features
"""
nom = project_info.get("Nom du projet", "")
description = project_info.get("Description", "")
duree = project_info.get("Durée (mois)", "")
complexite = project_info.get("Complexité (1-5)", "")
secteur = project_info.get("Secteur", "")
taches = project_info.get("Tâches Identifiées", "")
prompt = (f"Nom du projet: {nom}\n"
f"Description: {description}\n"
f"Durée (mois): {duree}\n"
f"Complexité (1-5): {complexite}\n"
f"Secteur: {secteur}\n"
f"Tâches Identifiées: {taches}\n\n"
"### Instruction:\n"
"Fournis les informations en format JSON pour:\n"
"- Compétences Requises\n"
"- Employés Alloués\n"
"- Répartition par Compétences\n\n"
"### Réponse:\n")
return prompt
def preprocess(self, data):
"""
Preprocess the input data
"""
try:
inputs = data.get("inputs", {})
# Handle string inputs (could be JSON string or direct prompt)
if isinstance(inputs, str):
try:
# Try to parse as JSON
inputs = json.loads(inputs)
except:
# If parsing fails, assume it's a direct prompt
return {"prompt": inputs}
# Build prompt if project info is provided
if isinstance(inputs, dict) and "Nom du projet" in inputs:
prompt = self.build_prompt(inputs)
else:
prompt = inputs
return {"prompt": prompt}
except Exception as e:
raise Exception(f"Error in preprocessing: {str(e)}")
def inference(self, inputs):
"""
Generate text based on the input prompt
"""
try:
prompt = inputs.get("prompt", "")
# Tokenize input
tokenized_inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
# Generate output
with torch.no_grad():
outputs = self.model.generate(
**tokenized_inputs,
max_new_tokens=800,
do_sample=False, # Deterministic generation
eos_token_id=self.tokenizer.eos_token_id
)
# Decode output
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract response part (after "### Réponse:")
if "### Réponse:" in generated_text:
response = generated_text.split("### Réponse:")[-1].strip()
else:
response = generated_text.strip()
# Clean up response (remove markdown code block markers)
if response.startswith("```json"):
response = response.split("```json", 1)[1]
if response.startswith("```"):
response = response.split("```", 1)[1]
if response.endswith("```"):
response = response.rsplit("```", 1)[0]
return response.strip()
except Exception as e:
raise Exception(f"Error in inference: {str(e)}")
def postprocess(self, inference_output):
"""
Post-process the model output
"""
try:
# Try to parse as JSON to ensure it's valid
try:
parsed_json = json.loads(inference_output)
# Return the parsed JSON if successful
return {"generated_text": inference_output}
except json.JSONDecodeError:
# If not valid JSON, return as is
return {"generated_text": inference_output}
except Exception as e:
raise Exception(f"Error in postprocessing: {str(e)}")