yukee1992 commited on
Commit
4939b75
·
verified ·
1 Parent(s): 3e8f82f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -106
app.py CHANGED
@@ -1,109 +1,61 @@
1
- import os
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
4
- from fastapi import FastAPI, Request
5
- from fastapi.responses import JSONResponse
6
- import logging
7
- import uvicorn
8
-
9
- # Configure logging
10
- logging.basicConfig(
11
- level=logging.INFO,
12
- format='%(asctime)s - %(levelname)s - %(message)s'
13
- )
14
- logger = logging.getLogger(__name__)
15
-
16
- # Configuration
17
- MODEL_ID = "google/gemma-1.1-2b-it"
18
- HF_TOKEN = os.getenv("HF_TOKEN", "")
19
- MAX_TOKENS = 400
20
- DEVICE = "cpu"
21
- PORT = int(os.getenv("PORT", 7860))
22
-
23
- class ScriptGenerator:
24
- def __init__(self):
25
- self.tokenizer = None
26
- self.model = None
27
- self.generation_config = None
28
- self.loaded = False
29
-
30
- def load_model(self):
31
- if self.loaded: return
32
-
33
- logger.info("Loading model...")
34
- try:
35
- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
36
- self.model = AutoModelForCausalLM.from_pretrained(
37
- MODEL_ID,
38
- torch_dtype=torch.float32,
39
- device_map=None,
40
- token=HF_TOKEN,
41
- low_cpu_mem_usage=True
42
- ).to(DEVICE)
43
-
44
- self.generation_config = GenerationConfig(
45
- max_new_tokens=MAX_TOKENS,
46
- do_sample=True,
47
- top_p=0.9,
48
- num_beams=1,
49
- no_repeat_ngram_size=2,
50
- pad_token_id=self.tokenizer.eos_token_id
51
- )
52
-
53
- self.loaded = True
54
- logger.info("Model loaded | Port: %s", PORT)
55
- except Exception as e:
56
- logger.error("Load failed: %s", str(e))
57
- raise
58
-
59
- generator = ScriptGenerator()
60
-
61
- app = FastAPI()
62
-
63
- @app.on_event("startup")
64
- def startup():
65
- generator.load_model()
66
 
67
- @app.post("/api/predict")
68
- async def predict(request: Request):
69
  try:
70
- data = await request.json()
71
- topic = data.get("topic", "")
72
-
73
- if isinstance(topic, list):
74
- topic = topic[0] if len(topic) > 0 else ""
75
- topic = str(topic).strip()
76
-
77
- logger.info("Processing: %.30s...", topic)
78
-
79
- inputs = generator.tokenizer(
80
- f"Create 1-minute script about {topic}:\n1) Hook\n2) Main\n3) CTA\n\nScript:",
81
- return_tensors="pt"
82
- ).to(DEVICE)
83
-
84
- outputs = generator.model.generate(
85
- **inputs,
86
- generation_config=generator.generation_config
87
- )
88
-
89
- script = generator.tokenizer.decode(outputs[0], skip_special_tokens=True)
90
- return JSONResponse({"result": script})
91
-
92
  except Exception as e:
93
- logger.error("API error: %s", str(e))
94
- return JSONResponse({"error": str(e)}, status_code=500)
95
-
96
- if __name__ == "__main__":
97
- # Hugging Face Spaces compatibility
98
- if os.getenv("SPACES", "false").lower() == "true":
99
- os.environ["GRADIO_SERVER_PORT"] = str(PORT)
100
- os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0"
101
-
102
- uvicorn.run(
103
- app,
104
- host="0.0.0.0",
105
- port=PORT,
106
- log_level="info",
107
- workers=1,
108
- timeout_keep_alive=30
109
- )
 
1
+ from fastapi import BackgroundTasks
2
+ import httpx
3
+ import uuid
4
+
5
+ jobs = {} # Stores ongoing jobs
6
+
7
+ @app.post("/api/submit")
8
+ async def submit_job(
9
+ request: Request,
10
+ background_tasks: BackgroundTasks
11
+ ):
12
+ data = await request.json()
13
+ job_id = str(uuid.uuid4())
14
+
15
+ # Store job details
16
+ jobs[job_id] = {
17
+ "status": "processing",
18
+ "result": None,
19
+ "callback_url": data.get("callback_url") # n8n webhook URL
20
+ }
21
+
22
+ # Start background task
23
+ background_tasks.add_task(
24
+ process_job,
25
+ job_id,
26
+ data["topic"]
27
+ )
28
+
29
+ return {"job_id": job_id, "status": "queued"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ async def process_job(job_id: str, topic: str):
 
32
  try:
33
+ script = generate_script(topic) # Your existing function
34
+ jobs[job_id]["status"] = "complete"
35
+ jobs[job_id]["result"] = script
36
+
37
+ # Send back to n8n via webhook
38
+ if jobs[job_id]["callback_url"]:
39
+ async with httpx.AsyncClient() as client:
40
+ await client.post(
41
+ jobs[job_id]["callback_url"],
42
+ json={
43
+ "job_id": job_id,
44
+ "status": "complete",
45
+ "result": script
46
+ },
47
+ timeout=30.0
48
+ )
49
+
 
 
 
 
 
50
  except Exception as e:
51
+ jobs[job_id]["status"] = "failed"
52
+ jobs[job_id]["error"] = str(e)
53
+ if jobs[job_id]["callback_url"]:
54
+ await httpx.post(
55
+ jobs[job_id]["callback_url"],
56
+ json={
57
+ "job_id": job_id,
58
+ "status": "failed",
59
+ "error": str(e)
60
+ }
61
+ )