Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import json | |
| import os | |
| import uuid | |
| import spaces | |
| import re | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| from peft import PeftModel, AutoPeftModelForCausalLM | |
| import io | |
| from openpyxl import load_workbook | |
| from typing import List, Dict, Any, Tuple | |
| from utils import * | |
| # base_model_id = "NousResearch/Nous-Hermes-2-Mistral-7B-DPO" | |
| # lora_path = "tat-llm-final-e4" | |
| # base_model = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch.float16) | |
| # model = PeftModel.from_pretrained(base_model, lora_path) | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # model = model.to(device) | |
| # model.eval() | |
| # tokenizer = AutoTokenizer.from_pretrained(lora_path) | |
| def generate_answer(json_data: Dict[str, Any], question: str) -> str: | |
| """ | |
| Generate answer using the fine-tuned model. | |
| """ | |
| base_model_id = "NousResearch/Nous-Hermes-2-Mistral-7B-DPO" | |
| lora_path = "tat-llm-final-e4" | |
| # Load base model and LoRA adapter | |
| base_model = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch.float16) | |
| model = PeftModel.from_pretrained(base_model, lora_path) | |
| tokenizer = AutoTokenizer.from_pretrained(lora_path) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.to(device) | |
| model.eval() | |
| prompt = create_prompt(json_data, question) | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024) | |
| # Move to GPU if available | |
| device = next(model.parameters()).device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| input_length = inputs["input_ids"].shape[1] | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| do_sample=False, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| generated_tokens = outputs[0][input_length:] | |
| answer = tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| return answer | |
| # Gradio interface functions | |
| def process_xlsx(file): | |
| """ | |
| Process uploaded XLSX file and return JSON, JSONL, and Markdown. | |
| """ | |
| if file is None: | |
| return None, "", "", "" | |
| try: | |
| json_data = xlsx_to_json(file.name) | |
| json_str = json.dumps(json_data, indent=2, ensure_ascii=False) | |
| jsonl_str = json_to_jsonl(json_data) | |
| markdown_str = json_to_markdown(json_data) | |
| return json_data, json_str, jsonl_str, markdown_str | |
| except Exception as e: | |
| return None, f"Error: {str(e)}", "", "" | |
| def chat_interface(json_data, question, history): | |
| """ | |
| Chat interface for Q&A. | |
| """ | |
| if json_data is None: | |
| return history + [[question, "Please upload an XLSX file first."]] | |
| if not question.strip(): | |
| return history + [[question, "Please enter a question."]] | |
| try: | |
| answer = generate_answer(json_data, question) | |
| return history + [[question, answer]] | |
| except Exception as e: | |
| return history + [[question, f"Error generating answer: {str(e)}"]] | |
| # Gradio UI | |
| with gr.Blocks(title="terTATa-LLM: Dari Tabel dan Teks Menjadi Langkah Bisnis Strategis", theme=gr.themes.Soft()) as demo: | |
| gr.HTML(""" | |
| <style> | |
| body, .gradio-container { | |
| font-family: 'Poppins', sans-serif; | |
| } | |
| h1, h2, h3, h4, h5 { | |
| font-family: 'Poppins', sans-serif; | |
| } | |
| </style> | |
| <link href="https://fonts.googleapis.com/css2?family=Poppins&display=swap" rel="stylesheet"> | |
| """) | |
| gr.Markdown(""" | |
| # terTATa-LLM: Dari Tabel dan Teks Menjadi Langkah Bisnis Strategis | |
| Unggah berkas XLSX berisi tabel dan paragraf, lalu ajukan pertanyaan tentang data tersebut. | |
| Sistem akan mengonversi berkas Anda ke format JSON dan menggunakan model terTATa-LLM untuk menjawab pertanyaan. | |
| """) | |
| json_data_state = gr.State() | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file_input = gr.File( | |
| label="Upload XLSX File", | |
| file_types=[".xlsx"], | |
| type="filepath" | |
| ) | |
| process_btn = gr.Button("Process File", variant="primary") | |
| with gr.Tabs(): | |
| with gr.Tab("Markdown Preview"): | |
| markdown_output = gr.Markdown(label="Markdown Preview") | |
| with gr.Tab("JSON Output"): | |
| json_output = gr.Code( | |
| label="JSON Format", | |
| language="json", | |
| lines=15 | |
| ) | |
| with gr.Tab("JSONL Output"): | |
| jsonl_output = gr.Code( | |
| label="JSONL Format", | |
| language="json", | |
| lines=5 | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Ajukan Pertanyaan Mengenai Data Anda") | |
| chatbot = gr.Chatbot(height=400) | |
| msg = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Ajukan pertanyaan tentang data tabel...", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| clear_btn = gr.Button("Clear Chat") | |
| gr.Examples( | |
| examples=[ | |
| "Apa saja wawasan yang bisa kita ambil dari data ini?", | |
| "Bagaimana perubahan dari tahun ke tahun?", | |
| "Apa saja tren utama yang terlihat dalam data?", | |
| "Hitung persentase perubahan antar tahun!", | |
| "Rekomendasi apa yang dapat diberikan berdasarkan data ini?" | |
| ], | |
| inputs=msg | |
| ) | |
| process_btn.click( | |
| fn=process_xlsx, | |
| inputs=[file_input], | |
| outputs=[json_data_state, json_output, jsonl_output, markdown_output] | |
| ) | |
| msg.submit( | |
| fn=chat_interface, | |
| inputs=[json_data_state, msg, chatbot], | |
| outputs=[chatbot] | |
| ).then( | |
| lambda: "", | |
| outputs=[msg] | |
| ) | |
| submit_btn.click( | |
| fn=chat_interface, | |
| inputs=[json_data_state, msg, chatbot], | |
| outputs=[chatbot] | |
| ).then( | |
| lambda: "", | |
| outputs=[msg] | |
| ) | |
| clear_btn.click( | |
| lambda: [], | |
| outputs=[chatbot] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(share=True) |