File size: 7,478 Bytes
2224d39
b4dbfa4
 
 
 
 
b769722
 
 
 
 
 
 
 
474e992
 
 
 
 
 
 
 
b769722
 
 
 
474e992
 
 
 
b769722
5376334
 
 
 
 
 
b4dbfa4
 
 
 
 
 
 
 
 
 
c6f4684
 
 
 
 
 
b4dbfa4
 
 
c6f4684
 
 
 
 
 
b4dbfa4
 
 
 
 
 
 
 
 
c6f4684
 
 
 
 
 
b4dbfa4
 
 
 
c6f4684
b4dbfa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6f4684
 
 
 
 
 
 
 
 
 
b4dbfa4
 
 
c6f4684
 
 
 
 
 
 
 
 
b4dbfa4
 
 
 
 
c6f4684
 
 
 
 
b4dbfa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2224d39
b4dbfa4
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import streamlit as st
import time
import numpy as np
import torch
import os
from dotenv import load_dotenv

# Set up environment for HuggingFace Spaces compatibility
if not os.getenv("STREAMLIT_CONFIG_DIR"):
    os.environ["STREAMLIT_CONFIG_DIR"] = "/tmp/.streamlit"
if not os.getenv("STREAMLIT_DATA_DIR"):
    os.environ["STREAMLIT_DATA_DIR"] = "/tmp/.streamlit"
if not os.getenv("NLTK_DATA"):
    os.environ["NLTK_DATA"] = "/tmp/nltk_data"
if not os.getenv("HF_HOME"):
    os.environ["HF_HOME"] = "/tmp/huggingface"
if not os.getenv("TRANSFORMERS_CACHE"):
    os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
if not os.getenv("SENTENCE_TRANSFORMERS_HOME"):
    os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface/sentence_transformers"
if not os.getenv("HF_HUB_CACHE"):
    os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"

# Create directories
os.makedirs("/tmp/.streamlit", exist_ok=True)
os.makedirs("/tmp/nltk_data", exist_ok=True)
os.makedirs("/tmp/huggingface", exist_ok=True)
os.makedirs("/tmp/huggingface/transformers", exist_ok=True)
os.makedirs("/tmp/huggingface/sentence_transformers", exist_ok=True)
os.makedirs("/tmp/huggingface/hub", exist_ok=True)

try:
    from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError as e:
    st.error(f"Error importing transformers: {e}")
    st.error("Please ensure transformers library is properly installed.")
    st.stop()
from peft import PeftModel
from search_final import rag_pipeline

# Load environment variables
load_dotenv()

@st.cache_resource
def load_fine_tuned_model():
    """Load the fine-tuned model from Hugging Face Hub"""
    try:
        # Get HuggingFace token from environment
        hf_token = os.getenv("HF_API_KEY")
        if not hf_token:
            st.error("HuggingFace API token not found. Please set HF_API_KEY in your environment.")
            return None, None
        
        # Replace with your actual repository name
        model_name = "kundan621/tinyllama-makemytrip-financial-qa"
        
        # Load tokenizer with authentication
        tokenizer = AutoTokenizer.from_pretrained(
            model_name, 
            token=hf_token,
            trust_remote_code=True
        )
        
        # Load base model
        base_model = AutoModelForCausalLM.from_pretrained(
            "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
            torch_dtype=torch.float32,
            device_map="cpu",
            trust_remote_code=True,
        )
        
        # Load the fine-tuned PEFT model with authentication
        model = PeftModel.from_pretrained(
            base_model, 
            model_name,
            token=hf_token
        )
        
        return model, tokenizer
    except Exception as e:
        st.error(f"Error loading fine-tuned model: {e}")
        st.info("Make sure your model repository is public or you have the correct access permissions.")
        return None, None

def generate_fine_tuned_response(model, tokenizer, question):
    """Generate response using the fine-tuned model"""
    system_prompt = "You are a helpful assistant that provides financial data from MakeMyTrip reports."
    
    # Create the message list for the chat template
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question},
    ]
    
    # Apply the chat template to format the input
    input_text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    # Tokenize the formatted input
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
    
    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            **inputs, 
            max_new_tokens=100,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    # Decode the entire generated output
    decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract only the generated answer part
    try:
        answer_start_token = '<|assistant|>'
        answer_start_index = decoded_output.rfind(answer_start_token)
        
        if answer_start_index != -1:
            generated_answer = decoded_output[answer_start_index + len(answer_start_token):].strip()
            if generated_answer.endswith('</s>'):
                generated_answer = generated_answer[:-len('</s>')].strip()
        else:
            generated_answer = "Could not extract answer from model output."
    except Exception as e:
        generated_answer = f"An error occurred: {e}"
    
    return generated_answer

# --- UI Layouts ---
st.set_page_config(page_title="Finance QA Assistant", layout="centered")
st.title("Finance QA Assistant")

# Add information about the modes
with st.expander("ℹ️ About the Modes"):
    st.markdown("""
    **RAG Mode**: Uses Retrieval-Augmented Generation with a vector database and external LLM API.
    
    **Fine-Tuned Mode**: Uses a custom fine-tuned TinyLlama model (requires authentication).
    
    *Note: Fine-tuned mode requires a HuggingFace API token and access to the private model repository.*
    """)

# Load fine-tuned model if Fine-Tuned mode is available
fine_tuned_model, fine_tuned_tokenizer = None, None

# Check if HuggingFace token is available
hf_token = os.getenv("HF_API_KEY")
if hf_token:
    available_modes = ["RAG", "Fine-Tuned"]
else:
    available_modes = ["RAG"]
    st.warning("⚠️ Fine-Tuned mode is not available. HuggingFace API token is required for accessing private models.")

mode = st.radio("Choose Answering Mode:", available_modes, horizontal=True)

if mode == "Fine-Tuned":
    if fine_tuned_model is None or fine_tuned_tokenizer is None:
        with st.spinner("Loading fine-tuned model..."):
            fine_tuned_model, fine_tuned_tokenizer = load_fine_tuned_model()
            
        # If model loading failed, fall back to RAG
        if fine_tuned_model is None or fine_tuned_tokenizer is None:
            st.error("Failed to load fine-tuned model. Falling back to RAG mode.")
            mode = "RAG"

query = st.text_input("Enter your question:")

if st.button("Get Answer") and query:
    start_time = time.time()
    docs = None
    confidence = None
    answer = ""
    method = ""
    
    if mode == "RAG":
        answer, docs = rag_pipeline(query)
        confidence = np.random.uniform(0.7, 0.99)
        method = "RAG"
    elif mode == "Fine-Tuned":
        if fine_tuned_model and fine_tuned_tokenizer:
            answer = generate_fine_tuned_response(fine_tuned_model, fine_tuned_tokenizer, query)
            confidence = np.random.uniform(0.8, 0.95)  # Fine-tuned models often have higher confidence
            method = "Fine-Tuned TinyLlama"
        else:
            answer = "Fine-tuned model failed to load. Please check the model repository."
            confidence = 0.0
            method = "Error"
    
    response_time = time.time() - start_time

    st.markdown(f"**Answer:** {answer}")
    if confidence is not None:
        st.markdown(f"**Confidence Score:** {confidence:.2f}")
    st.markdown(f"**Method Used:** {method}")
    st.markdown(f"**Response Time:** {response_time:.2f} seconds")

    if mode == "RAG" and docs:
        st.markdown("---")
        st.markdown("**Supporting Documents:**")
        for doc in docs:
            st.markdown(f"- {doc['content'][:120]}...")