Garabatos commited on
Commit
49606a1
·
1 Parent(s): 5b9030b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -6
app.py CHANGED
@@ -8,7 +8,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
8
  import torch
9
 
10
  # ======== Cargar el modelo DialoGPT =========
11
- MODEL_NAME = "microsoft/DialoGPT-medium"
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
14
 
@@ -21,13 +21,24 @@ class Message(BaseModel):
21
  @app.post("/chat")
22
  def chat(msg: Message):
23
  """Genera respuesta basada en el input del usuario."""
24
- input_text = msg.text
25
- print(msg.text)
 
 
26
  inputs = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors="pt")
27
- response_ids = model.generate(inputs, max_length=100, pad_token_id=tokenizer.eos_token_id)
 
 
 
 
 
 
 
 
 
28
  response_text = tokenizer.decode(response_ids[:, inputs.shape[-1]:][0], skip_special_tokens=True)
29
-
30
- print(response_text)
31
 
32
  return {"response": response_text}
33
 
 
8
  import torch
9
 
10
  # ======== Cargar el modelo DialoGPT =========
11
+ MODEL_NAME = "gpt2"
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
14
 
 
21
  @app.post("/chat")
22
  def chat(msg: Message):
23
  """Genera respuesta basada en el input del usuario."""
24
+ input_text = msg.text # Texto de entrada
25
+ print(f"Mensaje recibido: {input_text}")
26
+
27
+ # Codificar el texto de entrada y agregar el token de fin de secuencia
28
  inputs = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors="pt")
29
+
30
+ # Generar la respuesta
31
+ response_ids = model.generate(inputs,
32
+ max_length=100, # Longitud máxima de la respuesta
33
+ pad_token_id=tokenizer.eos_token_id,
34
+ no_repeat_ngram_size=2, # Evitar repeticiones
35
+ top_p=0.95, # Top-p sampling para mayor diversidad
36
+ top_k=60) # Top-k sampling
37
+
38
+ # Decodificar la respuesta generada
39
  response_text = tokenizer.decode(response_ids[:, inputs.shape[-1]:][0], skip_special_tokens=True)
40
+
41
+ print(f"Respuesta generada: {response_text}")
42
 
43
  return {"response": response_text}
44