Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import gradio as gr | |
| # βββ CONFIG βββ | |
| REPO_ID = "CodCodingCode/llama-3.1-8b-clinical" | |
| SUBFOLDER = "checkpoint-45000" | |
| HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") | |
| if not HF_TOKEN: | |
| raise RuntimeError("Missing HUGGINGFACE_HUB_TOKEN in env") | |
| # βββ 1) Download the full repo βββ | |
| local_cache = snapshot_download( | |
| repo_id=REPO_ID, | |
| token=HF_TOKEN, | |
| ) | |
| print("[DEBUG] snapshot_download β local_cache:", local_cache) | |
| import pathlib | |
| print( | |
| "[DEBUG] MODEL root contents:", | |
| list(pathlib.Path(local_cache).glob(f"{SUBFOLDER}/*")), | |
| ) | |
| # βββ 2) Repo root contains tokenizer.json; model shards live in the checkpoint subfolder βββ | |
| MODEL_DIR = local_cache | |
| MODEL_SUBFOLDER = SUBFOLDER | |
| print("[DEBUG] MODEL_DIR:", MODEL_DIR) | |
| print("[DEBUG] MODEL_DIR files:", os.listdir(MODEL_DIR)) | |
| print("[DEBUG] Checkpoint files:", os.listdir(os.path.join(MODEL_DIR, MODEL_SUBFOLDER))) | |
| # βββ 3) Load tokenizer & model from disk βββ | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_DIR, | |
| use_fast=True, | |
| ) | |
| print("[DEBUG] Loaded fast tokenizer object:", tokenizer, "type:", type(tokenizer)) | |
| # Confirm tokenizer files are present | |
| import os | |
| print("[DEBUG] Files in MODEL_DIR for tokenizer:", os.listdir(MODEL_DIR)) | |
| # Inspect tokenizer's initialization arguments | |
| try: | |
| print("[DEBUG] Tokenizer init_kwargs:", tokenizer.init_kwargs) | |
| except AttributeError: | |
| print("[DEBUG] No init_kwargs attribute on tokenizer.") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_DIR, | |
| subfolder=MODEL_SUBFOLDER, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| ) | |
| model.eval() | |
| print( | |
| "[DEBUG] Loaded model object:", | |
| model.__class__.__name__, | |
| "device:", | |
| next(model.parameters()).device, | |
| ) | |
| # === Role Agent with instruction/input/output format === | |
| class RoleAgent: | |
| def __init__(self, role_instruction, tokenizer, model): | |
| self.tokenizer = tokenizer | |
| self.model = model | |
| self.role_instruction = role_instruction | |
| def act(self, input_text): | |
| prompt = ( | |
| f"instruction: {self.role_instruction}\n" | |
| f"input: {input_text}\n" | |
| f"output:" | |
| ) | |
| encoding = self.tokenizer(prompt, return_tensors="pt") | |
| inputs = {k: v.to(self.model.device) for k, v in encoding.items()} | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| do_sample=True, | |
| temperature=0.7, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| thinking = "" | |
| answer = response | |
| if "Output:" in response: | |
| # Split on the last occurrence of 'Output:' in case it's repeated | |
| answer = response.rsplit("Output:", 1)[-1].strip() | |
| else: | |
| # Fallback: if thinking/answer/end tags exist, use previous logic | |
| tags = ("THINKING:", "ANSWER:", "END") | |
| if all(tag in response for tag in tags): | |
| print("[FIX] tagged response detected:", response) | |
| block = response.split("THINKING:", 1)[1].split("END", 1)[0] | |
| thinking = block.split("ANSWER:", 1)[0].strip() | |
| answer = block.split("ANSWER:", 1)[1].strip() | |
| print( | |
| "[THINKING ANSWER SPLIT] thinking/answer split:", | |
| response, | |
| "β", | |
| "[THINKING] thinking:", | |
| thinking, | |
| "[ANSWER] answer:", | |
| answer, | |
| ) | |
| return {"thinking": thinking, "output": answer} | |
| # === Agents === | |
| summarizer = RoleAgent( | |
| role_instruction="You are a clinical summarizer trained to extract structured vignettes from doctorβpatient dialogues.", | |
| tokenizer=tokenizer, | |
| model=model, | |
| ) | |
| diagnoser = RoleAgent( | |
| role_instruction="You are a board-certified diagnostician that diagnoses patients.", | |
| tokenizer=tokenizer, | |
| model=model, | |
| ) | |
| questioner = RoleAgent( | |
| role_instruction="You are a physician asking questions to diagnose a patient.", | |
| tokenizer=tokenizer, | |
| model=model, | |
| ) | |
| treatment_agent = RoleAgent( | |
| role_instruction="You are a board-certified clinician. Based on the diagnosis and patient vignette provided below, suggest a concise treatment plan that could realistically be initiated by a primary care physician or psychiatrist.", | |
| tokenizer=tokenizer, | |
| model=model, | |
| ) | |
| """[DEBUG] prompt: Instruction: You are a clinical summarizer trained to extract structured vignettes from doctorβpatient dialogues. | |
| Input: Doctor: What brings you in today? | |
| Patient: I am a male. I am 15. My knee hurts. What may be the issue with my knee? | |
| Previous Vignette: | |
| Output: | |
| Instruction: You are a clinical summarizer trained to extract structured vignettes from doctorβpatient dialogues. | |
| Input: Doctor: What brings you in today? | |
| Patient: I am a male. I am 15. My knee hurts. What may be the issue with my knee? | |
| Previous Vignette: | |
| Output: The patient is a 15-year-old male presenting with knee pain.""" | |
| # === Inference State === | |
| conversation_history = [] | |
| summary = "" | |
| diagnosis = "" | |
| # === Gradio Inference === | |
| def simulate_interaction(user_input, conversation_history=None): | |
| """Single turn interaction - no iterations, uses accumulated history""" | |
| if conversation_history is None: | |
| history = [f"Doctor: What brings you in today?", f"Patient: {user_input}"] | |
| else: | |
| history = conversation_history.copy() | |
| history.append(f"Patient: {user_input}") | |
| # Summarize the full conversation history | |
| sum_in = "\n".join(history) | |
| sum_out = summarizer.act(sum_in) | |
| summary = sum_out["output"] | |
| # Diagnose based on summary | |
| diag_out = diagnoser.act(summary) | |
| diagnosis = diag_out["output"] | |
| # Generate next question based on current understanding | |
| q_in = f"Vignette: {summary}\nCurrent Estimated Diagnosis: {diagnosis}" | |
| q_out = questioner.act(q_in) | |
| # Add doctor's response to history | |
| history.append(f"Doctor: {q_out['output']}") | |
| # Generate treatment plan (but don't end conversation) | |
| treatment_out = treatment_agent.act(f"Diagnosis: {diagnosis}\nVignette: {summary}") | |
| return { | |
| "summary": sum_out, | |
| "diagnosis": diag_out, | |
| "question": q_out, | |
| "treatment": treatment_out, | |
| "conversation": history, # Return full history list | |
| } | |
| # === Gradio UI === | |
| def ui_fn(user_input): | |
| """Non-stateful version for testing""" | |
| res = simulate_interaction(user_input) | |
| return f"""π Vignette Summary: | |
| π THINKING: {res['summary']['thinking']} | |
| π SUMMARY: {res['summary']['output']} | |
| π©Ί Diagnosis: | |
| π THINKING: {res['diagnosis']['thinking']} | |
| π DIAGNOSIS: {res['diagnosis']['output']} | |
| β Follow-up Question: | |
| π THINKING: {res['question']['thinking']} | |
| π¨ββοΈ DOCTOR: {res['question']['output']} | |
| π Treatment Plan: | |
| π THINKING: {res['treatment']['thinking']} | |
| π TREATMENT: {res['treatment']['output']} | |
| π¬ Full Conversation: | |
| {chr(10).join(res['conversation'])} | |
| """ | |
| # === Stateful Gradio UI === | |
| def stateful_ui_fn(user_input, history): | |
| """Proper stateful conversation handler""" | |
| # Initialize history if first interaction | |
| if history is None: | |
| history = [] | |
| # Run one turn of interaction with accumulated history | |
| res = simulate_interaction(user_input, history) | |
| # Get the updated conversation history | |
| updated_history = res["conversation"] | |
| # Format the display output | |
| display_output = f"""π¬ Conversation: | |
| {chr(10).join(updated_history)} | |
| π Current Assessment: | |
| π Diagnosis: {res['diagnosis']['output']} | |
| π Treatment Plan: {res['treatment']['output']} | |
| """ | |
| # Return display text and updated history for next turn | |
| return display_output, updated_history | |
| def chat_interface(user_input, history): | |
| """Alternative chat-style interface""" | |
| if history is None: | |
| history = [] | |
| # Run interaction | |
| res = simulate_interaction(user_input, history) | |
| updated_history = res["conversation"] | |
| # Return just the doctor's latest response and updated history | |
| doctor_response = res["question"]["output"] | |
| return doctor_response, updated_history | |
| # Create two different interfaces | |
| demo_stateful = gr.Interface( | |
| fn=stateful_ui_fn, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Patient Response", | |
| placeholder="Describe your symptoms or answer the doctor's question...", | |
| ), | |
| gr.State(), # holds the conversation history | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Medical Consultation", lines=15), | |
| gr.State(), # returns the updated history | |
| ], | |
| title="π§ AI Doctor - Full Medical Consultation", | |
| description="Have a conversation with an AI doctor. Each response builds on the previous conversation.", | |
| ) | |
| demo_chat = gr.Interface( | |
| fn=chat_interface, | |
| inputs=[ | |
| gr.Textbox(label="Your Response", placeholder="Tell me about your symptoms..."), | |
| gr.State(), | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Doctor", lines=5), | |
| gr.State(), | |
| ], | |
| title="π©Ί AI Doctor Chat", | |
| description="Simple chat interface with the AI doctor.", | |
| ) | |
| if __name__ == "__main__": | |
| # Launch the stateful version by default | |
| demo_stateful.launch(share=True) | |
| # Uncomment the line below to use the chat version instead: | |
| # demo_chat.launch(share=True) | |