Spaces:
Runtime error
Runtime error
| import os | |
| import requests | |
| import gradio as gr | |
| from conversation import Conversation | |
| class BaseModel: | |
| name: str | |
| endpoint: str | |
| namespace: str | |
| generation_params: dict | |
| def __init__(self, name, endpoint, namespace, generation_params): | |
| self.name = name | |
| self.endpoint = endpoint | |
| self.namespace = namespace | |
| self.generation_params = generation_params | |
| def generate_response(self, conversation, custom_generation_params=None): | |
| prompt = self._get_prompt(conversation) | |
| response = self._get_response(prompt, custom_generation_params) | |
| return response | |
| def _get_prompt(self, conversation: Conversation): | |
| prompt = "\n".join( | |
| [conversation.memory, conversation.prompt] | |
| ).strip() | |
| for message in conversation.messages: | |
| prompt += f"\n{message['from'].strip()}: {message['value'].strip()}" | |
| prompt += f"\n{conversation.bot_label}:" | |
| return prompt | |
| def _get_response(self, text, custom_generation_params): | |
| api = str(os.environ.get("API_BASE_PATH")).replace("\{\}", "{}") | |
| api = api.format(self.endpoint, self.namespace) | |
| parameters = self.generation_params | |
| if custom_generation_params is not None: | |
| parameters.update(custom_generation_params) | |
| payload = {'instances': [text], "parameters": parameters} | |
| resp = requests.post(api, json=payload, timeout=600) | |
| if resp.status_code != 200: | |
| raise gr.Error(f"Endpoint returned code: {resp.status_code}. " | |
| f"Solution: " | |
| f"1. Scale-to-Zero enabled, so please wait for some minutes and try again. " | |
| f"2. Probably the response generated by the model is to big, try changing max_new_tokens. " | |
| f"3. If nothing helps — report the problem.") | |
| predictions = resp.json()["predictions"] | |
| if isinstance(predictions[0], str): | |
| return predictions[0].strip() | |
| predictions = sorted(predictions[0], key=lambda d: d['score']) | |
| return predictions[-1]["text"].strip() | |