Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| from fastapi import FastAPI, File, Form, UploadFile | |
| from fastapi.responses import PlainTextResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import time | |
| # Import your custom PyPI library | |
| from graphvision import GraphExtractor | |
| app = FastAPI(title="STEM Sight Backend") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allows any browser extension to connect | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize your custom PyPI library | |
| print("Initializing STEM Sight Vision Engine...") | |
| vision_engine = GraphExtractor() | |
| async def root(): | |
| return {"message": "STEM Sight API is online and ready."} | |
| def generate_audio_summary(extraction_result: dict) -> str: | |
| """ | |
| Hardcoded logic to generate a conversational summary from graph data | |
| without relying on an external LLM. | |
| """ | |
| chart_type = extraction_result.get("chart_type", "unknown").lower() | |
| title = extraction_result.get("title") | |
| title_text = f"titled {title}" if title else "without a specific title" | |
| # --- 1. PIE CHART LOGIC --- | |
| if chart_type == "pie": | |
| data = extraction_result.get("data", {}) | |
| if not data: | |
| return f"This is a pie chart {title_text}, but no data could be extracted." | |
| max_cat = max(data, key=data.get) | |
| min_cat = min(data, key=data.get) | |
| summary = ( | |
| f"This is a pie chart {title_text}. " | |
| f"The largest portion is {max_cat} at {data[max_cat]}. " | |
| f"The smallest portion is {min_cat} at {data[min_cat]}." | |
| ) | |
| return summary | |
| # --- 2. BAR CHART LOGIC (HBAR & VBAR) --- | |
| elif chart_type in ["hbar_categorical", "vbar_categorical", "hbar", "vbar"]: | |
| data = extraction_result.get("data", []) | |
| x_label = extraction_result.get("x_axis_label", "the X axis") | |
| y_label = extraction_result.get("y_axis_label", "the Y axis") | |
| if not data: | |
| return f"This is a bar chart {title_text}, but no data could be extracted." | |
| max_item = max(data, key=lambda d: d.get("value", 0)) | |
| min_item = min(data, key=lambda d: d.get("value", 0)) | |
| summary = ( | |
| f"This is a bar chart {title_text}, showing {y_label} against {x_label}. " | |
| f"The highest value is {max_item.get('category')} at {max_item.get('value')}. " | |
| f"The lowest value is {min_item.get('category')} at {min_item.get('value')}. " | |
| ) | |
| # Filter out the max and min items so we don't repeat them | |
| other_items = [item for item in data if item != max_item and item != min_item] | |
| if other_items: | |
| # Join the remaining items with a comma so the text-to-speech engine adds a slight pause | |
| other_points_text = ", ".join([f"{item.get('category')} at {item.get('value')}" for item in other_items]) | |
| summary += f"The other values are: {other_points_text}." | |
| return summary | |
| # --- 3. DOT / LINE CHART LOGIC --- | |
| elif chart_type == "dot_line": | |
| data = extraction_result.get("data", []) | |
| x_label = extraction_result.get("x_axis_label", "the X axis") | |
| y_label = extraction_result.get("y_axis_label", "the Y axis") | |
| total_points = extraction_result.get("total_points", len(data)) | |
| if not data: | |
| return f"This is a line chart {title_text}, but no data could be extracted." | |
| # Group data points by their category (class) | |
| categories = {} | |
| for item in data: | |
| cat_name = item.get("class", "unknown") | |
| if cat_name not in categories: | |
| categories[cat_name] = [] | |
| categories[cat_name].append(item) | |
| classes = list(categories.keys()) | |
| # Format the classes cleanly for the introductory sentence | |
| classes_text = ", ".join(classes[:3]) | |
| if len(classes) > 3: | |
| classes_text += f", and {len(classes) - 3} other categories" | |
| # Introductory overview | |
| summary = ( | |
| f"This is a scatter plot {title_text}, with {x_label} on the X axis and {y_label} on the Y axis. " | |
| f"It shows {total_points} data points across categories like {classes_text}. " | |
| ) | |
| # Calculate and append max/min for each category | |
| category_summaries = [] | |
| for cat_name, points in categories.items(): | |
| max_item = max(points, key=lambda d: d.get("y", 0)) | |
| min_item = min(points, key=lambda d: d.get("y", 0)) | |
| # Using a predictable sentence structure that includes the X coordinate | |
| category_summaries.append( | |
| f"For {cat_name}, the highest value is {max_item.get('y')} when X is {max_item.get('x')}, " | |
| f"and the lowest value is {min_item.get('y')} when X is {min_item.get('x')}." | |
| ) | |
| # Join the category breakdowns with spaces so they read as separate sentences | |
| summary += " ".join(category_summaries) | |
| return summary | |
| elif chart_type == "line": | |
| # The Donut model already generated the perfect text summary! | |
| summary = extraction_result.get("summary", "") | |
| if not summary: | |
| return f"This is a line chart {title_text}, but the Vision Engine could not generate a summary." | |
| return summary | |
| # --- FALLBACK --- | |
| else: | |
| return f"Data has been extracted for a {chart_type} chart, but the summary feature for this specific format is not available." | |
| async def analyze_graph(file: UploadFile = File(...)): | |
| try: | |
| start_time = time.time() | |
| # 1. Save the uploaded image temporarily | |
| temp_image_path = f"temp_{file.filename}" | |
| with open(temp_image_path, "wb") as buffer: | |
| buffer.write(await file.read()) | |
| print(f"⏱️ Image received and saved in: {time.time() - start_time:.2f} seconds") | |
| # 2. Extract structured data | |
| extract_start = time.time() | |
| print(f"Extracting data from {file.filename}...") | |
| extraction_json_string = vision_engine.extract(temp_image_path) | |
| print(f"⏱️ AI Extraction finished in: {time.time() - extract_start:.2f} seconds") | |
| if os.path.exists(temp_image_path): | |
| os.remove(temp_image_path) | |
| extraction_result = json.loads(extraction_json_string) | |
| print(f"Extracted data: {extraction_result}") | |
| if "error" in extraction_result: | |
| return f"I am sorry, I could not clearly identify the data in this graph. Reason: {extraction_result['error']}" | |
| # 3. Generate summary using hardcoded logic instead of Groq | |
| audio_script = generate_audio_summary(extraction_result) | |
| print(f"✅ TOTAL TIME: {time.time() - start_time:.2f} seconds") | |
| return audio_script | |
| except Exception as e: | |
| return f"An error occurred while analyzing the graph: {str(e)}" | |
| async def ask_chart_rule_based( | |
| file: UploadFile = File(...), | |
| question: str = Form(...) | |
| ): | |
| # 1. Extract JSON using GraphVision | |
| temp_image_path = f"temp_qa_{file.filename}" | |
| with open(temp_image_path, "wb") as buffer: | |
| buffer.write(await file.read()) | |
| extraction_json_string = vision_engine.extract(temp_image_path) | |
| os.remove(temp_image_path) | |
| extraction_result = json.loads(extraction_json_string) | |
| chart_type = extraction_result.get("chart_type", "unknown").lower() | |
| if chart_type == "line": | |
| summary = extraction_result.get("summary") | |
| if summary: | |
| return summary | |
| else: | |
| return "I couldn't extract a summary from this line chart." | |
| data = extraction_result.get("data") | |
| if not data: | |
| return "I couldn't extract data from this chart. Please ensure the image is clear." | |
| # 2. Pre-process Data and Question | |
| question_lower = question.lower() | |
| # Dynamically find all categories available in this specific chart | |
| available_categories = [] | |
| if isinstance(data, dict): # It's a Pie Chart | |
| available_categories = list(data.keys()) | |
| elif isinstance(data, list): # It's a Bar, Line, or Scatter plot | |
| for item in data: | |
| cat = item.get("category", item.get("class")) | |
| if cat and cat not in available_categories: | |
| available_categories.append(cat) | |
| # Check if the user is asking about a SPECIFIC category | |
| target_category = None | |
| for cat in available_categories: | |
| if cat.lower() in question_lower: | |
| target_category = cat | |
| break # Found the category they are asking about | |
| # If they asked for a specific category in a list-based chart, filter the data! | |
| filtered_data = data | |
| if target_category and isinstance(data, list): | |
| filtered_data = [item for item in data if item.get("category", item.get("class")) == target_category] | |
| # 3. Rule-Based Intent Routing | |
| # Intent 1: Asking for the highest/maximum | |
| if any(word in question_lower for word in ["highest", "maximum", "most", "largest", "top"]): | |
| if isinstance(filtered_data, dict): # Pie Charts | |
| max_cat = max(filtered_data, key=filtered_data.get) | |
| val = filtered_data[max_cat] | |
| return f"Based on the extracted data, the highest is {max_cat} with a value of {val}." | |
| elif isinstance(filtered_data, list): # Bar/Line/Scatter | |
| max_item = max(filtered_data, key=lambda d: d.get("value", d.get("y", 0))) | |
| cat = max_item.get("category", max_item.get("class", "unknown")) | |
| val = max_item.get("value", max_item.get("y")) | |
| # If they asked for a specific category in a scatter plot, include the X coordinate | |
| if target_category: | |
| x_val = max_item.get("x") | |
| if x_val is not None: | |
| return f"For the {target_category} category, the highest value is {val} when X is {x_val}." | |
| return f"For the {target_category} category, the highest value is {val}." | |
| else: | |
| return f"Based on the extracted data, the overall highest is {cat} with a value of {val}." | |
| # Intent 2: Asking for the lowest/minimum | |
| elif any(word in question_lower for word in ["lowest", "minimum", "least", "smallest", "bottom"]): | |
| if isinstance(filtered_data, dict): # Pie Charts | |
| min_cat = min(filtered_data, key=filtered_data.get) | |
| val = filtered_data[min_cat] | |
| return f"Based on the extracted data, the lowest is {min_cat} with a value of {val}." | |
| elif isinstance(filtered_data, list): # Bar/Line/Scatter | |
| min_item = min(filtered_data, key=lambda d: d.get("value", d.get("y", 0))) | |
| cat = min_item.get("category", min_item.get("class", "unknown")) | |
| val = min_item.get("value", min_item.get("y")) | |
| if target_category: | |
| x_val = min_item.get("x") | |
| if x_val is not None: | |
| return f"For the {target_category} category, the lowest value is {val} when X is {x_val}." | |
| return f"For the {target_category} category, the lowest value is {val}." | |
| else: | |
| return f"Based on the extracted data, the overall lowest is {cat} with a value of {val}." | |
| # Intent 3: Asking for a specific category's value (General Lookup) | |
| elif target_category: | |
| if isinstance(data, dict): # Pie Charts | |
| val = data[target_category] | |
| return f"Based on the extracted data, the value for {target_category} is {val}." | |
| elif isinstance(filtered_data, list): # Bar charts | |
| if len(filtered_data) == 1: | |
| val = filtered_data[0].get("value", filtered_data[0].get("y")) | |
| return f"Based on the extracted data, the value for {target_category} is {val}." | |
| else: | |
| # If there are multiple values (like a line chart), tell them to be more specific | |
| return f"The category {target_category} has {len(filtered_data)} different data points. Please ask for the highest or lowest value for this category." | |
| # Intent 4: Fallback | |
| return "I am sorry, I do not understand the question. Please ask for the highest value, the lowest value, or ask about a specific category." | |
| # import os | |
| # import json | |
| # from fastapi import FastAPI, File, UploadFile | |
| # from fastapi.responses import PlainTextResponse | |
| # from fastapi.middleware.cors import CORSMiddleware | |
| # from groq import Groq | |
| # import time | |
| # # Import your newly updated PyPI library! | |
| # from graphvision import GraphExtractor | |
| # app = FastAPI(title="STEM Sight Backend") | |
| # app.add_middleware( | |
| # CORSMiddleware, | |
| # allow_origins=["*"], # Allows any browser extension to connect | |
| # allow_credentials=True, | |
| # allow_methods=["*"], | |
| # allow_headers=["*"], | |
| # ) | |
| # # Initialize the Groq Client (Looks for the GROQ_API_KEY environment variable) | |
| # groq_client = Groq() | |
| # # Initialize your custom PyPI library | |
| # print("Initializing STEM Sight Vision Engine...") | |
| # vision_engine = GraphExtractor() | |
| # @app.get("/") | |
| # async def root(): | |
| # return {"message": "STEM Sight API is online and ready."} | |
| # @app.post("/analyze-graph", response_class=PlainTextResponse) | |
| # async def analyze_graph(file: UploadFile = File(...)): | |
| # try: | |
| # start_time = time.time() | |
| # # 1. Save the uploaded image temporarily | |
| # temp_image_path = f"temp_{file.filename}" | |
| # with open(temp_image_path, "wb") as buffer: | |
| # buffer.write(await file.read()) | |
| # print(f"⏱️ Image received and saved in: {time.time() - start_time:.2f} seconds") | |
| # # 2. Extract structured data | |
| # extract_start = time.time() | |
| # print(f"Extracting data from {file.filename}...") | |
| # extraction_json_string = vision_engine.extract(temp_image_path) | |
| # print(f"⏱️ AI Extraction finished in: {time.time() - extract_start:.2f} seconds") | |
| # if os.path.exists(temp_image_path): | |
| # os.remove(temp_image_path) | |
| # extraction_result = json.loads(extraction_json_string) | |
| # print(f"Extracted data: {extraction_result}") | |
| # if "error" in extraction_result: | |
| # return f"I'm sorry, I couldn't clearly identify the data in this graph. Reason: {extraction_result['error']}" | |
| # graph_type = extraction_result.get("chart_type", "unknown") | |
| # graph_data = extraction_result.get("data", []) | |
| # x_label = extraction_result.get("x_axis_label", "Unknown X-Axis") | |
| # y_label = extraction_result.get("y_axis_label", "Unknown Y-Axis") | |
| # title = extraction_result.get("title", "Untitled Graph") | |
| # prompt = f""" | |
| # You are an accessibility assistant for visually impaired students. | |
| # I am giving you extracted data from a {graph_type} chart. | |
| # Title: {title} | |
| # X-Axis Label: {x_label} | |
| # Y-Axis Label: {y_label} | |
| # Please summarize this data in one short, conversational, and easy-to-understand paragraph. | |
| # Point out the largest and smallest values if relevant. | |
| # Do not use markdown, bold text, or asterisks. Write it exactly as it should be spoken out loud by a text-to-speech engine. | |
| # Data: | |
| # {graph_data} | |
| # """ | |
| # # 3. Send to Groq | |
| # groq_start = time.time() | |
| # print("Generating audio script with Groq Llama 3...") | |
| # chat_completion = groq_client.chat.completions.create( | |
| # messages=[{"role": "user", "content": prompt}], | |
| # model="llama-3.1-8b-instant", | |
| # temperature=0.4, | |
| # ) | |
| # print(f"⏱️ Groq Llama 3 finished in: {time.time() - groq_start:.2f} seconds") | |
| # print(f"✅ TOTAL TIME: {time.time() - start_time:.2f} seconds") | |
| # return chat_completion.choices[0].message.content.strip() | |
| # except Exception as e: | |
| # return f"An error occurred while analyzing the graph: {str(e)}" |