|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
|
|
|
|
|
|
model_name = "shanover/medbot_godel_v3" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model = model.to(device) |
|
|
|
|
|
def generate_response(symptoms, max_length=512): |
|
|
"""Generate medical response based on symptoms""" |
|
|
input_text = symptoms |
|
|
input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=max_length, truncation=True) |
|
|
input_ids = input_ids.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output_ids = model.generate(input_ids) |
|
|
|
|
|
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
return generated_text |
|
|
|