rairo commited on
Commit
19ee271
·
verified ·
1 Parent(s): 6862917

Update sozo_gen.py

Browse files
Files changed (1) hide show
  1. sozo_gen.py +135 -116
sozo_gen.py CHANGED
@@ -86,121 +86,23 @@ def clean_narration(txt: str) -> str:
86
 
87
  def placeholder_img() -> Image.Image: return Image.new("RGB", (WIDTH, HEIGHT), (230, 230, 230))
88
 
89
- # In sozo_gen.py, add these new functions at the end of the file
90
-
91
- def generate_image_with_gemini(prompt: str) -> Image.Image:
92
- """Generates an image using the specified Gemini model and client configuration."""
93
- logging.info(f"Generating Gemini image with prompt: '{prompt}'")
94
- try:
95
- # Use the genai.Client as per the correct implementation
96
- client = genai.Client(api_key=API_KEY)
97
- full_prompt = f"A professional, 3d digital art style illustration for a business presentation: {prompt}"
98
-
99
- response = client.models.generate_content(
100
- model="gemini-2.0-flash-exp",
101
- contents=full_prompt,
102
- config=genai_types.GenerateContentConfig(
103
- response_modalities=["Text", "Image"]
104
- ),
105
- )
106
-
107
- # Find the image part in the response
108
- img_part = next((part for part in response.candidates[0].content.parts if part.content_type == "Image"), None)
109
-
110
- if img_part:
111
- # The content is already bytes, so we can open it directly
112
- return Image.open(io.BytesIO(img_part.content)).convert("RGB")
113
- else:
114
- logging.error("Gemini response did not contain an image.")
115
- return None
116
- except Exception as e:
117
- logging.error(f"Gemini image generation failed: {e}")
118
- return None
119
-
120
- def generate_slides_from_report(raw_md: str, chart_urls: dict, uid: str, project_id: str, bucket, llm):
121
- """
122
- Uses an AI planner to convert a report into a 10-slide presentation deck.
123
- """
124
- logging.info(f"Generating slides for project {project_id}")
125
-
126
- planner_prompt = f"""
127
- You are an expert presentation designer. Your task is to convert the following data analysis report into a concise and visually engaging 10-slide deck.
128
-
129
- **Full Report Content:**
130
- ---
131
- {raw_md}
132
- ---
133
-
134
- **Instructions:**
135
- 1. Read the entire report to understand the core narrative and key findings.
136
- 2. Create a plan for exactly 10 slides.
137
- 3. For each slide, define a `title` and short `content` (2-3 bullet points or a brief paragraph).
138
- 4. For the visual on each slide, you must decide between two types:
139
- - If a report section is supported by an existing chart (indicated by a `<generate_chart:...>` tag), you **must** use it. Set `visual_type: "existing_chart"` and `visual_ref: "the exact chart description from the tag"`.
140
- - For key points without a chart (like introductions, conclusions, or text-only insights), you **must** request a new image. Set `visual_type: "new_image"` and `visual_ref: "a concise, descriptive prompt for an AI to generate a 3D digital art style illustration"`.
141
- 5. You must request exactly 3-4 new images to balance the presentation.
142
-
143
- **Output Format:**
144
- Return ONLY a valid JSON array of 10 slide objects. Do not include any other text or markdown formatting.
145
-
146
- Example:
147
- [
148
- {{ "slide_number": 1, "title": "Introduction", "content": "...", "visual_type": "new_image", "visual_ref": "A 3D illustration of a rising stock chart" }},
149
- {{ "slide_number": 2, "title": "Sales by Region", "content": "...", "visual_type": "existing_chart", "visual_ref": "bar | Sales by Region" }},
150
- ...
151
- ]
152
- """
153
-
154
- try:
155
- plan_response = llm.invoke(planner_prompt).content.strip()
156
- if plan_response.startswith("```json"):
157
- plan_response = plan_response[7:-3]
158
- slide_plan = json.loads(plan_response)
159
- except Exception as e:
160
- logging.error(f"Failed to generate or parse slide plan: {e}")
161
- return None
162
-
163
- final_slides = []
164
- for slide in slide_plan:
165
- try:
166
- image_url = None
167
- visual_type = slide.get("visual_type")
168
- visual_ref = slide.get("visual_ref")
169
-
170
- if visual_type == "existing_chart":
171
- sanitized_ref = sanitize_for_firebase_key(visual_ref)
172
- image_url = chart_urls.get(sanitized_ref)
173
- if not image_url:
174
- logging.warning(f"Could not find existing chart for ref: '{visual_ref}' (sanitized: '{sanitized_ref}')")
175
-
176
- elif visual_type == "new_image":
177
- img = generate_image_with_gemini(visual_ref)
178
- if img:
179
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
180
- img_path = Path(temp_file.name)
181
- img.save(img_path, format="PNG")
182
-
183
- blob_name = f"sozo_projects/{uid}/{project_id}/slides/slide_{uuid.uuid4().hex}.png"
184
- blob = bucket.blob(blob_name)
185
- blob.upload_from_filename(str(img_path))
186
- image_url = blob.public_url
187
- logging.info(f"Uploaded new slide image to {image_url}")
188
- os.unlink(img_path)
189
-
190
- if not image_url:
191
- logging.warning(f"Visual generation failed for slide {slide.get('slide_number')}. Skipping visual for this slide.")
192
-
193
- final_slides.append({
194
- "slide_number": slide.get("slide_number"),
195
- "title": slide.get("title"),
196
- "content": slide.get("content"),
197
- "image_url": image_url or ""
198
- })
199
- except Exception as slide_e:
200
- logging.error(f"Failed to process slide {slide.get('slide_number')}: {slide_e}")
201
- continue
202
-
203
- return final_slides
204
 
