Spaces:
Sleeping
Sleeping
| # app.py | |
| import gradio as gr | |
| import pandas as pd | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| # 1) Load your synthetic profitability dataset | |
| df = pd.read_csv('synthetic_profit.csv') | |
| # 2) Ensure numeric columns for true aggregation | |
| for col in ["Revenue", "Profit", "ProfitMargin"]: | |
| df[col] = pd.to_numeric(df[col], errors='coerce') | |
| # 3) Build the schema description text | |
| # ← replaced .iteritems() with .items() here | |
| schema_lines = [f"- {col}: {dtype.name}" for col, dtype in df.dtypes.items()] | |
| schema_text = "Table schema:\n" + "\n".join(schema_lines) | |
| # 4) Few-shot examples teaching SUM and AVERAGE patterns | |
| example_block = """ | |
| Example 1 | |
| Q: Total profit by region? | |
| A: Group “Profit” by “Region” and sum → EMEA: 30172183.37; APAC: 32301788.32; Latin America: 27585378.50; North America: 25473893.34 | |
| Example 2 | |
| Q: Average profit margin for Product B in Americas? | |
| A: Filter Product=B & Region=Americas, take mean of “ProfitMargin” → 0.18 | |
| """.strip() | |
| # 5) Model & pipeline setup | |
| MODEL_ID = "microsoft/tapex-base-finetuned-wikisql" | |
| device = 0 if torch.cuda.is_available() else -1 | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID) | |
| table_qa = pipeline( | |
| "table-question-answering", | |
| model=model, | |
| tokenizer=tokenizer, | |
| framework="pt", | |
| device=device, | |
| ) | |
| # 6) QA function with schema-aware prompting | |
| def answer_profitability(question: str) -> str: | |
| # cast all cells to string for safety | |
| table = df.astype(str).to_dict(orient="records") | |
| # assemble the full prompt | |
| prompt = f"""{schema_text} | |
| {example_block} | |
| Q: {question} | |
| A:""" | |
| try: | |
| out = table_qa(table=table, query=prompt) | |
| return out.get("answer", "No answer found.") | |
| except Exception as e: | |
| return f"Error: {e}" | |
| # 7) Gradio interface | |
| iface = gr.Interface( | |
| fn=answer_profitability, | |
| inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…"), | |
| outputs="text", | |
| title="SAP Profitability Q&A (Schema-Aware TAPEX)", | |
| description=( | |
| "Every query is prefixed with your table’s schema and two few-shot examples, " | |
| "so the model learns to SUM, AVERAGE, FILTER, etc., without extra Python code." | |
| ) | |
| ) | |
| # 8) Launch the app | |
| if __name__ == "__main__": | |
| iface.launch(server_name="0.0.0.0", server_port=7860) | |