Rifqidits commited on
Commit
c8a3495
·
verified ·
1 Parent(s): d80005b

Initial update for our case

Browse files
Files changed (1) hide show
  1. app.py +58 -46
app.py CHANGED
@@ -1,6 +1,4 @@
1
  # chatbot_template.py
2
-
3
- import gradio as gr
4
  import spaces
5
 
6
  DESCRIPTION = """
@@ -16,50 +14,64 @@ LICENSE = """
16
 
17
  # This is a dummy generation function
18
  @spaces.GPU # This allows it to run on GPU Spaces (remove if not needed)
19
- def generate_response(
20
- message: str,
21
- chat_history: list[dict],
22
- system_prompt: str = "",
23
- max_new_tokens: int = 512,
24
- temperature: float = 0.7,
25
- top_p: float = 0.9,
26
- top_k: int = 40,
27
- repetition_penalty: float = 1.1,
28
- ):
29
- # Replace this with actual model logic
30
- yield f"(This is a dummy response to): {message}"
31
-
32
- # Example for real model use:
33
- # tokenizer = AutoTokenizer.from_pretrained("your-model-id")
34
- # model = AutoModelForCausalLM.from_pretrained("your-model-id", device_map="auto")
35
- # input_ids = tokenizer(message, return_tensors="pt").input_ids.to(model.device)
36
- # output = model.generate(input_ids, ...)
37
- # response = tokenizer.decode(output[0], skip_special_tokens=True)
38
- # yield response
39
-
40
- chat_interface = gr.ChatInterface(
41
- fn=generate_response,
42
- additional_inputs=[
43
- gr.Textbox(label="System Prompt", lines=2, value=""),
44
- gr.Slider(label="Max new tokens", minimum=16, maximum=2048, value=512, step=16),
45
- gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, value=0.7, step=0.1),
46
- gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.05),
47
- gr.Slider(label="Top-k", minimum=1, maximum=1000, value=40, step=1),
48
- gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, value=1.1, step=0.05),
49
- ],
50
- examples=[
51
- ["Hello!"],
52
- ["Can you summarize AI in one sentence?"],
53
- ["What is the capital of France?"],
54
- ],
55
- cache_examples=False,
56
- type="messages",
57
- )
58
-
59
- with gr.Blocks(fill_height=True) as demo:
60
- gr.Markdown(DESCRIPTION)
61
- chat_interface.render()
62
- gr.Markdown(LICENSE)
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
 
64
  if __name__ == "__main__":
65
  demo.queue().launch()
 
1
  # chatbot_template.py
 
 
2
  import spaces
3
 
4
  DESCRIPTION = """
 
14
 
15
  # This is a dummy generation function
16
  @spaces.GPU # This allows it to run on GPU Spaces (remove if not needed)
17
+ import torch
18
+ import gradio as gr
19
+ from transformers import AutoTokenizer, AutoModelForCausalLM
20
+ from peft import PeftModel
21
+
22
+ # === [1] Model and Tokenizer Loading ===
23
+ base_model_id = "meta-llama/Llama-2-7b-hf" # Replace with your base model
24
+ lora_path = "./tat-llm" # Path to your fine-tuned LoRA folder
25
+
26
+ # Load base model and LoRA adapter
27
+ base_model = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch.float16)
28
+ model = PeftModel.from_pretrained(base_model, lora_path)
29
+ model.eval().cuda()
30
+
31
+ # Load tokenizer
32
+ tokenizer = AutoTokenizer.from_pretrained(lora_path)
33
+
34
+ # === [2] Prompt Formatting Function ===
35
+ def create_prompt(table, context, question):
36
+ return f"""You are a financial assistant. Given the table and context, answer the question.
37
+
38
+ Table:
39
+ {table}
40
+
41
+ Context:
42
+ {context}
43
+
44
+ Question:
45
+ {question}
46
+
47
+ Answer:"""
48
+
49
+ # === [3] Inference Function ===
50
+ def answer_question(table, context, question):
51
+ prompt = create_prompt(table, context, question)
52
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
53
+ with torch.no_grad():
54
+ outputs = model.generate(
55
+ **inputs,
56
+ max_new_tokens=128,
57
+ do_sample=False,
58
+ eos_token_id=tokenizer.eos_token_id
59
+ )
60
+ return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
61
+
62
+ # === [4] Gradio UI Layout ===
63
+ with gr.Blocks(title="TAT-LLM Table & Text QA") as demo:
64
+ gr.Markdown("## TAT-LLM: Table-and-Text Question Answering\nUpload a table (Markdown format), provide context, and ask your question.")
65
+
66
+ with gr.Row():
67
+ table_input = gr.Textbox(label="Table (Markdown)", lines=10, placeholder="| Quarter | Revenue |\n|--------|---------|\n| Q1 | 100 | ...")
68
+ context_input = gr.Textbox(label="Context", lines=10, placeholder="PT ABC mengalami peningkatan pendapatan dari Q1 ke Q4.")
69
+ question_input = gr.Textbox(label="Question", lines=2, placeholder="Berapa persentase kenaikan dari Q1 ke Q4?")
70
+ output_box = gr.Textbox(label="Answer", lines=5)
71
+
72
+ submit_btn = gr.Button("Generate Answer")
73
+ submit_btn.click(fn=answer_question, inputs=[table_input, context_input, question_input], outputs=output_box)
74
 
75
+ # === [5] Launch ===
76
  if __name__ == "__main__":
77
  demo.queue().launch()