| import torch |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoModelForTokenClassification, AutoTokenizer, pipeline |
| import gradio as gr |
|
|
| |
| MODEL_NAME = "gpt2" |
| tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME) |
| model = GPT2LMHeadModel.from_pretrained(MODEL_NAME) |
|
|
| |
| BIOBERT_MODEL = "dmis-lab/biobert-base-cased-v1.1" |
| biobert_tokenizer = AutoTokenizer.from_pretrained(BIOBERT_MODEL) |
| biobert_model = AutoModelForTokenClassification.from_pretrained(BIOBERT_MODEL) |
| nlp_pipeline = pipeline("ner", model=biobert_model, tokenizer=biobert_tokenizer) |
|
|
| |
| conversation_history = [] |
|
|
| |
| def generate_text(prompt, max_length=100, temperature=0.7, top_k=50, top_p=0.9, clean_output=False, stopwords="", num_responses=1, extract_entities=False): |
| global conversation_history |
| |
| |
| full_prompt = "\n".join(conversation_history + [prompt]) |
| inputs = tokenizer(full_prompt, return_tensors="pt") |
| |
| responses = [] |
| for _ in range(num_responses): |
| with torch.no_grad(): |
| output = model.generate( |
| **inputs, |
| max_length=max_length, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| do_sample=True |
| ) |
| text = tokenizer.decode(output[0], skip_special_tokens=True) |
| |
| |
| if clean_output: |
| text = text.replace("\n", " ").strip() |
| |
| |
| for word in stopwords.split(","): |
| text = text.replace(word.strip(), "") |
| |
| |
| if extract_entities: |
| entities = nlp_pipeline(text) |
| extracted_entities = set([entity["word"] for entity in entities]) |
| text += "\n\nExtracted Medical Entities:\n" + "\n".join(extracted_entities) |
| |
| responses.append(text) |
| |
| |
| conversation_history.append(prompt) |
| conversation_history.append(responses[0]) |
| |
| return "\n\n".join(responses) |
|
|
| |
| demo = gr.Interface( |
| fn=generate_text, |
| inputs=[ |
| gr.Textbox(label="Medical Query"), |
| gr.Slider(50, 500, step=10, label="Max Length"), |
| gr.Slider(0.1, 1.5, step=0.1, label="Temperature"), |
| gr.Slider(0, 100, step=5, label="Top-K"), |
| gr.Slider(0.0, 1.0, step=0.1, label="Top-P"), |
| gr.Checkbox(label="Clean Output"), |
| gr.Textbox(label="Stopwords (comma-separated)"), |
| gr.Slider(1, 5, step=1, label="Number of Responses"), |
| gr.Checkbox(label="Extract Medical Entities") |
| ], |
| outputs=gr.Textbox(label="Generated Medical Text"), |
| title="Medical AI Assistant", |
| description="Enter a medical-related prompt and adjust parameters to generate AI-assisted text. Supports entity recognition for medical terms.", |
| ) |
|
|
| |
| demo.launch() |
|
|