Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from data_processor import DataProcessor | |
| from chart_generator import ChartGenerator | |
| from image_verifier import ImageVerifier | |
| from huggingface_hub import login | |
| import logging | |
| import time | |
| import os | |
| from dotenv import load_dotenv | |
| import ast | |
| import requests | |
| import json | |
| load_dotenv() | |
| class LLM_Agent: | |
| def __init__(self, data_path=None): | |
| logging.info("Initializing LLM_Agent") | |
| self.data_processor = DataProcessor(data_path) | |
| self.chart_generator = ChartGenerator(self.data_processor.data) | |
| self.image_verifier = ImageVerifier() | |
| # Use Hugging Face Hub model path for fine-tuned model | |
| model_path = "ArchCoder/fine-tuned-bart-large" | |
| self.query_tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| self.query_model = AutoModelForSeq2SeqLM.from_pretrained(model_path) | |
| def validate_plot_args(plot_args): | |
| required_keys = ['x', 'y', 'chart_type'] | |
| if not all(key in plot_args for key in required_keys): | |
| return False | |
| if not isinstance(plot_args['y'], list): | |
| plot_args['y'] = [plot_args['y']] | |
| return True | |
| def process_request(self, data): | |
| start_time = time.time() | |
| logging.info(f"Processing request data: {data}") | |
| query = data.get('query', '') | |
| data_path = data.get('file_path') | |
| model_choice = data.get('model', 'bart') | |
| # Log file path and check existence | |
| if data_path: | |
| logging.info(f"Data path received: {data_path}") | |
| import os | |
| if not os.path.exists(data_path): | |
| logging.error(f"File does not exist at path: {data_path}") | |
| else: | |
| logging.info(f"File exists at path: {data_path}") | |
| # Re-initialize data processor and chart generator if a file is specified | |
| if data_path: | |
| self.data_processor = DataProcessor(data_path) | |
| # Log loaded columns | |
| loaded_columns = self.data_processor.get_columns() | |
| logging.info(f"Loaded columns from data: {loaded_columns}") | |
| self.chart_generator = ChartGenerator(self.data_processor.data) | |
| # Enhanced prompt for better model responses | |
| enhanced_prompt = ( | |
| "You are VizBot, an expert data visualization assistant. " | |
| "Given a user's natural language request about plotting data, output ONLY a valid Python dictionary with keys: x, y, chart_type, and color (if specified). " | |
| "Do not include any explanation or extra text.\n\n" | |
| "Example 1:\n" | |
| "User: plot the sales in the years with red line\n" | |
| "Output: {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line', 'color': 'red'}\n\n" | |
| "Example 2:\n" | |
| "User: show employee expenses and net profit over the years\n" | |
| "Output: {'x': 'Year', 'y': ['Employee expense', 'Net profit'], 'chart_type': 'line'}\n\n" | |
| "Example 3:\n" | |
| "User: display the EBITDA for each year with a blue bar\n" | |
| "Output: {'x': 'Year', 'y': ['EBITDA'], 'chart_type': 'bar', 'color': 'blue'}\n\n" | |
| f"User: {query}\nOutput:" | |
| ) | |
| try: | |
| if model_choice == 'bart': | |
| # Use local fine-tuned BART model | |
| inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True) | |
| outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1) | |
| response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| elif model_choice == 'flan-t5-base': | |
| # Use Hugging Face Inference API with Flan-T5-Base model | |
| api_url = "https://api-inference.huggingface.co/models/google/flan-t5-base" | |
| headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}"} | |
| payload = {"inputs": enhanced_prompt} | |
| response = requests.post(api_url, headers=headers, json=payload, timeout=30) | |
| if response.status_code != 200: | |
| logging.error(f"Hugging Face API error: {response.status_code} {response.text}") | |
| response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}" | |
| else: | |
| try: | |
| resp_json = response.json() | |
| response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '') | |
| if not response_text: | |
| response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}" | |
| except Exception as e: | |
| logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}") | |
| response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}" | |
| elif model_choice == 'flan-ul2': | |
| # Use Hugging Face Inference API with Flan-T5-XXL model (best available) | |
| api_url = "https://api-inference.huggingface.co/models/google/flan-t5-xxl" | |
| headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}"} | |
| payload = {"inputs": enhanced_prompt} | |
| response = requests.post(api_url, headers=headers, json=payload, timeout=30) | |
| if response.status_code != 200: | |
| logging.error(f"Hugging Face API error: {response.status_code} {response.text}") | |
| response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}" | |
| else: | |
| try: | |
| resp_json = response.json() | |
| response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '') | |
| if not response_text: | |
| response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}" | |
| except Exception as e: | |
| logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}") | |
| response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}" | |
| else: | |
| # Default fallback to local fine-tuned BART model | |
| inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True) | |
| outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1) | |
| response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| logging.info(f"LLM response text: {response_text}") | |
| # Clean and parse the response | |
| response_text = response_text.strip() | |
| if response_text.startswith("```") and response_text.endswith("```"): | |
| response_text = response_text[3:-3].strip() | |
| if response_text.startswith("python"): | |
| response_text = response_text[6:].strip() | |
| try: | |
| plot_args = ast.literal_eval(response_text) | |
| except (SyntaxError, ValueError) as e: | |
| logging.warning(f"Invalid LLM response: {e}. Response: {response_text}") | |
| plot_args = {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'} | |
| if not LLM_Agent.validate_plot_args(plot_args): | |
| logging.warning("Invalid plot arguments. Using default.") | |
| plot_args = {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'} | |
| chart_path = self.chart_generator.generate_chart(plot_args) | |
| verified = self.image_verifier.verify(chart_path, query) | |
| end_time = time.time() | |
| logging.info(f"Processed request in {end_time - start_time} seconds") | |
| return { | |
| "response": response_text, | |
| "chart_path": chart_path, | |
| "verified": verified | |
| } | |
| except Exception as e: | |
| logging.error(f"Error processing request: {e}") | |
| end_time = time.time() | |
| logging.info(f"Processed request in {end_time - start_time} seconds") | |
| return { | |
| "response": f"Error: {str(e)}", | |
| "chart_path": "", | |
| "verified": False | |
| } | |