yukee1992 commited on
Commit
4d5089c
·
verified ·
1 Parent(s): 606735e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -11
app.py CHANGED
@@ -1,15 +1,14 @@
1
  import os
2
  import uuid
3
  import httpx
 
 
 
4
  from fastapi import FastAPI, Request, BackgroundTasks, HTTPException
5
  from fastapi.responses import JSONResponse
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
- import logging
8
  import uvicorn
9
- from typing import Dict
10
-
11
- # Initialize FastAPI app FIRST
12
- app = FastAPI()
13
 
14
  # Configuration
15
  MODEL_ID = "google/gemma-1.1-2b-it"
@@ -39,9 +38,10 @@ class ScriptGenerator:
39
  self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
40
  self.model = AutoModelForCausalLM.from_pretrained(
41
  MODEL_ID,
42
- torch_dtype=torch.float32,
43
  device_map=None,
44
- token=HF_TOKEN
 
45
  ).to(DEVICE)
46
  self.loaded = True
47
  logger.info("Model loaded successfully")
@@ -49,6 +49,15 @@ class ScriptGenerator:
49
  logger.error(f"Model loading failed: {str(e)}")
50
  raise
51
 
 
 
 
 
 
 
 
 
 
52
  generator = ScriptGenerator()
53
 
54
  def generate_script(topic: str) -> str:
@@ -149,10 +158,6 @@ async def get_status(job_id: str):
149
  raise HTTPException(status_code=404, detail="Job not found")
150
  return jobs[job_id]
151
 
152
- @app.on_event("startup")
153
- async def startup():
154
- generator.load_model()
155
-
156
  if __name__ == "__main__":
157
  uvicorn.run(
158
  app,
 
1
  import os
2
  import uuid
3
  import httpx
4
+ import torch # <-- MISSING IMPORT ADDED
5
+ import logging
6
+ from typing import Dict
7
  from fastapi import FastAPI, Request, BackgroundTasks, HTTPException
8
  from fastapi.responses import JSONResponse
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
10
  import uvicorn
11
+ from contextlib import asynccontextmanager
 
 
 
12
 
13
  # Configuration
14
  MODEL_ID = "google/gemma-1.1-2b-it"
 
38
  self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
39
  self.model = AutoModelForCausalLM.from_pretrained(
40
  MODEL_ID,
41
+ torch_dtype=torch.float32, # Now torch is defined
42
  device_map=None,
43
+ token=HF_TOKEN,
44
+ low_cpu_mem_usage=True
45
  ).to(DEVICE)
46
  self.loaded = True
47
  logger.info("Model loaded successfully")
 
49
  logger.error(f"Model loading failed: {str(e)}")
50
  raise
51
 
52
+ # Modern lifespan handler (replaces @app.on_event)
53
+ @asynccontextmanager
54
+ async def lifespan(app: FastAPI):
55
+ generator = ScriptGenerator()
56
+ generator.load_model()
57
+ yield
58
+ # Cleanup if needed
59
+
60
+ app = FastAPI(lifespan=lifespan)
61
  generator = ScriptGenerator()
62
 
63
  def generate_script(topic: str) -> str:
 
158
  raise HTTPException(status_code=404, detail="Job not found")
159
  return jobs[job_id]
160
 
 
 
 
 
161
  if __name__ == "__main__":
162
  uvicorn.run(
163
  app,