Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
| # Load the fine-tuned model and tokenizer | |
| # Cache model to avoid reloading | |
| def load_model(): | |
| model_directory = "Vaishu16/QC_fine_tuned_model" | |
| model = AutoModelForCausalLM.from_pretrained(model_directory) | |
| tokenizer = AutoTokenizer.from_pretrained(model_directory) | |
| return model, tokenizer | |
| # Load model and tokenizer | |
| model, tokenizer = load_model() | |
| # Create a pipeline | |
| question_completion_pipeline = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=-1 | |
| ) | |
| # Suggested partial questions | |
| suggested_questions = [ | |
| "What is the impact of", | |
| "How does the company report", | |
| "What are the financial risks of", | |
| "Explain the corporate governance of", | |
| "What was the revenue growth of" | |
| ] | |
| # Sidebar Settings | |
| st.sidebar.header("π§ Settings") | |
| max_length = st.sidebar.slider("Max Length", 30, 100, 60, step=5) | |
| num_return_sequences = st.sidebar.slider("Number of Completions", 1, 5, 3) | |
| sampling = st.sidebar.checkbox("Enable Sampling", value=True) | |
| # Main UI | |
| st.title("π€ Question Completion Model") | |
| st.write("Enter a partial question related to **financial statements, corporate governance, or company reports**, and the model will intelligently complete it!") | |
| # Session state for input question | |
| if "partial_question" not in st.session_state: | |
| st.session_state.partial_question = "" | |
| # Function to update input when a suggestion is clicked | |
| def update_input(selected_question): | |
| st.session_state.partial_question = selected_question | |
| st.rerun() | |
| # Display suggested partial questions in sidebar | |
| st.sidebar.subheader("π‘ Suggested Partial Questions") | |
| for i, question in enumerate(suggested_questions): | |
| if st.sidebar.button(question, key=f"suggested_{i}"): | |
| update_input(question) | |
| # Text input box | |
| st.subheader("βοΈ Enter Partial Question") | |
| partial_question = st.text_input("Enter a partial question:", st.session_state.partial_question) | |
| # Generate button | |
| if st.button("π Complete Question"): | |
| if partial_question.strip(): | |
| with st.spinner("β³ Generating completed questions... Please wait!"): | |
| outputs = question_completion_pipeline( | |
| partial_question, | |
| max_length=max_length, | |
| num_return_sequences=num_return_sequences, | |
| do_sample=sampling, | |
| truncation=True | |
| ) | |
| st.subheader("β Completed Questions") | |
| completed_texts = [output["generated_text"] for output in outputs] | |
| # Display completed questions inside an expander | |
| with st.expander("π View Generated Questions"): | |
| for idx, completed_question in enumerate(completed_texts): | |
| st.markdown(f"**{idx+1}. {completed_question}**") | |
| # Copy option | |
| st.text_area("π Copy Completed Questions:", "\n".join(completed_texts), height=150) | |
| else: | |
| st.warning("β οΈ Please enter a partial question.") | |