Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import pandas as pd | |
| from transformers import BertTokenizer, BertModel, GPT2LMHeadModel, GPT2Tokenizer | |
| import torch | |
| from fastapi import FastAPI, Request | |
| from pydantic import BaseModel | |
| import gradio as gr | |
| import uvicorn | |
| # Initialize FastAPI | |
| app = FastAPI() | |
| # Initialize the BERT model and tokenizer | |
| bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| bert_model = BertModel.from_pretrained('bert-base-uncased') | |
| def get_bert_embeddings(texts): | |
| inputs = bert_tokenizer(texts, return_tensors='pt', padding=True, truncation=True) | |
| with torch.no_grad(): | |
| outputs = bert_model(**inputs) | |
| return outputs.last_hidden_state[:, 0, :].numpy() | |
| def get_closest_question(user_query, questions, threshold=0.95): | |
| all_texts = questions + [user_query] | |
| embeddings = get_bert_embeddings(all_texts) | |
| cosine_similarities = np.dot(embeddings[-1], embeddings[:-1].T) / ( | |
| np.linalg.norm(embeddings[-1]) * np.linalg.norm(embeddings[:-1], axis=1) | |
| ) | |
| max_similarity = np.max(cosine_similarities) | |
| if max_similarity >= threshold: | |
| most_similar_index = np.argmax(cosine_similarities) | |
| return questions[most_similar_index], max_similarity | |
| else: | |
| return None, max_similarity | |
| def generate_gpt2_response(prompt, model, tokenizer, max_length=100): | |
| inputs = tokenizer.encode(prompt, return_tensors='pt') | |
| outputs = model.generate(inputs, max_length=max_length, num_return_sequences=1) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Initialize data | |
| data_dict = { | |
| "questions": [ | |
| "What is Rookus?", | |
| "How does Rookus use AI in its designs?", | |
| "What products does Rookus offer?", | |
| "Can I see samples of Rookus' designs?", | |
| "How can I join the waitlist for Rookus?", | |
| "How does Rookus ensure the quality of its AI-generated designs?", | |
| "Is there a custom design option available at Rookus?", | |
| "How long does it take to receive a product from Rookus?" | |
| ], | |
| "answers": [ | |
| "Rookus is a startup that leverages AI to create unique designs for various products such as clothes, posters, and different arts and crafts.", | |
| "Rookus uses advanced AI algorithms to generate innovative and aesthetically pleasing designs. These AI models are trained on vast datasets of art and design to produce high-quality mockups.", | |
| "Rookus offers a variety of products, including clothing, posters, and a range of arts and crafts items, all featuring AI-generated designs.", | |
| "Yes, Rookus provides samples of its designs on its website. You can view a gallery of products showcasing the AI-generated artwork.", | |
| "To join the waitlist for Rookus, visit our website and sign up with your email. You'll receive updates on our launch and exclusive early access opportunities.", | |
| "Rookus ensures the quality of its AI-generated designs through rigorous testing and refinement. Each design goes through multiple review stages to ensure it meets our high standards.", | |
| "Yes, Rookus offers custom design options. You can submit your preferences, and our AI will generate a design tailored to your specifications.", | |
| "The delivery time for products from Rookus varies based on the product type and location. Typically, it takes 2-4 weeks for production and delivery." | |
| ], | |
| "default_answer": "I'm sorry, I cannot answer this right now. Your question has been saved, and we will get back to you with a response soon." | |
| } | |
| # Initialize GPT-2 model and tokenizer | |
| gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
| gpt2_model = GPT2LMHeadModel.from_pretrained('gpt2') | |
| # Ensure the Excel file is created with necessary structure | |
| excel_file = 'data.xlsx' | |
| if not os.path.isfile(excel_file): | |
| df = pd.DataFrame(columns=['question']) | |
| df.to_excel(excel_file, index=False) | |
| def chatbot(user_query): | |
| closest_question, similarity = get_closest_question(user_query, data_dict['questions'], threshold=0.95) | |
| if closest_question and similarity >= 0.95: | |
| answer_index = data_dict['questions'].index(closest_question) | |
| answer = data_dict['answers'][answer_index] | |
| else: | |
| new_data = pd.DataFrame({'question': [user_query]}) | |
| df = pd.read_excel(excel_file) | |
| df = pd.concat([df, new_data], ignore_index=True) | |
| with pd.ExcelWriter(excel_file, engine='openpyxl', mode='w') as writer: | |
| df.to_excel(writer, index=False) | |
| answer = data_dict['default_answer'] | |
| return answer | |
| # Gradio Interface | |
| iface = gr.Interface(fn=chatbot, inputs="text", outputs="text") | |
| # FastAPI endpoint | |
| class Query(BaseModel): | |
| user_query: str | |
| async def get_answer(query: Query): | |
| user_query = query.user_query | |
| return {"answer": chatbot(user_query)} | |
| # Run the app with Uvicorn | |
| if __name__ == "__main__": | |
| iface.launch(share=True) | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |