saadkhi commited on
Commit
81f0e97
·
verified ·
1 Parent(s): 75da654

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -51,6 +51,8 @@ def generate_sql(prompt: str):
51
  return_tensors="pt"
52
  )
53
 
 
 
54
  with torch.inference_mode():
55
  outputs = model.generate(
56
  input_ids=inputs,
@@ -61,13 +63,10 @@ def generate_sql(prompt: str):
61
  pad_token_id=tokenizer.eos_token_id,
62
  )
63
 
64
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
65
 
66
- # Cleanup
67
- if "<|assistant|>" in response:
68
- response = response.split("<|assistant|>", 1)[-1].strip()
69
- if "<|end|>" in response:
70
- response = response.split("<|end|>")[0].strip()
71
 
72
  return response
73
 
 
51
  return_tensors="pt"
52
  )
53
 
54
+ input_length = inputs.shape[-1] # length of prompt tokens
55
+
56
  with torch.inference_mode():
57
  outputs = model.generate(
58
  input_ids=inputs,
 
63
  pad_token_id=tokenizer.eos_token_id,
64
  )
65
 
66
+ # 🔑 Remove the prompt tokens from the output
67
+ generated_tokens = outputs[0][input_length:]
68
 
69
+ response = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
 
 
 
 
70
 
71
  return response
72