smgc commited on
Commit
dc92a6e
·
verified ·
1 Parent(s): 306b7a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -18
app.py CHANGED
@@ -8,10 +8,12 @@ from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.responses import StreamingResponse
9
  from pydantic import BaseModel
10
  from typing import List, Optional
 
11
  from google.protobuf import descriptor
12
  from google.protobuf import descriptor_pool
13
  from google.protobuf import symbol_database
14
- from google.protobuf.json_format import MessageToDict
 
15
 
16
  # 加载环境变量
17
  load_dotenv()
@@ -115,24 +117,115 @@ config = Config()
115
 
116
  # 动态生成 Proto
117
  def generate_proto_classes():
118
- pool = descriptor_pool.Default()
119
-
120
- # 注册 GPT Proto
121
- desc = descriptor.FileDescriptor(
122
- name='gpt_service.proto',
123
- package='runtime.aot.machine_learning.parents.gpt',
124
- serialized_pb=GPT_PROTO.encode()
125
- )
126
- pool.Add(desc)
127
-
128
- # 注册 Vertex Proto
129
- desc = descriptor.FileDescriptor(
130
- name='vertex_service.proto',
131
- package='runtime.aot.machine_learning.parents.vertex',
132
- serialized_pb=VERTEX_PROTO.encode()
133
- )
134
- pool.Add(desc)
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  return pool
137
 
138
  # 生成 Proto 类
 
8
  from fastapi.responses import StreamingResponse
9
  from pydantic import BaseModel
10
  from typing import List, Optional
11
+ from google.protobuf import descriptor_pb2
12
  from google.protobuf import descriptor
13
  from google.protobuf import descriptor_pool
14
  from google.protobuf import symbol_database
15
+ from google.protobuf.compiler import plugin_pb2
16
+ from google.protobuf.json_format import MessageToDic
17
 
18
  # 加载环境变量
19
  load_dotenv()
 
117
 
118
  # 动态生成 Proto
119
  def generate_proto_classes():
120
+ pool = descriptor_pool.DescriptorPool()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
+ # 为 GPT 服务创建文件描述符
123
+ gpt_file = descriptor_pb2.FileDescriptorProto()
124
+ gpt_file.name = "gpt_service.proto"
125
+ gpt_file.package = "runtime.aot.machine_learning.parents.gpt"
126
+ gpt_file.syntax = "proto3"
127
+
128
+ # GPT 服务消息定义
129
+ message = gpt_file.message_type.add()
130
+ message.name = "Message"
131
+ field = message.field.add()
132
+ field.name = "role"
133
+ field.number = 1
134
+ field.type = descriptor.FieldDescriptor.TYPE_UINT64
135
+ field = message.field.add()
136
+ field.name = "message"
137
+ field.number = 2
138
+ field.type = descriptor.FieldDescriptor.TYPE_STRING
139
+
140
+ # Request 消息
141
+ request = gpt_file.message_type.add()
142
+ request.name = "Request"
143
+ field = request.field.add()
144
+ field.name = "models"
145
+ field.number = 1
146
+ field.type = descriptor.FieldDescriptor.TYPE_STRING
147
+ field = request.field.add()
148
+ field.name = "messages"
149
+ field.number = 2
150
+ field.type = descriptor.FieldDescriptor.TYPE_MESSAGE
151
+ field.type_name = ".runtime.aot.machine_learning.parents.gpt.Message"
152
+ field.label = descriptor.FieldDescriptor.LABEL_REPEATED
153
+ field = request.field.add()
154
+ field.name = "temperature"
155
+ field.number = 3
156
+ field.type = descriptor.FieldDescriptor.TYPE_DOUBLE
157
+ field = request.field.add()
158
+ field.name = "top_p"
159
+ field.number = 4
160
+ field.type = descriptor.FieldDescriptor.TYPE_DOUBLE
161
+
162
+ # Response 消息
163
+ response = gpt_file.message_type.add()
164
+ response.name = "Response"
165
+ field = response.field.add()
166
+ field.name = "response_code"
167
+ field.number = 2
168
+ field.type = descriptor.FieldDescriptor.TYPE_UINT64
169
+ field = response.field.add()
170
+ field.name = "body"
171
+ field.number = 4
172
+ field.type = descriptor.FieldDescriptor.TYPE_MESSAGE
173
+ field.type_name = ".runtime.aot.machine_learning.parents.gpt.Body"
174
+ field.label = descriptor.FieldDescriptor.LABEL_OPTIONAL
175
+
176
+ # Body 消息
177
+ body = gpt_file.message_type.add()
178
+ body.name = "Body"
179
+ field = body.field.add()
180
+ field.name = "id"
181
+ field.number = 1
182
+ field.type = descriptor.FieldDescriptor.TYPE_STRING
183
+ field = body.field.add()
184
+ field.name = "object"
185
+ field.number = 2
186
+ field.type = descriptor.FieldDescriptor.TYPE_STRING
187
+ field = body.field.add()
188
+ field.name = "time"
189
+ field.number = 3
190
+ field.type = descriptor.FieldDescriptor.TYPE_UINT64
191
+ field = body.field.add()
192
+ field.name = "message_warpper"
193
+ field.number = 4
194
+ field.type = descriptor.FieldDescriptor.TYPE_MESSAGE
195
+ field.type_name = ".runtime.aot.machine_learning.parents.gpt.MessageWarpper"
196
+
197
+ # MessageWarpper 消息
198
+ message_wrapper = gpt_file.message_type.add()
199
+ message_wrapper.name = "MessageWarpper"
200
+ field = message_wrapper.field.add()
201
+ field.name = "arg1"
202
+ field.number = 1
203
+ field.type = descriptor.FieldDescriptor.TYPE_INT64
204
+ field = message_wrapper.field.add()
205
+ field.name = "message"
206
+ field.number = 2
207
+ field.type = descriptor.FieldDescriptor.TYPE_MESSAGE
208
+ field.type_name = ".runtime.aot.machine_learning.parents.gpt.Message"
209
+
210
+ # GPT 服务定义
211
+ service = gpt_file.service.add()
212
+ service.name = "GPTInferenceService"
213
+ method = service.method.add()
214
+ method.name = "Predict"
215
+ method.input_type = ".runtime.aot.machine_learning.parents.gpt.Request"
216
+ method.output_type = ".runtime.aot.machine_learning.parents.gpt.Response"
217
+ method = service.method.add()
218
+ method.name = "PredictWithStream"
219
+ method.input_type = ".runtime.aot.machine_learning.parents.gpt.Request"
220
+ method.output_type = ".runtime.aot.machine_learning.parents.gpt.Response"
221
+ method.server_streaming = True
222
+
223
+ # 将文件描述符添加到池中
224
+ pool.Add(gpt_file)
225
+
226
+ # Vertex 服务的定义类似...
227
+ # 这里省略 Vertex 服务的定义,原理相���
228
+
229
  return pool
230
 
231
  # 生成 Proto 类