Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import BertTokenizer, BertForSequenceClassification, GPT2LMHeadModel, GPT2Tokenizer | |
| # --- 1. LOAD MODELS FROM HUGGING FACE HUB --- | |
| BERT_REPO = "Sankeerth004/fitbuddy-bert-intent" | |
| GPT2_REPO = "Sankeerth004/fitbuddy-gpt2-chatbot" | |
| print("Loading models from Hugging Face Hub...") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load models | |
| bert_tokenizer = BertTokenizer.from_pretrained(BERT_REPO) | |
| bert_model = BertForSequenceClassification.from_pretrained(BERT_REPO).to(device) | |
| bert_model.eval() | |
| gpt2_tokenizer = GPT2Tokenizer.from_pretrained(GPT2_REPO) | |
| gpt2_model = GPT2LMHeadModel.from_pretrained(GPT2_REPO).to(device) | |
| gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token | |
| gpt2_model.config.pad_token_id = gpt2_tokenizer.eos_token_id | |
| print("Models loaded successfully!") | |
| # --- 2. DEFINE HELPER FUNCTIONS --- | |
| intent_mapping = { | |
| 0: "Body Building", 1: "Meal Plan Recommendation", 2: "Recommend Meditation or Yoga", | |
| 3: "Suggest Recovery Exercises", 4: "Weight Loss" | |
| } | |
| def predict_intent(text): | |
| inputs = bert_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device) | |
| with torch.no_grad(): | |
| outputs = bert_model(**inputs) | |
| return intent_mapping.get(torch.argmax(outputs.logits, dim=1).item(), "Unknown Intent") | |
| def generate_response(prompt, max_length=200): | |
| encoding = gpt2_tokenizer(prompt, return_tensors='pt', padding=True, truncation=True) | |
| input_ids = encoding['input_ids'].to(device) | |
| attention_mask = encoding['attention_mask'].to(device) | |
| output = gpt2_model.generate( | |
| input_ids, attention_mask=attention_mask, do_sample=True, max_length=max_length, | |
| pad_token_id=gpt2_tokenizer.eos_token_id, top_k=50, top_p=0.8, temperature=0.7 | |
| ) | |
| return gpt2_tokenizer.decode(output[0], skip_special_tokens=True).split('[A]')[-1].strip() | |
| # --- 3. CREATE THE GRADIO CHATBOT FUNCTION --- | |
| def chat_with_fitbuddy(user_input, history): | |
| if not user_input.strip(): return "Please enter a valid question." | |
| intent = predict_intent(user_input) | |
| prompt = f"[Q] {user_input} [Intent: {intent}]" | |
| return generate_response(prompt) | |
| # --- 4. LAUNCH THE GRADIO INTERFACE --- | |
| iface = gr.ChatInterface( | |
| fn=chat_with_fitbuddy, | |
| title="🤖 FitBuddy Chatbot", | |
| description="Your AI-powered personal fitness assistant.", | |
| examples=[["How do I build muscle?"], ["What is a good diet?"]], | |
| ) | |
| iface.launch() |