import torch import gradio as gr from peft import PeftModel, PeftConfig from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from conversation import Conversation MODEL_NAME = "warleagle/medical_chat_saiga" config = PeftConfig.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( config.base_model_name_or_path, load_in_4bit=True, torch_dtype=torch.float16, device_map="auto" ) model = PeftModel.from_pretrained( model, MODEL_NAME, torch_dtype=torch.float16 ) model.eval() tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False) generation_config = GenerationConfig.from_pretrained(MODEL_NAME) generation_config.max_new_tokens = 70 def generate(model, tokenizer, prompt, generation_config): data = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) data = {k: v.to(model.device) for k, v in data.items()} output_ids = model.generate( **data, generation_config=generation_config )[0] output_ids = output_ids[len(data["input_ids"][0]):] output = tokenizer.decode(output_ids, skip_special_tokens=True) return output.strip() def predict(input_data, temp): generation_config.temperature = temp conversation = Conversation() conversation.add_user_message(input_data) prompt = conversation.get_prompt() output_one = generate(model, tokenizer, prompt, generation_config) output_two = generate(model, tokenizer, prompt, generation_config) output_three = generate(model, tokenizer, prompt, generation_config) return output_one, output_two, output_three io = gr.Interface(predict, inputs=[gr.Textbox(value="Как записаться к стоматологу?", label="Введите текст:"), gr.Slider(minimum=0.01, maximum=1, value=0.3, step=0.1, info="Данный параметр позволяет изменять креативность модели. Чем больше, тем модель будет более креативная и наоборот.")], outputs=[gr.Textbox(label="Первый вариант ответа:"), gr.Textbox(label="Второй вариант ответа:"), gr.Textbox(label="Третий вариант ответа:")]) if __name__ == "__main__": io.launch()