Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch, re | |
| from pymongo import MongoClient | |
| from datetime import datetime | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline,AutoModelForSeq2SeqLM | |
| from sentence_transformers import SentenceTransformer, util | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_community.llms import HuggingFacePipeline | |
| from IndicTransToolkit.processor import IndicProcessor | |
| from transformers import BitsAndBytesConfig | |
| # === MongoDB === | |
| mongo_uri = "mongodb+srv://vipplavai:pravip2025@cluster0.zcsijsa.mongodb.net/" | |
| client = MongoClient(mongo_uri) | |
| db = client["msme_schemes_db"] | |
| udyam_coll = db["udyam_profiles"] | |
| schemes_chunk_coll = db["schemes_chunks_only"] | |
| schemes_info_coll = db["schemes_embedded"] | |
| query_logs_coll = db["query_logs"] | |
| # === LLM === | |
| model_id = "google/gemma-2b-it" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32) | |
| generator = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=128, do_sample=False) | |
| llm = HuggingFacePipeline(pipeline=generator) | |
| embed_model = SentenceTransformer("BAAI/bge-small-en-v1.5", device="cuda" if torch.cuda.is_available() else "cpu") | |
| # === IndicTrans2 === | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| ip = IndicProcessor(inference=True) | |
| def initialize_translator(ckpt_dir): | |
| tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, trust_remote_code=True) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(ckpt_dir, trust_remote_code=True).to(DEVICE) | |
| model.eval() | |
| return tokenizer, model | |
| def translate_to_telugu(text, tokenizer, model): | |
| batch = ip.preprocess_batch([text], src_lang="eng_Latn", tgt_lang="tel_Telu") | |
| inputs = tokenizer(batch, return_tensors="pt", padding=True).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, max_length=256, num_beams=5) | |
| result = tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| return ip.postprocess_batch(result, lang="tel_Telu")[0] | |
| translator_tokenizer, translator_model = initialize_translator("ai4bharat/indictrans2-en-indic-1B") | |
| # === Prompt === | |
| rephrase_template = PromptTemplate.from_template(""" | |
| You're a helpful assistant guiding Indian MSMEs to the best-matching government schemes. | |
| Based on the enterprise profile, generate a clear, short one-line search query with keywords like state, sector, size, gender, and investment. | |
| Only return the query. Avoid comments. | |
| Enterprise Profile: | |
| {profile_summary} | |
| """) | |
| # === Utilities === | |
| def normalize_udyam(uid): return uid.strip().upper().replace(" ", "") | |
| def is_valid_udyam(uid): return bool(re.match(r"^UDYAM-[A-Z]{2}-\d{2}-\d{6,7}$", uid)) | |
| def get_profile_by_uid(uid): | |
| uid = normalize_udyam(uid) | |
| if not is_valid_udyam(uid): return None | |
| return udyam_coll.find_one({"Udyam_ID": uid}, {"_id": 0}) | |
| def summarize_profile(profile): | |
| return ( | |
| f"The user represents an enterprise named '{profile['Enterprise Name']}', based in {profile['State']}, operating in the {profile['Major Activity']} sector. " | |
| f"They identify as {profile['Gender']}, run a {profile['Enterprise Type']} sized {profile['Organisation Type'].lower()} organization. The enterprise has " | |
| f"{profile['Employment']} employees, with an investment of βΉ{profile['Investment Cost (In Rs.)']:,} and a turnover of βΉ{profile['Net Turnover (In Rs.)']:,}." | |
| ) | |
| def generate_search_query(profile): | |
| summary = summarize_profile(profile) | |
| prompt = rephrase_template.format(profile_summary=summary) | |
| response = llm.invoke(prompt) | |
| return response.strip().split("\n")[0].strip(), summary | |
| def get_top_matching_schemes(query_text, top_k=5): | |
| query_embedding = embed_model.encode(query_text, convert_to_tensor=True) | |
| matches = [] | |
| for doc in schemes_chunk_coll.find({"rag_chunks": {"$exists": True}}): | |
| for chunk in doc["rag_chunks"]: | |
| if "embedding" in chunk and chunk["embedding"]: | |
| chunk_tensor = torch.tensor(chunk["embedding"]).to(query_embedding.device) | |
| score = util.cos_sim(query_embedding, chunk_tensor)[0][0].item() | |
| matches.append({ | |
| "score": score, | |
| "scheme_id": doc.get("scheme_id"), | |
| "scheme_name": doc.get("scheme_name") | |
| }) | |
| seen, top_results = set(), [] | |
| for m in sorted(matches, key=lambda x: x["score"], reverse=True): | |
| if m["scheme_id"] not in seen: | |
| top_results.append(m) | |
| seen.add(m["scheme_id"]) | |
| if len(top_results) == top_k: | |
| break | |
| return top_results | |
| def fetch_scheme_field_llm(scheme_id, field_input): | |
| field_map = { | |
| "eligibility": "eligibility_list", | |
| "benefits": "key_benefits_list", | |
| "assistance": "assistance_list", | |
| "apply": "how_to_apply_list", | |
| "documents": "required_documents_list" | |
| } | |
| matched_field = next((v for k, v in field_map.items() if k in field_input.lower()), None) | |
| if not matched_field: | |
| return "β Try asking about eligibility, benefits, how to apply, or documents." | |
| doc = schemes_info_coll.find_one({"scheme_id": scheme_id}) | |
| if doc and matched_field in doc: | |
| raw_text = "\n".join(doc[matched_field][:5]) | |
| prompt = f""" | |
| Summarize the following information into a clear and professional explanation for business owners: | |
| Scheme: {doc['scheme_name']} | |
| Section: {matched_field.replace('_list','').title()} | |
| {raw_text} | |
| """ | |
| return llm.invoke(prompt).strip() | |
| return "β οΈ Couldnβt find that information for the selected scheme." | |
| # === Chat State === | |
| chat_state = {"stage": 0, "profile": {}, "scheme_id": None, "last_bot_msg": "", "summary": ""} | |
| def chatbot(msg, history): | |
| if chat_state["stage"] == 0: | |
| chat_state["stage"] = 1 | |
| chat_state["last_bot_msg"] = "π Hello! Please enter your Udyam ID or say 'manual' to fill in details yourself." | |
| return chat_state["last_bot_msg"] | |
| if chat_state["stage"] == 1: | |
| if msg.lower().startswith("udyam-"): | |
| profile = get_profile_by_uid(msg) | |
| if profile: | |
| chat_state["profile"] = profile | |
| chat_state["stage"] = 3 | |
| summary = summarize_profile(profile) | |
| chat_state["summary"] = summary | |
| chat_state["last_bot_msg"] = f"β Profile found! Generating recommendations...\n\nπ Based on your profile: {summary}\n\nType 'show related schemes' to view top matches." | |
| return chat_state["last_bot_msg"] | |
| chat_state["last_bot_msg"] = "β Invalid or unregistered Udyam ID. Try again or say 'manual'." | |
| return chat_state["last_bot_msg"] | |
| elif "manual" in msg.lower(): | |
| chat_state["stage"] = 2 | |
| chat_state["last_bot_msg"] = "π Great! What's your enterprise name?" | |
| return chat_state["last_bot_msg"] | |
| chat_state["last_bot_msg"] = "Please enter a valid Udyam ID or type 'manual'." | |
| return chat_state["last_bot_msg"] | |
| if chat_state["stage"] == 2: | |
| steps = [ | |
| "Enterprise Name", "Gender", "Enterprise Type", "Organisation Type", | |
| "Major Activity", "State", "Investment Cost (In Rs.)", "Net Turnover (In Rs.)", "Employment" | |
| ] | |
| curr_index = len(chat_state["profile"]) | |
| key = steps[curr_index] | |
| chat_state["profile"][key] = int(msg) if "Cost" in key or "Turnover" in key or "Employment" in key else msg | |
| if len(chat_state["profile"]) == len(steps): | |
| chat_state["stage"] = 3 | |
| summary = summarize_profile(chat_state["profile"]) | |
| chat_state["summary"] = summary | |
| chat_state["last_bot_msg"] = f"β Thanks! Profile completed.\n\nπ Based on your profile: {summary}\n\nType 'show related schemes' to view top matches." | |
| return chat_state["last_bot_msg"] | |
| prompt = f"{steps[curr_index + 1]}?" | |
| chat_state["last_bot_msg"] = prompt | |
| return prompt | |
| if chat_state["stage"] == 3: | |
| if "show" in msg.lower() and "scheme" in msg.lower(): | |
| query, summary = generate_search_query(chat_state["profile"]) | |
| top_schemes = get_top_matching_schemes(query) | |
| if not top_schemes: | |
| chat_state["last_bot_msg"] = "β οΈ No matching schemes found." | |
| return chat_state["last_bot_msg"] | |
| chat_state["scheme_id"] = top_schemes[0]["scheme_id"] | |
| chat_state["stage"] = 4 | |
| schemes_text = "\n".join([f"{i+1}. {s['scheme_name']} (Score: {round(s['score'],4)})" for i, s in enumerate(top_schemes)]) | |
| chat_state["last_bot_msg"] = f"π Recommended Schemes:\n{schemes_text}\n\nYou can now ask about eligibility, apply, documents, etc." | |
| query_logs_coll.insert_one({ | |
| "timestamp": datetime.utcnow(), | |
| "udyam_id": chat_state["profile"].get("Udyam_ID", "manual_entry"), | |
| "profile_summary": summary, | |
| "query": query, | |
| "top_schemes": top_schemes, | |
| "selected_scheme": top_schemes[0]["scheme_name"] | |
| }) | |
| return chat_state["last_bot_msg"] | |
| chat_state["last_bot_msg"] = "Type 'show related schemes' to proceed." | |
| return chat_state["last_bot_msg"] | |
| if chat_state["stage"] == 4: | |
| response = fetch_scheme_field_llm(chat_state["scheme_id"], msg) | |
| chat_state["last_bot_msg"] = response | |
| return response | |
| def translate_last_response(): | |
| if chat_state["last_bot_msg"]: | |
| return "π Telugu Translation:\n\n" + translate_to_telugu(chat_state["last_bot_msg"], translator_tokenizer, translator_model) | |
| return "β οΈ No message to translate." | |
| # === UI === | |
| with gr.Blocks(title="MSME Chatbot with Telugu Support") as demo: | |
| chatbot_ui = gr.ChatInterface(fn=chatbot, title="π€ MSME Scheme Assistant", textbox=gr.Textbox(placeholder="Type your message here...")) | |
| translate_btn = gr.Button("π Translate Last Response to Telugu") | |
| translation_output = gr.Textbox(label="π£οΈ Telugu Translation", lines=5) | |
| translate_btn.click(fn=translate_last_response, outputs=translation_output) | |
| demo.launch() | |