joseififif commited on
Commit
cf64161
·
verified ·
1 Parent(s): 7de1379

Create triton_inference/inference_engine.py

Browse files
triton_inference/inference_engine.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tritonclient.grpc as grpcclient
2
+ import numpy as np
3
+ from typing import Optional
4
+ import logging
5
+ import os
6
+
7
+ class TritonInferenceEngine:
8
+ def __init__(self):
9
+ self.triton_url = os.getenv("TRITON_URL", "localhost:8001")
10
+ self.model_name = os.getenv("MODEL_NAME", "wizardlm-7b")
11
+ self.client = None
12
+
13
+ async def initialize(self):
14
+ try:
15
+ self.client = grpcclient.InferenceServerClient(
16
+ url=self.triton_url,
17
+ verbose=False
18
+ )
19
+ if not self.client.is_model_ready(self.model_name):
20
+ raise RuntimeError(f"Model {self.model_name} is not ready")
21
+ logging.info(f"Connected to Triton Inference Server at {self.triton_url}")
22
+ except Exception as e:
23
+ logging.error(f"Error connecting to Triton: {e}")
24
+ raise
25
+
26
+ async def generate(
27
+ self,
28
+ prompt: str,
29
+ max_tokens: int = 100,
30
+ temperature: float = 0.7,
31
+ top_p: float = 0.9
32
+ ) -> str:
33
+ # Preprocesar el prompt (aquí necesitarías tokenizar con el tokenizer adecuado)
34
+ # Para simplificar, asumimos que el modelo espera input_ids y attention_mask
35
+ # En un caso real, usarías un tokenizer como en el ejemplo de HuggingFace
36
+
37
+ # Este es un ejemplo simplificado. En producción, necesitarás adaptarlo a tu modelo.
38
+ inputs = self._prepare_inputs(prompt)
39
+
40
+ # Configurar inputs para Triton
41
+ triton_inputs = [
42
+ grpcclient.InferInput("input_ids", inputs["input_ids"].shape, "INT64"),
43
+ grpcclient.InferInput("attention_mask", inputs["attention_mask"].shape, "INT64"),
44
+ ]
45
+ triton_inputs[0].set_data_from_numpy(inputs["input_ids"])
46
+ triton_inputs[1].set_data_from_numpy(inputs["attention_mask"])
47
+
48
+ # Configurar outputs
49
+ outputs = [grpcclient.InferRequestedOutput("output_ids")]
50
+
51
+ # Realizar inferencia
52
+ response = self.client.infer(
53
+ model_name=self.model_name,
54
+ inputs=triton_inputs,
55
+ outputs=outputs
56
+ )
57
+
58
+ # Postprocesar la respuesta
59
+ output_ids = response.as_numpy("output_ids")
60
+ generated_text = self._decode_output(output_ids)
61
+
62
+ return generated_text
63
+
64
+ def _prepare_inputs(self, prompt: str):
65
+ # Aquí deberías tokenizar el prompt usando el tokenizer de WizardLM-7B
66
+ # Por ahora, devolvemos un ejemplo dummy
67
+ # En producción, carga el tokenizer y úsalo para tokenizar
68
+ return {
69
+ "input_ids": np.array([[1, 2, 3, 4, 5]], dtype=np.int64),
70
+ "attention_mask": np.array([[1, 1, 1, 1, 1]], dtype=np.int64)
71
+ }
72
+
73
+ def _decode_output(self, output_ids):
74
+ # Decodificar los output_ids a texto
75
+ # Por ahora, devolvemos un texto dummy
76
+ return "This is a dummy response from the AI model."
77
+
78
+ async def close(self):
79
+ if self.client:
80
+ self.client.close()