205
  # NEW: Keyword extraction for better Pexels searches
206
  def extract_keywords_for_query(text: str, llm) -> str:
@@ -783,4 +685,121 @@ def generate_video_from_project(df: pd.DataFrame, raw_md: str, data_context: Dic
783
  if os.path.exists(p): os.unlink(p)
784
 
785
  return blob.public_url
786
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  def placeholder_img() -> Image.Image: return Image.new("RGB", (WIDTH, HEIGHT), (230, 230, 230))
88
 
89
+
90
+ def detect_dataset_domain(df: pd.DataFrame) -> str:
91
+ """Analyzes column names to detect the dataset's primary domain."""
92
+ domain_keywords = {
93
+ "health insurance": ["charges", "bmi", "smoker", "beneficiary"],
94
+ "finance": ["revenue", "profit", "cost", "budget", "expense", "stock"],
95
+ "marketing": ["campaign", "conversion", "click", "customer", "segment"],
96
+ "survey": ["satisfaction", "rating", "feedback", "opinion", "score"],
97
+ "food": ["nutrition", "calories", "ingredients", "restaurant"]
98
+ }
99
+ columns_lower = [col.lower() for col in df.columns]
100
+ for domain, keywords in domain_keywords.items():
101
+ if any(keyword in col for col in columns_lower for keyword in keywords):
102
+ logging.info(f"Dataset domain detected: {domain}")
103
+ return domain
104
+ logging.info("No specific dataset domain detected, using generic terms.")
105
+ return "data"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  # NEW: Keyword extraction for better Pexels searches
108
  def extract_keywords_for_query(text: str, llm) -> str:
 
685
  if os.path.exists(p): os.unlink(p)
686
 
687
  return blob.public_url
688
+ return None
689
+
690
+
691
+ # In sozo_gen.py, add these new functions at the end of the file
692
+
693
+ def generate_image_with_gemini(prompt: str) -> Image.Image:
694
+ """Generates an image using the specified Gemini model and client configuration."""
695
+ logging.info(f"Generating Gemini image with prompt: '{prompt}'")
696
+ try:
697
+ # Use the genai.Client as per the correct implementation
698
+ client = genai.Client(api_key=API_KEY)
699
+ full_prompt = f"A professional, 3d digital art style illustration for a business presentation: {prompt}"
700
+
701
+ response = client.models.generate_content(
702
+ model="gemini-2.0-flash-exp",
703
+ contents=full_prompt,
704
+ config=genai_types.GenerateContentConfig(
705
+ response_modalities=["Text", "Image"]
706
+ ),
707
+ )
708
+
709
+ # Find the image part in the response
710
+ img_part = next((part for part in response.candidates[0].content.parts if part.content_type == "Image"), None)
711
+
712
+ if img_part:
713
+ # The content is already bytes, so we can open it directly
714
+ return Image.open(io.BytesIO(img_part.content)).convert("RGB")
715
+ else:
716
+ logging.error("Gemini response did not contain an image.")
717
+ return None
718
+ except Exception as e:
719
+ logging.error(f"Gemini image generation failed: {e}")
720
+ return None
721
+
722
+ def generate_slides_from_report(raw_md: str, chart_urls: dict, uid: str, project_id: str, bucket, llm):
723
+ """
724
+ Uses an AI planner to convert a report into a 10-slide presentation deck.
725
+ """
726
+ logging.info(f"Generating slides for project {project_id}")
727
+
728
+ planner_prompt = f"""
729
+ You are an expert presentation designer. Your task is to convert the following data analysis report into a concise and visually engaging 10-slide deck.
730
+
731
+ **Full Report Content:**
732
+ ---
733
+ {raw_md}
734
+ ---
735
+
736
+ **Instructions:**
737
+ 1. Read the entire report to understand the core narrative and key findings.
738
+ 2. Create a plan for exactly 10 slides.
739
+ 3. For each slide, define a `title` and short `content` (2-3 bullet points or a brief paragraph).
740
+ 4. For the visual on each slide, you must decide between two types:
741
+ - If a report section is supported by an existing chart (indicated by a `<generate_chart:...>` tag), you **must** use it. Set `visual_type: "existing_chart"` and `visual_ref: "the exact chart description from the tag"`.
742
+ - For key points without a chart (like introductions, conclusions, or text-only insights), you **must** request a new image. Set `visual_type: "new_image"` and `visual_ref: "a concise, descriptive prompt for an AI to generate a 3D digital art style illustration"`.
743
+ 5. You must request exactly 3-4 new images to balance the presentation.
744
+
745
+ **Output Format:**
746
+ Return ONLY a valid JSON array of 10 slide objects. Do not include any other text or markdown formatting.
747
+
748
+ Example:
749
+ [
750
+ {{ "slide_number": 1, "title": "Introduction", "content": "...", "visual_type": "new_image", "visual_ref": "A 3D illustration of a rising stock chart" }},
751
+ {{ "slide_number": 2, "title": "Sales by Region", "content": "...", "visual_type": "existing_chart", "visual_ref": "bar | Sales by Region" }},
752
+ ...
753
+ ]
754
+ """
755
+
756
+ try:
757
+ plan_response = llm.invoke(planner_prompt).content.strip()
758
+ if plan_response.startswith("```json"):
759
+ plan_response = plan_response[7:-3]
760
+ slide_plan = json.loads(plan_response)
761
+ except Exception as e:
762
+ logging.error(f"Failed to generate or parse slide plan: {e}")
763
+ return None
764
+
765
+ final_slides = []
766
+ for slide in slide_plan:
767
+ try:
768
+ image_url = None
769
+ visual_type = slide.get("visual_type")
770
+ visual_ref = slide.get("visual_ref")
771
+
772
+ if visual_type == "existing_chart":
773
+ sanitized_ref = sanitize_for_firebase_key(visual_ref)
774
+ image_url = chart_urls.get(sanitized_ref)
775
+ if not image_url:
776
+ logging.warning(f"Could not find existing chart for ref: '{visual_ref}' (sanitized: '{sanitized_ref}')")
777
+
778
+ elif visual_type == "new_image":
779
+ img = generate_image_with_gemini(visual_ref)
780
+ if img:
781
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
782
+ img_path = Path(temp_file.name)
783
+ img.save(img_path, format="PNG")
784
+
785
+ blob_name = f"sozo_projects/{uid}/{project_id}/slides/slide_{uuid.uuid4().hex}.png"
786
+ blob = bucket.blob(blob_name)
787
+ blob.upload_from_filename(str(img_path))
788
+ image_url = blob.public_url
789
+ logging.info(f"Uploaded new slide image to {image_url}")
790
+ os.unlink(img_path)
791
+
792
+ if not image_url:
793
+ logging.warning(f"Visual generation failed for slide {slide.get('slide_number')}. Skipping visual for this slide.")
794
+
795
+ final_slides.append({
796
+ "slide_number": slide.get("slide_number"),
797
+ "title": slide.get("title"),
798
+ "content": slide.get("content"),
799
+ "image_url": image_url or ""
800
+ })
801
+ except Exception as slide_e:
802
+ logging.error(f"Failed to process slide {slide.get('slide_number')}: {slide_e}")
803
+ continue
804
+
805
+ return final_slides