jflo commited on
Commit
5da4f83
·
1 Parent(s): 1f1dae1

Implemented parallelism to run inference on mental and physical model

Browse files
Files changed (1) hide show
  1. main.py +8 -19
main.py CHANGED
@@ -5,6 +5,7 @@
5
  import os
6
  import anthropic
7
  from contextlib import asynccontextmanager
 
8
  from fastapi import FastAPI, HTTPException
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from pydantic import BaseModel
@@ -146,25 +147,13 @@ async def predict(request: PredictRequest):
146
  # 1. Enrich text with goal
147
  enriched_text = f"Goal: {request.user_profile.primary_goal}. {request.text}"
148
 
149
- # 2. Run physical model
150
- physical_output = run_inference(
151
- ml_models["physical_model"],
152
- ml_models["tokenizer"],
153
- enriched_text,
154
- PHYSICAL_LABEL_COLS,
155
- PHYSICAL_DECODERS
156
  )
157
 
158
- # 3. Run mental model
159
- mental_output = run_inference(
160
- ml_models["mental_model"],
161
- ml_models["tokenizer"],
162
- enriched_text,
163
- MENTAL_LABEL_COLS,
164
- MENTAL_DECODERS
165
- )
166
-
167
- # 4. Build Claude prompt
168
  prompt = build_claude_prompt(
169
  request.user_profile.primary_goal,
170
  request.user_profile.modifiers,
@@ -172,7 +161,7 @@ async def predict(request: PredictRequest):
172
  mental_output
173
  )
174
 
175
- # 5. Call Claude API
176
  # ANTHROPIC_API_KEY must be set as a HuggingFace Space secret
177
  claude_client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
178
  message = claude_client.messages.create(
@@ -181,7 +170,7 @@ async def predict(request: PredictRequest):
181
  messages=[{"role": "user", "content": prompt}]
182
  )
183
 
184
- # 6. Return full response
185
  return PredictResponse(
186
  physical=physical_output,
187
  mental=mental_output,
 
5
  import os
6
  import anthropic
7
  from contextlib import asynccontextmanager
8
+ import asyncio
9
  from fastapi import FastAPI, HTTPException
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from pydantic import BaseModel
 
147
  # 1. Enrich text with goal
148
  enriched_text = f"Goal: {request.user_profile.primary_goal}. {request.text}"
149
 
150
+ # 2. Run physical and mental model using asyncio
151
+ physical_output, mental_output = await asyncio.gather(
152
+ asyncio.to_thread(run_inference, ml_models["physical_model"], ml_models["tokenizer"], enriched_text, PHYSICAL_LABEL_COLS, PHYSICAL_DECODERS),
153
+ asyncio.to_thread(run_inference, ml_models["mental_model"], ml_models["tokenizer"], enriched_text, MENTAL_LABEL_COLS, MENTAL_DECODERS)
 
 
 
154
  )
155
 
156
+ # 3. Build Claude prompt
 
 
 
 
 
 
 
 
 
157
  prompt = build_claude_prompt(
158
  request.user_profile.primary_goal,
159
  request.user_profile.modifiers,
 
161
  mental_output
162
  )
163
 
164
+ # 4. Call Claude API
165
  # ANTHROPIC_API_KEY must be set as a HuggingFace Space secret
166
  claude_client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
167
  message = claude_client.messages.create(
 
170
  messages=[{"role": "user", "content": prompt}]
171
  )
172
 
173
+ # 5. Return full response
174
  return PredictResponse(
175
  physical=physical_output,
176
  mental=mental_output,