Spaces:
Paused
Paused
| import gradio as gr | |
| from transformers import pipeline | |
| import torch | |
| import re | |
| import os | |
| import spaces | |
| torch.set_default_device("cuda") | |
| model_id = "glides/mistral-eap" | |
| pipe = pipeline("text-generation", model=model_id, device_map="auto") | |
| system_prompt = os.environ["sys"] | |
| def follows_rules(s): | |
| pattern = r'<thinking>.+?</thinking><output>.+?</output><reflecting>.+?</reflecting><refined>.+?</refined>' | |
| return bool(re.match(pattern, s.replace("\n", ""))) | |
| def predict(input_text, history): | |
| chat = [{"role": "system", "content": system_prompt}] | |
| for item in history: | |
| chat.append({"role": "user", "content": item[0]}) | |
| if item[1] is not None: | |
| chat.append({"role": "assistant", "content": item[1]}) | |
| chat.append({"role": "user", "content": input_text}) | |
| generated_text = pipe(chat, max_new_tokens=2 ** 16)[0]['generated_text'][-1]['content'] | |
| removed_pres = "<thinking>" + generated_text.split("<thinking>")[-1] | |
| removed_posts = removed_pres.split("</refined>")[0] + "</refined>" | |
| while not follows_rules(removed_posts): | |
| print(f"model output {generated_text} was found invalid") | |
| generated_text = pipe(chat, max_new_tokens=2 ** 16)[0]['generated_text'][-1]['content'] | |
| model_output = removed_posts.split("<refined>")[-1].replace("</refined>", "") | |
| return model_output.strip() | |
| gr.ChatInterface(predict, theme="soft").launch() |