Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -51,6 +51,8 @@ def generate_sql(prompt: str):
|
|
| 51 |
return_tensors="pt"
|
| 52 |
)
|
| 53 |
|
|
|
|
|
|
|
| 54 |
with torch.inference_mode():
|
| 55 |
outputs = model.generate(
|
| 56 |
input_ids=inputs,
|
|
@@ -61,13 +63,10 @@ def generate_sql(prompt: str):
|
|
| 61 |
pad_token_id=tokenizer.eos_token_id,
|
| 62 |
)
|
| 63 |
|
| 64 |
-
|
|
|
|
| 65 |
|
| 66 |
-
|
| 67 |
-
if "<|assistant|>" in response:
|
| 68 |
-
response = response.split("<|assistant|>", 1)[-1].strip()
|
| 69 |
-
if "<|end|>" in response:
|
| 70 |
-
response = response.split("<|end|>")[0].strip()
|
| 71 |
|
| 72 |
return response
|
| 73 |
|
|
|
|
| 51 |
return_tensors="pt"
|
| 52 |
)
|
| 53 |
|
| 54 |
+
input_length = inputs.shape[-1] # length of prompt tokens
|
| 55 |
+
|
| 56 |
with torch.inference_mode():
|
| 57 |
outputs = model.generate(
|
| 58 |
input_ids=inputs,
|
|
|
|
| 63 |
pad_token_id=tokenizer.eos_token_id,
|
| 64 |
)
|
| 65 |
|
| 66 |
+
# 🔑 Remove the prompt tokens from the output
|
| 67 |
+
generated_tokens = outputs[0][input_length:]
|
| 68 |
|
| 69 |
+
response = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
return response
|
| 72 |
|