Spaces:
Sleeping
Sleeping
| import pickle | |
| import torch | |
| from sentence_transformers import SentenceTransformer, util | |
| import gradio as gr | |
| from fastapi import FastAPI, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| import threading | |
| import uvicorn | |
| import nest_asyncio | |
| # --------------- Load model & embeddings --------------- | |
| with open("chatbot.pkl", "rb") as f: | |
| data = pickle.load(f) | |
| questions = data["questions"] | |
| answers = data["answers"] | |
| question_embeddings = data["embeddings"] | |
| model = SentenceTransformer('all-MiniLM-L6-v2') | |
| def chat(user_question, threshold=0.4, top_k=1): | |
| """Return best answer or fallback if none found.""" | |
| user_embedding = model.encode(user_question, convert_to_tensor=True) | |
| cos_scores = util.cos_sim(user_embedding, question_embeddings)[0] | |
| top_results = torch.topk(cos_scores, k=top_k) | |
| for score, idx in zip(top_results.values, top_results.indices): | |
| if score.item() >= threshold: | |
| return {"matched_question": questions[idx], "answer": answers[idx], "score": score.item()} | |
| return {"matched_question": None, "answer": "Sorry, I am not able to answer that.", "score": None} | |
| # --------------- FastAPI API --------------- | |
| api = FastAPI() | |
| api.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def get_chat(question: str = Query(..., description="Your question here")): | |
| response = chat(question) | |
| return JSONResponse(response) | |
| # --------------- Gradio UI --------------- | |
| iface = gr.Interface( | |
| fn=lambda q: chat(q)["answer"], | |
| inputs=gr.Textbox(lines=2, placeholder="Ask a question..."), | |
| outputs=gr.Textbox(), | |
| title="FAQ Chatbot", | |
| description="Ask any question and get answers from the FAQ." | |
| ) | |
| def run_gradio(): | |
| iface.launch(server_name="0.0.0.0", server_port=7860) | |
| # --------------- Run both FastAPI + Gradio --------------- | |
| nest_asyncio.apply() | |
| threading.Thread(target=run_gradio, daemon=True).start() | |
| # Run FastAPI on port 8000 | |
| uvicorn.run(api, host="0.0.0.0", port=8000) | |