yukee1992 commited on
Commit
34bab26
·
verified ·
1 Parent(s): 492f1bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -65
app.py CHANGED
@@ -9,18 +9,14 @@ import uvicorn
9
  # Configure logging
10
  logging.basicConfig(
11
  level=logging.INFO,
12
- format='%(asctime)s - %(levelname)s - %(message)s',
13
- handlers=[
14
- logging.StreamHandler(),
15
- logging.FileHandler('app.log')
16
- ]
17
  )
18
  logger = logging.getLogger(__name__)
19
 
20
  # Configuration
21
  MODEL_ID = "google/gemma-1.1-2b-it"
22
  HF_TOKEN = os.getenv("HF_TOKEN", "")
23
- MAX_TOKENS = 400 # For ~200 word scripts
24
  DEVICE = "cpu"
25
  PORT = int(os.getenv("PORT", 7860))
26
 
@@ -45,61 +41,23 @@ class ScriptGenerator:
45
  low_cpu_mem_usage=True
46
  ).to(DEVICE)
47
 
48
- # Configure valid generation parameters
49
  self.generation_config = GenerationConfig(
50
  max_new_tokens=MAX_TOKENS,
51
  do_sample=True,
52
- top_p=0.9, # Replaces temperature
53
  num_beams=1,
54
  no_repeat_ngram_size=2,
55
  pad_token_id=self.tokenizer.eos_token_id
56
  )
57
 
58
  self.loaded = True
59
- logger.info("Model and generation config loaded successfully")
60
  except Exception as e:
61
- logger.error(f"Model loading failed: {str(e)}")
62
  raise
63
 
64
  generator = ScriptGenerator()
65
 
66
- def generate_script(topic: str) -> str:
67
- """Generation with validated config"""
68
- try:
69
- if not generator.loaded:
70
- generator.load_model()
71
-
72
- prompt = (
73
- f"Generate a 1-minute (60s) video script about: {topic[:80]}\n"
74
- "Required structure with timings:\n"
75
- "[0:00-0:10] HOOK: Grab attention\n"
76
- "[0:10-0:50] MAIN: 3 key points\n"
77
- "[0:50-1:00] CTA: Clear action\n\n"
78
- "Script:\n"
79
- )
80
-
81
- inputs = generator.tokenizer(prompt, return_tensors="pt").to(DEVICE)
82
-
83
- # Use the pre-configured generation config
84
- outputs = generator.model.generate(
85
- **inputs,
86
- generation_config=generator.generation_config
87
- )
88
-
89
- script = generator.tokenizer.decode(outputs[0], skip_special_tokens=True)
90
-
91
- # Validate output structure
92
- required_sections = ["HOOK:", "MAIN:", "CTA:"]
93
- if not all(section in script for section in required_sections):
94
- logger.warning("Script missing sections, adding template")
95
- script += "\n\n[0:00-0:10] HOOK: [Add attention-grabber]\n[0:10-0:50] MAIN: [Add content]\n[0:50-1:00] CTA: [Add action]"
96
-
97
- return script
98
-
99
- except Exception as e:
100
- logger.error(f"Generation failed: {str(e)}")
101
- return f"Error: {str(e)}"
102
-
103
  app = FastAPI()
104
 
105
  @app.on_event("startup")
@@ -112,38 +70,40 @@ async def predict(request: Request):
112
  data = await request.json()
113
  topic = data.get("topic", "")
114
 
115
- # Handle n8n's list input format
116
  if isinstance(topic, list):
117
  topic = topic[0] if len(topic) > 0 else ""
118
  topic = str(topic).strip()
119
 
120
- if not topic:
121
- return JSONResponse(
122
- {"success": False, "error": "Empty topic"},
123
- status_code=400
124
- )
 
125
 
126
- logger.info(f"Generating script for: {topic[:30]}...")
127
- result = generate_script(topic)
 
 
128
 
129
- return JSONResponse({
130
- "success": not result.startswith("Error"),
131
- "result": result,
132
- "error": None if not result.startswith("Error") else result
133
- })
134
 
135
  except Exception as e:
136
- logger.error(f"API error: {str(e)}")
137
- return JSONResponse(
138
- {"success": False, "error": str(e)},
139
- status_code=500
140
- )
141
 
142
  if __name__ == "__main__":
 
 
 
 
 
143
  uvicorn.run(
144
  app,
145
  host="0.0.0.0",
146
  port=PORT,
147
  log_level="info",
 
148
  timeout_keep_alive=30
149
  )
 
9
  # Configure logging
10
  logging.basicConfig(
11
  level=logging.INFO,
12
+ format='%(asctime)s - %(levelname)s - %(message)s'
 
 
 
 
13
  )
14
  logger = logging.getLogger(__name__)
15
 
16
  # Configuration
17
  MODEL_ID = "google/gemma-1.1-2b-it"
18
  HF_TOKEN = os.getenv("HF_TOKEN", "")
19
+ MAX_TOKENS = 400
20
  DEVICE = "cpu"
21
  PORT = int(os.getenv("PORT", 7860))
22
 
 
41
  low_cpu_mem_usage=True
42
  ).to(DEVICE)
43
 
 
44
  self.generation_config = GenerationConfig(
45
  max_new_tokens=MAX_TOKENS,
46
  do_sample=True,
47
+ top_p=0.9,
48
  num_beams=1,
49
  no_repeat_ngram_size=2,
50
  pad_token_id=self.tokenizer.eos_token_id
51
  )
52
 
53
  self.loaded = True
54
+ logger.info("Model loaded | Port: %s", PORT)
55
  except Exception as e:
56
+ logger.error("Load failed: %s", str(e))
57
  raise
58
 
59
  generator = ScriptGenerator()
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  app = FastAPI()
62
 
63
  @app.on_event("startup")
 
70
  data = await request.json()
71
  topic = data.get("topic", "")
72
 
 
73
  if isinstance(topic, list):
74
  topic = topic[0] if len(topic) > 0 else ""
75
  topic = str(topic).strip()
76
 
77
+ logger.info("Processing: %.30s...", topic)
78
+
79
+ inputs = generator.tokenizer(
80
+ f"Create 1-minute script about {topic}:\n1) Hook\n2) Main\n3) CTA\n\nScript:",
81
+ return_tensors="pt"
82
+ ).to(DEVICE)
83
 
84
+ outputs = generator.model.generate(
85
+ **inputs,
86
+ generation_config=generator.generation_config
87
+ )
88
 
89
+ script = generator.tokenizer.decode(outputs[0], skip_special_tokens=True)
90
+ return JSONResponse({"result": script})
 
 
 
91
 
92
  except Exception as e:
93
+ logger.error("API error: %s", str(e))
94
+ return JSONResponse({"error": str(e)}, status_code=500)
 
 
 
95
 
96
  if __name__ == "__main__":
97
+ # Hugging Face Spaces compatibility
98
+ if os.getenv("SPACES", "false").lower() == "true":
99
+ os.environ["GRADIO_SERVER_PORT"] = str(PORT)
100
+ os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0"
101
+
102
  uvicorn.run(
103
  app,
104
  host="0.0.0.0",
105
  port=PORT,
106
  log_level="info",
107
+ workers=1,
108
  timeout_keep_alive=30
109
  )