Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import re | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer | |
| from peft import PeftModel | |
| from threading import Thread | |
| # ========================================== | |
| # 1. SETUP & AUTHENTICATION | |
| # ========================================== | |
| # Mistral requires an HF token to download the base model. | |
| # You must add your token in the HF Spaces "Settings" -> "Variables and secrets" as 'HF_TOKEN' | |
| hf_token = os.environ.get("HF_TOKEN") | |
| BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.2" | |
| ADAPTER_REPO = "your-username/Medical-Mistral-7B-LoRA" # <--- CHANGE THIS TO YOUR REPO | |
| print("Booting up Clinical AI Server...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=hf_token) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Load Base Model in 4-bit precision | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16 | |
| ) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| token=hf_token | |
| ) | |
| # Merge Base Model with your Trained LoRA Adapter | |
| print("Downloading and attaching LoRA Adapter from Hub...") | |
| model = PeftModel.from_pretrained(base_model, ADAPTER_REPO, token=hf_token) | |
| model.eval() | |
| print("✅ Model Ready!") | |
| # ========================================== | |
| # 2. FORMATTING FUNCTION | |
| # ========================================== | |
| def format_and_clean(text): | |
| end_triggers = ["regards", "hope this", "hope i", "let me know", "dr.", "thanks for"] | |
| lower_text = text.lower() | |
| cutoff = len(text) | |
| for trigger in end_triggers: | |
| idx = lower_text.find(trigger) | |
| if idx != -1 and idx < cutoff: | |
| cutoff = idx | |
| clean = text[:cutoff].strip().replace('*', '') | |
| sentences = clean.split(". ") | |
| capitalized = [s.capitalize() for s in sentences if s] | |
| clean = ". ".join(capitalized) | |
| clean = re.sub(r'(?i)assessment\s*:', '**Assessment:** ', clean) | |
| clean = re.sub(r'(?i)analysis\s*:', '\n\n**Analysis:** ', clean) | |
| clean = re.sub(r'(?i)recommended action\s*:', '\n\n**Recommended Action:** ', clean) | |
| return clean | |
| # ========================================== | |
| # 3. CHAT ENGINE & UI | |
| # ========================================== | |
| def clinical_chat(message, history): | |
| prompt = f"""<s>[INST] You are a highly intelligent clinical AI. The user says: "{message}" | |
| If the user provides symptoms, diagnose them. If the user asks about a known condition, provide treatment advice. | |
| You MUST format your response exactly like this: | |
| **Assessment:** [Name of the condition] | |
| **Analysis:** [Brief explanation] | |
| **Recommended Action:** [Medical advice] | |
| Do not include greetings, sign-offs, or links. [/INST]""" | |
| inputs = tokenizer(prompt, return_tensors="pt").to("cuda") | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10.0) | |
| generation_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=300, | |
| temperature=0.2, | |
| repetition_penalty=1.15, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| partial_text = "" | |
| for new_token in streamer: | |
| partial_text += new_token | |
| yield format_and_clean(partial_text) | |
| # Launch UI inside Docker container | |
| theme = gr.themes.Soft(primary_hue="blue", neutral_hue="slate") | |
| demo = gr.ChatInterface( | |
| fn=clinical_chat, | |
| title="⚕️ Clinical AI Diagnostic Assistant", | |
| description="**Enter your symptoms or medical queries below for a professional analysis.**", | |
| theme=theme, | |
| examples=[ | |
| "I am having severe headache, body pain and strain in my neck.", | |
| "What medicine should I take if I have high cholesterol?", | |
| "I have been sneezing and coughing for 3 days." | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| # 0.0.0.0 and port 7860 are strictly required for Docker on Hugging Face Spaces | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |