Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| from peft import PeftModel, PeftConfig | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig | |
| from conversation import Conversation | |
| MODEL_NAME = "warleagle/medical_chat_saiga" | |
| config = PeftConfig.from_pretrained(MODEL_NAME) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| config.base_model_name_or_path, | |
| load_in_4bit=True, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| model = PeftModel.from_pretrained( | |
| model, | |
| MODEL_NAME, | |
| torch_dtype=torch.float16 | |
| ) | |
| model.eval() | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False) | |
| generation_config = GenerationConfig.from_pretrained(MODEL_NAME) | |
| generation_config.max_new_tokens = 70 | |
| def generate(model, tokenizer, prompt, generation_config): | |
| data = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) | |
| data = {k: v.to(model.device) for k, v in data.items()} | |
| output_ids = model.generate( | |
| **data, | |
| generation_config=generation_config | |
| )[0] | |
| output_ids = output_ids[len(data["input_ids"][0]):] | |
| output = tokenizer.decode(output_ids, skip_special_tokens=True) | |
| return output.strip() | |
| def predict(input_data, temp): | |
| generation_config.temperature = temp | |
| conversation = Conversation() | |
| conversation.add_user_message(input_data) | |
| prompt = conversation.get_prompt() | |
| output_one = generate(model, tokenizer, prompt, generation_config) | |
| output_two = generate(model, tokenizer, prompt, generation_config) | |
| output_three = generate(model, tokenizer, prompt, generation_config) | |
| return output_one, output_two, output_three | |
| io = gr.Interface(predict, | |
| inputs=[gr.Textbox(value="Как записаться к стоматологу?", | |
| label="Введите текст:"), | |
| gr.Slider(minimum=0.01, | |
| maximum=1, | |
| value=0.3, | |
| step=0.1, | |
| info="Данный параметр позволяет изменять креативность модели. Чем больше, тем модель будет более креативная и наоборот.")], | |
| outputs=[gr.Textbox(label="Первый вариант ответа:"), | |
| gr.Textbox(label="Второй вариант ответа:"), | |
| gr.Textbox(label="Третий вариант ответа:")]) | |
| if __name__ == "__main__": | |
| io.launch() | |