File size: 816 Bytes
dccbddd
6f24db8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# Load the model and tokenizer
model_name = "shanover/medbot_godel_v3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

def generate_response(symptoms, max_length=512):
    """Generate medical response based on symptoms"""
    input_text = symptoms
    input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=max_length, truncation=True)
    input_ids = input_ids.to(device)
    
    with torch.no_grad():
        output_ids = model.generate(input_ids)
    
    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return generated_text