Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModel | |
| from openai import OpenAI | |
| import os | |
| import numpy as np | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| # Load the NASA-specific bi-encoder model and tokenizer | |
| bi_encoder_model_name = "nasa-impact/nasa-smd-ibm-st-v2" | |
| bi_tokenizer = AutoTokenizer.from_pretrained(bi_encoder_model_name) | |
| bi_model = AutoModel.from_pretrained(bi_encoder_model_name) | |
| # Set up OpenAI client | |
| api_key = os.getenv('OPENAI_API_KEY') | |
| client = OpenAI(api_key=api_key) | |
| def encode_text(text): | |
| inputs = bi_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128) | |
| outputs = bi_model(**inputs) | |
| return outputs.last_hidden_state.mean(dim=1).detach().numpy() | |
| def retrieve_relevant_context(user_input, context_texts): | |
| user_embedding = encode_text(user_input) | |
| context_embeddings = np.array([encode_text(text) for text in context_texts]) | |
| similarities = cosine_similarity(user_embedding, context_embeddings).flatten() | |
| most_relevant_idx = np.argmax(similarities) | |
| return context_texts[most_relevant_idx] | |
| def generate_response(user_input, relevant_context): | |
| combined_input = f"Context: {relevant_context}\nQuestion: {user_input}\nAnswer:" | |
| response = client.chat.completions.create( | |
| model="gpt-4", | |
| messages=[ | |
| {"role": "user", "content": combined_input} | |
| ], | |
| max_tokens=150, | |
| temperature=0.7, | |
| top_p=0.9, | |
| frequency_penalty=0.5, | |
| presence_penalty=0.0 | |
| ) | |
| return response.choices[0].message['content'].strip() | |
| def chatbot(user_input, context=""): | |
| context_texts = context.split("\n") | |
| relevant_context = retrieve_relevant_context(user_input, context_texts) if context else "" | |
| response = generate_response(user_input, relevant_context) | |
| return response | |
| # Create the Gradio interface | |
| iface = gr.Interface( | |
| fn=chatbot, | |
| inputs=[ | |
| gr.Textbox(lines=2, placeholder="Enter your message here..."), | |
| gr.Textbox(lines=5, placeholder="Enter context here, separated by new lines...") | |
| ], | |
| outputs="text", | |
| title="Context-Aware Dynamic Response Chatbot", | |
| description="A chatbot using a NASA-specific bi-encoder model to understand the input context and GPT-4 to generate dynamic responses. Enter context to get more refined and relevant responses." | |
| ) | |
| # Launch the interface | |
| iface.launch() | |