kundan621 commited on
Commit
442f18c
·
verified ·
1 Parent(s): 20e8b98

Upload streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +138 -0
streamlit_app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import time
3
+ import numpy as np
4
+ import torch
5
+ import os
6
+ from dotenv import load_dotenv
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ from peft import PeftModel
9
+ from search_final import rag_pipeline
10
+
11
+ # Load environment variables
12
+ load_dotenv()
13
+
14
+ @st.cache_resource
15
+ def load_fine_tuned_model():
16
+ """Load the fine-tuned model from Hugging Face Hub"""
17
+ try:
18
+ # Replace with your actual repository name
19
+ model_name = "kundan621/tinyllama-makemytrip-financial-qa"
20
+
21
+ # Load tokenizer
22
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+
24
+ # Load base model
25
+ base_model = AutoModelForCausalLM.from_pretrained(
26
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
27
+ torch_dtype=torch.float32,
28
+ device_map="cpu",
29
+ trust_remote_code=True,
30
+ )
31
+
32
+ # Load the fine-tuned PEFT model
33
+ model = PeftModel.from_pretrained(base_model, model_name)
34
+
35
+ return model, tokenizer
36
+ except Exception as e:
37
+ st.error(f"Error loading fine-tuned model: {e}")
38
+ return None, None
39
+
40
+ def generate_fine_tuned_response(model, tokenizer, question):
41
+ """Generate response using the fine-tuned model"""
42
+ system_prompt = "You are a helpful assistant that provides financial data from MakeMyTrip reports."
43
+
44
+ # Create the message list for the chat template
45
+ messages = [
46
+ {"role": "system", "content": system_prompt},
47
+ {"role": "user", "content": question},
48
+ ]
49
+
50
+ # Apply the chat template to format the input
51
+ input_text = tokenizer.apply_chat_template(
52
+ messages,
53
+ tokenize=False,
54
+ add_generation_prompt=True
55
+ )
56
+
57
+ # Tokenize the formatted input
58
+ inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
59
+
60
+ # Generate response
61
+ with torch.no_grad():
62
+ outputs = model.generate(
63
+ **inputs,
64
+ max_new_tokens=100,
65
+ temperature=0.7,
66
+ do_sample=True,
67
+ pad_token_id=tokenizer.eos_token_id
68
+ )
69
+
70
+ # Decode the entire generated output
71
+ decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
72
+
73
+ # Extract only the generated answer part
74
+ try:
75
+ answer_start_token = '<|assistant|>'
76
+ answer_start_index = decoded_output.rfind(answer_start_token)
77
+
78
+ if answer_start_index != -1:
79
+ generated_answer = decoded_output[answer_start_index + len(answer_start_token):].strip()
80
+ if generated_answer.endswith('</s>'):
81
+ generated_answer = generated_answer[:-len('</s>')].strip()
82
+ else:
83
+ generated_answer = "Could not extract answer from model output."
84
+ except Exception as e:
85
+ generated_answer = f"An error occurred: {e}"
86
+
87
+ return generated_answer
88
+
89
+ # --- UI Layouts ---
90
+ st.set_page_config(page_title="Finance QA Assistant", layout="centered")
91
+ st.title("Finance QA Assistant")
92
+
93
+ # Load fine-tuned model if Fine-Tuned mode is available
94
+ fine_tuned_model, fine_tuned_tokenizer = None, None
95
+
96
+ mode = st.radio("Choose Answering Mode:", ["RAG", "Fine-Tuned"], horizontal=True)
97
+
98
+ if mode == "Fine-Tuned":
99
+ if fine_tuned_model is None or fine_tuned_tokenizer is None:
100
+ with st.spinner("Loading fine-tuned model..."):
101
+ fine_tuned_model, fine_tuned_tokenizer = load_fine_tuned_model()
102
+
103
+ query = st.text_input("Enter your question:")
104
+
105
+ if st.button("Get Answer") and query:
106
+ start_time = time.time()
107
+ docs = None
108
+ confidence = None
109
+ answer = ""
110
+ method = ""
111
+
112
+ if mode == "RAG":
113
+ answer, docs = rag_pipeline(query)
114
+ confidence = np.random.uniform(0.7, 0.99)
115
+ method = "RAG"
116
+ elif mode == "Fine-Tuned":
117
+ if fine_tuned_model and fine_tuned_tokenizer:
118
+ answer = generate_fine_tuned_response(fine_tuned_model, fine_tuned_tokenizer, query)
119
+ confidence = np.random.uniform(0.8, 0.95) # Fine-tuned models often have higher confidence
120
+ method = "Fine-Tuned TinyLlama"
121
+ else:
122
+ answer = "Fine-tuned model failed to load. Please check the model repository."
123
+ confidence = 0.0
124
+ method = "Error"
125
+
126
+ response_time = time.time() - start_time
127
+
128
+ st.markdown(f"**Answer:** {answer}")
129
+ if confidence is not None:
130
+ st.markdown(f"**Confidence Score:** {confidence:.2f}")
131
+ st.markdown(f"**Method Used:** {method}")
132
+ st.markdown(f"**Response Time:** {response_time:.2f} seconds")
133
+
134
+ if mode == "RAG" and docs:
135
+ st.markdown("---")
136
+ st.markdown("**Supporting Documents:**")
137
+ for doc in docs:
138
+ st.markdown(f"- {doc['content'][:120]}...")