Spaces:
Sleeping
Sleeping
Update src/model_consumer.py
Browse files- src/model_consumer.py +25 -3
src/model_consumer.py
CHANGED
|
@@ -7,7 +7,8 @@ model_id = "prd101-wd/phi1_5-bankingqa-merged"
|
|
| 7 |
# Load model only once
|
| 8 |
@st.cache_resource
|
| 9 |
def load_model():
|
| 10 |
-
return pipeline("question-answering", model=model_id)
|
|
|
|
| 11 |
|
| 12 |
# Create a text generation pipeline
|
| 13 |
pipe = load_model()
|
|
@@ -23,10 +24,31 @@ if st.button("Ask"):
|
|
| 23 |
if user_input.strip():
|
| 24 |
# Format the prompt like Alpaca-style
|
| 25 |
prompt = f"### Instruction:\n{user_input}\n\n### Response:\n"
|
| 26 |
-
output = pipe(prompt, max_new_tokens=200, do_sample=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
# Extract only the model's response (remove prompt part if included in output)
|
| 29 |
-
answer = output.split("### Response:")[-1].strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
st.markdown("### HelpdeskBot Answer:")
|
| 31 |
st.success(answer)
|
| 32 |
else:
|
|
|
|
| 7 |
# Load model only once
|
| 8 |
@st.cache_resource
|
| 9 |
def load_model():
|
| 10 |
+
#return pipeline("question-answering", model=model_id)
|
| 11 |
+
return pipeline("text-generation", model=model_id, trust_remote_code=True)
|
| 12 |
|
| 13 |
# Create a text generation pipeline
|
| 14 |
pipe = load_model()
|
|
|
|
| 24 |
if user_input.strip():
|
| 25 |
# Format the prompt like Alpaca-style
|
| 26 |
prompt = f"### Instruction:\n{user_input}\n\n### Response:\n"
|
| 27 |
+
output = pipe(prompt, max_new_tokens=200, do_sample=True, temperature=0.7)
|
| 28 |
+
|
| 29 |
+
# Process output
|
| 30 |
+
if isinstance(output, list) and output:
|
| 31 |
+
answer = output[0]['generated_text']
|
| 32 |
+
# Extract only the response part
|
| 33 |
+
if "### Response:" in answer:
|
| 34 |
+
answer = answer.split("### Response:")[-1].strip()
|
| 35 |
+
else:
|
| 36 |
+
answer = "Unable to generate a response. Please try again."
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# if isinstance(output, list) and len(output) > 0 and "generated_text" in output[0]:
|
| 41 |
+
# answer = output[0]["generated_text"]
|
| 42 |
+
# else:
|
| 43 |
+
# answer = "Unable to generate a response. Please try again."
|
| 44 |
|
| 45 |
# Extract only the model's response (remove prompt part if included in output)
|
| 46 |
+
#answer = output.split("### Response:")[-1].strip()
|
| 47 |
+
# if isinstance(output, str):
|
| 48 |
+
# answer = output.split("### Response:")[-1].strip()
|
| 49 |
+
# else:
|
| 50 |
+
# answer = "Unexpected output format. Please try again."
|
| 51 |
+
|
| 52 |
st.markdown("### HelpdeskBot Answer:")
|
| 53 |
st.success(answer)
|
| 54 |
else:
|