MSME_Chat_bot / app.py
Vipplav's picture
Update app.py
e710c18 verified
import gradio as gr
import torch, re
from pymongo import MongoClient
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline,AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer, util
from langchain_core.prompts import PromptTemplate
from langchain_community.llms import HuggingFacePipeline
from IndicTransToolkit.processor import IndicProcessor
from transformers import BitsAndBytesConfig
# === MongoDB ===
mongo_uri = "mongodb+srv://vipplavai:pravip2025@cluster0.zcsijsa.mongodb.net/"
client = MongoClient(mongo_uri)
db = client["msme_schemes_db"]
udyam_coll = db["udyam_profiles"]
schemes_chunk_coll = db["schemes_chunks_only"]
schemes_info_coll = db["schemes_embedded"]
query_logs_coll = db["query_logs"]
# === LLM ===
model_id = "google/gemma-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
generator = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=128, do_sample=False)
llm = HuggingFacePipeline(pipeline=generator)
embed_model = SentenceTransformer("BAAI/bge-small-en-v1.5", device="cuda" if torch.cuda.is_available() else "cpu")
# === IndicTrans2 ===
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
ip = IndicProcessor(inference=True)
def initialize_translator(ckpt_dir):
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(ckpt_dir, trust_remote_code=True).to(DEVICE)
model.eval()
return tokenizer, model
def translate_to_telugu(text, tokenizer, model):
batch = ip.preprocess_batch([text], src_lang="eng_Latn", tgt_lang="tel_Telu")
inputs = tokenizer(batch, return_tensors="pt", padding=True).to(DEVICE)
with torch.no_grad():
outputs = model.generate(**inputs, max_length=256, num_beams=5)
result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return ip.postprocess_batch(result, lang="tel_Telu")[0]
translator_tokenizer, translator_model = initialize_translator("ai4bharat/indictrans2-en-indic-1B")
# === Prompt ===
rephrase_template = PromptTemplate.from_template("""
You're a helpful assistant guiding Indian MSMEs to the best-matching government schemes.
Based on the enterprise profile, generate a clear, short one-line search query with keywords like state, sector, size, gender, and investment.
Only return the query. Avoid comments.
Enterprise Profile:
{profile_summary}
""")
# === Utilities ===
def normalize_udyam(uid): return uid.strip().upper().replace(" ", "")
def is_valid_udyam(uid): return bool(re.match(r"^UDYAM-[A-Z]{2}-\d{2}-\d{6,7}$", uid))
def get_profile_by_uid(uid):
uid = normalize_udyam(uid)
if not is_valid_udyam(uid): return None
return udyam_coll.find_one({"Udyam_ID": uid}, {"_id": 0})
def summarize_profile(profile):
return (
f"The user represents an enterprise named '{profile['Enterprise Name']}', based in {profile['State']}, operating in the {profile['Major Activity']} sector. "
f"They identify as {profile['Gender']}, run a {profile['Enterprise Type']} sized {profile['Organisation Type'].lower()} organization. The enterprise has "
f"{profile['Employment']} employees, with an investment of β‚Ή{profile['Investment Cost (In Rs.)']:,} and a turnover of β‚Ή{profile['Net Turnover (In Rs.)']:,}."
)
def generate_search_query(profile):
summary = summarize_profile(profile)
prompt = rephrase_template.format(profile_summary=summary)
response = llm.invoke(prompt)
return response.strip().split("\n")[0].strip(), summary
def get_top_matching_schemes(query_text, top_k=5):
query_embedding = embed_model.encode(query_text, convert_to_tensor=True)
matches = []
for doc in schemes_chunk_coll.find({"rag_chunks": {"$exists": True}}):
for chunk in doc["rag_chunks"]:
if "embedding" in chunk and chunk["embedding"]:
chunk_tensor = torch.tensor(chunk["embedding"]).to(query_embedding.device)
score = util.cos_sim(query_embedding, chunk_tensor)[0][0].item()
matches.append({
"score": score,
"scheme_id": doc.get("scheme_id"),
"scheme_name": doc.get("scheme_name")
})
seen, top_results = set(), []
for m in sorted(matches, key=lambda x: x["score"], reverse=True):
if m["scheme_id"] not in seen:
top_results.append(m)
seen.add(m["scheme_id"])
if len(top_results) == top_k:
break
return top_results
def fetch_scheme_field_llm(scheme_id, field_input):
field_map = {
"eligibility": "eligibility_list",
"benefits": "key_benefits_list",
"assistance": "assistance_list",
"apply": "how_to_apply_list",
"documents": "required_documents_list"
}
matched_field = next((v for k, v in field_map.items() if k in field_input.lower()), None)
if not matched_field:
return "❌ Try asking about eligibility, benefits, how to apply, or documents."
doc = schemes_info_coll.find_one({"scheme_id": scheme_id})
if doc and matched_field in doc:
raw_text = "\n".join(doc[matched_field][:5])
prompt = f"""
Summarize the following information into a clear and professional explanation for business owners:
Scheme: {doc['scheme_name']}
Section: {matched_field.replace('_list','').title()}
{raw_text}
"""
return llm.invoke(prompt).strip()
return "⚠️ Couldn’t find that information for the selected scheme."
# === Chat State ===
chat_state = {"stage": 0, "profile": {}, "scheme_id": None, "last_bot_msg": "", "summary": ""}
def chatbot(msg, history):
if chat_state["stage"] == 0:
chat_state["stage"] = 1
chat_state["last_bot_msg"] = "πŸ‘‹ Hello! Please enter your Udyam ID or say 'manual' to fill in details yourself."
return chat_state["last_bot_msg"]
if chat_state["stage"] == 1:
if msg.lower().startswith("udyam-"):
profile = get_profile_by_uid(msg)
if profile:
chat_state["profile"] = profile
chat_state["stage"] = 3
summary = summarize_profile(profile)
chat_state["summary"] = summary
chat_state["last_bot_msg"] = f"βœ… Profile found! Generating recommendations...\n\nπŸ” Based on your profile: {summary}\n\nType 'show related schemes' to view top matches."
return chat_state["last_bot_msg"]
chat_state["last_bot_msg"] = "❌ Invalid or unregistered Udyam ID. Try again or say 'manual'."
return chat_state["last_bot_msg"]
elif "manual" in msg.lower():
chat_state["stage"] = 2
chat_state["last_bot_msg"] = "πŸ“ Great! What's your enterprise name?"
return chat_state["last_bot_msg"]
chat_state["last_bot_msg"] = "Please enter a valid Udyam ID or type 'manual'."
return chat_state["last_bot_msg"]
if chat_state["stage"] == 2:
steps = [
"Enterprise Name", "Gender", "Enterprise Type", "Organisation Type",
"Major Activity", "State", "Investment Cost (In Rs.)", "Net Turnover (In Rs.)", "Employment"
]
curr_index = len(chat_state["profile"])
key = steps[curr_index]
chat_state["profile"][key] = int(msg) if "Cost" in key or "Turnover" in key or "Employment" in key else msg
if len(chat_state["profile"]) == len(steps):
chat_state["stage"] = 3
summary = summarize_profile(chat_state["profile"])
chat_state["summary"] = summary
chat_state["last_bot_msg"] = f"βœ… Thanks! Profile completed.\n\nπŸ” Based on your profile: {summary}\n\nType 'show related schemes' to view top matches."
return chat_state["last_bot_msg"]
prompt = f"{steps[curr_index + 1]}?"
chat_state["last_bot_msg"] = prompt
return prompt
if chat_state["stage"] == 3:
if "show" in msg.lower() and "scheme" in msg.lower():
query, summary = generate_search_query(chat_state["profile"])
top_schemes = get_top_matching_schemes(query)
if not top_schemes:
chat_state["last_bot_msg"] = "⚠️ No matching schemes found."
return chat_state["last_bot_msg"]
chat_state["scheme_id"] = top_schemes[0]["scheme_id"]
chat_state["stage"] = 4
schemes_text = "\n".join([f"{i+1}. {s['scheme_name']} (Score: {round(s['score'],4)})" for i, s in enumerate(top_schemes)])
chat_state["last_bot_msg"] = f"πŸ“ˆ Recommended Schemes:\n{schemes_text}\n\nYou can now ask about eligibility, apply, documents, etc."
query_logs_coll.insert_one({
"timestamp": datetime.utcnow(),
"udyam_id": chat_state["profile"].get("Udyam_ID", "manual_entry"),
"profile_summary": summary,
"query": query,
"top_schemes": top_schemes,
"selected_scheme": top_schemes[0]["scheme_name"]
})
return chat_state["last_bot_msg"]
chat_state["last_bot_msg"] = "Type 'show related schemes' to proceed."
return chat_state["last_bot_msg"]
if chat_state["stage"] == 4:
response = fetch_scheme_field_llm(chat_state["scheme_id"], msg)
chat_state["last_bot_msg"] = response
return response
def translate_last_response():
if chat_state["last_bot_msg"]:
return "πŸ“„ Telugu Translation:\n\n" + translate_to_telugu(chat_state["last_bot_msg"], translator_tokenizer, translator_model)
return "⚠️ No message to translate."
# === UI ===
with gr.Blocks(title="MSME Chatbot with Telugu Support") as demo:
chatbot_ui = gr.ChatInterface(fn=chatbot, title="πŸ€– MSME Scheme Assistant", textbox=gr.Textbox(placeholder="Type your message here..."))
translate_btn = gr.Button("🌐 Translate Last Response to Telugu")
translation_output = gr.Textbox(label="πŸ—£οΈ Telugu Translation", lines=5)
translate_btn.click(fn=translate_last_response, outputs=translation_output)
demo.launch()