Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import spaces | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # Load MedScholar model and tokenizer | |
| model_name = "yasserrmd/MedScholar-1.5B" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True | |
| ) | |
| model.eval() | |
| # Chat function (streaming style) | |
| def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p): | |
| # Prepare the full conversation | |
| conversation = [{"role": "system", "content": system_message}] | |
| for user_msg, bot_reply in history: | |
| if user_msg: | |
| conversation.append({"role": "user", "content": user_msg}) | |
| if bot_reply: | |
| conversation.append({"role": "assistant", "content": bot_reply}) | |
| conversation.append({"role": "user", "content": message}) | |
| # Convert conversation into prompt string | |
| prompt = "" | |
| for turn in conversation: | |
| if turn["role"] == "system": | |
| prompt += f"<|system|>\n{turn['content']}\n" | |
| elif turn["role"] == "user": | |
| prompt += f"<|user|>\n{turn['content']}\n" | |
| elif turn["role"] == "assistant": | |
| prompt += f"<|assistant|>\n{turn['content']}\n" | |
| prompt += "<|assistant|>\n" | |
| # Tokenize | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| # Generate with streaming-like loop | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| # Decode and stream the new content | |
| decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| response = decoded.split("<|assistant|>\n")[-1].strip() | |
| yield response | |
| # Build Gradio interface | |
| demo = gr.ChatInterface( | |
| respond, | |
| additional_inputs=[ | |
| gr.Textbox(value="You are a helpful medical assistant.", label="System message"), | |
| gr.Slider(minimum=1, maximum=1024, value=512, step=1, label="Max new tokens"), | |
| gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"), | |
| ], | |
| title="🩺 MedScholar-1.5B: Medical Chatbot" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |