File size: 5,745 Bytes
d1735ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
from typing import Dict, List, Any
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize model and tokenizer
"""
self.model_dir = path
self.initialized = False
self.model = None
self.tokenizer = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def __call__(self, data):
"""
Main entry point for the handler
"""
if not self.initialized:
self.initialize()
if data is None:
return {"error": "No input data provided"}
try:
inputs = self.preprocess(data)
outputs = self.inference(inputs)
return self.postprocess(outputs)
except Exception as e:
return {"error": str(e)}
def initialize(self):
"""
Load the model and tokenizer
"""
try:
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model
self.model = AutoModelForCausalLM.from_pretrained(
self.model_dir,
device_map="auto",
torch_dtype=torch.float16
)
self.initialized = True
except Exception as e:
raise RuntimeError(f"Error initializing model: {str(e)}")
def build_prompt(self, project_info):
"""
Build an input prompt from project features
"""
nom = project_info.get("Nom du projet", "")
description = project_info.get("Description", "")
duree = project_info.get("Durée (mois)", "")
complexite = project_info.get("Complexité (1-5)", "")
secteur = project_info.get("Secteur", "")
taches = project_info.get("Tâches Identifiées", "")
prompt = (f"Nom du projet: {nom}\n"
f"Description: {description}\n"
f"Durée (mois): {duree}\n"
f"Complexité (1-5): {complexite}\n"
f"Secteur: {secteur}\n"
f"Tâches Identifiées: {taches}\n\n"
"### Instruction:\n"
"Fournis les informations en format JSON pour:\n"
"- Compétences Requises\n"
"- Employés Alloués\n"
"- Répartition par Compétences\n\n"
"### Réponse:\n")
return prompt
def preprocess(self, data):
"""
Preprocess the input data
"""
try:
inputs = data.get("inputs", {})
# Handle string inputs (could be JSON string or direct prompt)
if isinstance(inputs, str):
try:
# Try to parse as JSON
inputs = json.loads(inputs)
except:
# If parsing fails, assume it's a direct prompt
return {"prompt": inputs}
# Build prompt if project info is provided
if isinstance(inputs, dict) and "Nom du projet" in inputs:
prompt = self.build_prompt(inputs)
else:
prompt = inputs
return {"prompt": prompt}
except Exception as e:
raise Exception(f"Error in preprocessing: {str(e)}")
def inference(self, inputs):
"""
Generate text based on the input prompt
"""
try:
prompt = inputs.get("prompt", "")
# Tokenize input
tokenized_inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
# Generate output
with torch.no_grad():
outputs = self.model.generate(
**tokenized_inputs,
max_new_tokens=800,
do_sample=False, # Deterministic generation
eos_token_id=self.tokenizer.eos_token_id
)
# Decode output
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract response part (after "### Réponse:")
if "### Réponse:" in generated_text:
response = generated_text.split("### Réponse:")[-1].strip()
else:
response = generated_text.strip()
# Clean up response (remove markdown code block markers)
if response.startswith("```json"):
response = response.split("```json", 1)[1]
if response.startswith("```"):
response = response.split("```", 1)[1]
if response.endswith("```"):
response = response.rsplit("```", 1)[0]
return response.strip()
except Exception as e:
raise Exception(f"Error in inference: {str(e)}")
def postprocess(self, inference_output):
"""
Post-process the model output
"""
try:
# Try to parse as JSON to ensure it's valid
try:
parsed_json = json.loads(inference_output)
# Return the parsed JSON if successful
return {"generated_text": inference_output}
except json.JSONDecodeError:
# If not valid JSON, return as is
return {"generated_text": inference_output}
except Exception as e:
raise Exception(f"Error in postprocessing: {str(e)}") |