Spaces:
Running
Running
| """ | |
| File: app.py | |
| Usage: Hugging Face Spaces Deployment | |
| Description: Domain-Specific Assistant via LLMs Fine-Tuning. | |
| """ | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # Hugging Face Hub repository containing the fine-tuned model | |
| MODEL_REPO = "degide/tinyllama-medical-assistant" | |
| print("Downloading and loading the fine-tuned medical chatbot...") | |
| # 1. Load the Tokenizer and Model directly from Hugging Face Hub | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, trust_remote_code=True) | |
| # Configuring the model for efficient CPU loading. | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_REPO, | |
| device_map="cpu", | |
| trust_remote_code=True, | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True, | |
| ) | |
| model.eval() | |
| print("Model loaded successfully!") | |
| def detect_ood(query): | |
| """Heuristic-based Out-Of-Domain (OOD) detection.""" | |
| medical_keywords = [ | |
| 'symptom', 'disease', 'treatment', 'medicine', 'doctor', 'health', | |
| 'diabetes', 'blood', 'pressure', 'heart', 'pain', 'sick', 'hospital', | |
| 'care', 'diagnosis', 'patient', 'clinic', 'drug', 'therapy', 'cancer', | |
| 'syndrome', 'infection', 'virus', 'bacteria', 'pill', 'dosage' | |
| ] | |
| query_lower = query.lower() | |
| has_medical = any(kw in query_lower for kw in medical_keywords) | |
| non_medical_patterns = [ | |
| 'cook', 'recipe', 'weather', 'capital', 'python', 'code', | |
| 'movie', 'song', 'game', 'sports', 'programming', 'math' | |
| ] | |
| is_non_medical = any(pattern in query_lower for pattern in non_medical_patterns) | |
| return is_non_medical or not has_medical | |
| def generate_medical_response(message, history): | |
| """Generates the chatbot response with OOD handling.""" | |
| if detect_ood(message): | |
| return ( | |
| "**Out of Domain Detected:** I apologize, but I am a specialized medical " | |
| "assistant and can only answer health-related questions. Could you please " | |
| "ask me about medical symptoms, conditions, or treatments?\n\n" | |
| "*Examples:*\n" | |
| "- What are the symptoms of asthma?\n" | |
| "- How is high blood pressure diagnosed?" | |
| ) | |
| prompt = ( | |
| f"</s><|system|>You are a highly accurate and helpful medical assistant.</s>" | |
| f"<|user|>{message}</s>" | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| temperature=0.3, | |
| repetition_penalty=1.0, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| print(f"generated_text: {generated_text}") | |
| final_answer = generated_text.split("<|assistant|>")[-1].strip() | |
| if not final_answer: | |
| final_answer = "I apologize, but I am unable to generate a confident medical response to that exact phrasing. Could you please rephrase your question?" | |
| disclaimer = ( | |
| "\n\n---\n" | |
| "**Medical Disclaimer:** *This chatbot provides general health information " | |
| "only based on fine-tuned data. It is not a replacement for professional " | |
| "medical advice. Always consult a qualified healthcare provider.*" | |
| ) | |
| return final_answer + disclaimer | |
| # --- USER INTERFACE --- | |
| demo = gr.ChatInterface( | |
| fn=generate_medical_response, | |
| title="Domain-Specific Medical Assistant (TinyLlama)", | |
| description=( | |
| "An LLM fine-tuned via LoRA on the Medical Meadow Flashcards dataset. " | |
| "Ask questions about medical symptoms, conditions, and treatments." | |
| ), | |
| examples=[ | |
| "What are the common symptoms of type 2 diabetes?", | |
| "Explain the mechanism of action of metformin.", | |
| "What is the prognosis for patients with stage 3 chronic kidney disease?", | |
| "Describe the side effects of chemotherapy for breast cancer." | |
| ], | |
| chatbot=gr.Chatbot(height=600), | |
| save_history=True, | |
| fill_height=True, | |
| fill_width=True, | |
| submit_btn="Ask", | |
| stop_btn="Stop" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| share=False, | |
| server_port=7860, | |
| ) |