Update app.py
Browse files
app.py
CHANGED
|
@@ -72,8 +72,19 @@ def generate_response(text: str, model_gemma, tokenizer_gemma, device) -> str: #
|
|
| 72 |
)
|
| 73 |
|
| 74 |
generated_output = model_gemma.generate(input, generation_config=generation_config)
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
#input_text = "Reapond to the users prompt: " + text
|
| 78 |
#input = tokenizer_gemma(input_text, return_tensors="pt").to(device)
|
| 79 |
#generated_output = model_gemma.generate(**input, max_length=MAX_GEMMA_LENGTH, early_stopping=True)
|
|
|
|
| 72 |
)
|
| 73 |
|
| 74 |
generated_output = model_gemma.generate(input, generation_config=generation_config)
|
| 75 |
+
decoded_output = tokenizer_gemma.decode(generated_output[0], skip_special_tokens=False)
|
| 76 |
+
|
| 77 |
+
# Extract the assistant's response (Gemma specific)
|
| 78 |
+
start_token = "<start_of_turn>model"
|
| 79 |
+
end_token = "<end_of_turn>"
|
| 80 |
+
|
| 81 |
+
start_index = decoded_output.find(start_token)
|
| 82 |
+
if start_index != -1:
|
| 83 |
+
start_index += len(start_token)
|
| 84 |
+
end_index = decoded_output.find(end_token, start_index)
|
| 85 |
+
assistant_response = decoded_output[start_index:].strip()
|
| 86 |
+
return assistant_response
|
| 87 |
+
return decoded_output
|
| 88 |
#input_text = "Reapond to the users prompt: " + text
|
| 89 |
#input = tokenizer_gemma(input_text, return_tensors="pt").to(device)
|
| 90 |
#generated_output = model_gemma.generate(**input, max_length=MAX_GEMMA_LENGTH, early_stopping=True)
|