Vivek16 commited on
Commit
ab96010
·
verified ·
1 Parent(s): 7059f9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -95
app.py CHANGED
@@ -1,121 +1,68 @@
1
  import gradio as gr
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from peft import PeftModel
5
- import re # Used for advanced cleanup
6
 
7
  # --- Configuration ---
8
- # Your new model repo on the Hub (where the LoRA adapters and tokenizer are)
9
- ADAPTER_MODEL_ID = "Vivek16/Root_Math-TinyLlama-CPU"
10
- # The base model you used for training
11
- BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
12
 
13
- # The highly restrictive system prompt to force a simple output.
14
- SYSTEM_PROMPT = "Calculate the final numerical answer to the user's math problem. **Do not show the formula or steps.** Only provide the final numerical result and the units (e.g., '20 cm')."
15
-
16
- # --- Model Loading (Runs only once when the Space starts) ---
17
-
18
- # 1. Load the Tokenizer
19
- tokenizer = AutoTokenizer.from_pretrained(ADAPTER_MODEL_ID)
20
- if tokenizer.pad_token is None:
21
- tokenizer.pad_token = tokenizer.eos_token
22
-
23
- # 2. Load the Base Model
24
- # Use 'device_map="cpu"' to ensure it runs on the CPU instance
25
- base_model = AutoModelForCausalLM.from_pretrained(
26
- BASE_MODEL_ID,
27
- device_map="cpu",
28
- torch_dtype=torch.float32,
29
- )
30
-
31
- # 3. Load the LoRA Adapters onto the Base Model
32
- model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID)
33
- model = model.eval() # Set to evaluation mode
34
-
35
- print(f"✅ Model loaded successfully from {ADAPTER_MODEL_ID}!")
36
 
37
  # ------------------------------------------------------------------
38
- # 💡 EXPLICIT API CALL FUNCTION
39
  # ------------------------------------------------------------------
40
- def call_model_api(prompt, model, tokenizer):
41
  """
42
- Simulates calling an API to get the model's raw generation.
43
  """
44
- inputs = tokenizer(prompt, return_tensors="pt")
45
-
46
- # This is the actual API call to the loaded model's generation function
47
- with torch.no_grad():
48
- output_tokens = model.generate(
49
- **inputs,
50
- # Strict API parameters for focused output
51
- max_new_tokens=64,
52
- do_sample=True,
53
- temperature=0.1,
54
- top_k=5,
55
- repetition_penalty=1.1,
56
- pad_token_id=tokenizer.eos_token_id,
57
- eos_token_id=tokenizer.eos_token_id
58
  )
59
 
60
- # Decode the entire generated sequence immediately after the "API call"
61
- raw_response = tokenizer.decode(output_tokens[0], skip_special_tokens=False)
62
- return raw_response
 
 
63
 
64
  # --- Inference Function (The Main App Logic) ---
65
 
66
  def generate_response(message, history):
67
- # 1. Format the Input Prompt
68
- chat_template = "<|system|>\n{}</s>\n<|user|>\n{}</s>\n<|assistant|>\n"
69
- prompt = chat_template.format(SYSTEM_PROMPT, message)
70
-
71
- # 2. Call the Model (Simulated API Call)
72
- response = call_model_api(prompt, model, tokenizer)
73
-
74
- # ------------------------------------------------------------------
75
- # CLEANUP LOGIC: (Aggressively filter the raw API response)
76
- # ------------------------------------------------------------------
77
-
78
- # 1. Remove the initial prompt tokens
79
- assistant_prefix = "<|assistant|>\n"
80
- if assistant_prefix in response:
81
- response = response.split(assistant_prefix, 1)[1]
82
-
83
- # 2. Define all known junk tokens for truncation
84
- junk_tokens = [
85
- "</s>", "<|user|>", "<|system|>", "\n\n", "User:",
86
- "\\tag{", "\\end{align*}", "\\begin{align*}", "\\text{", "\\frac{", "\\pi",
87
- "\\end{align*}", "\\tag{", "\\end{align*}", "\\tag{",
88
- "\\tag{1}", "\\tag{2}", "\\tag{3}", "\\tag{4}"
89
- ]
90
-
91
- # 3. Truncate at the first sight of any junk token
92
- for token in junk_tokens:
93
- if token in response:
94
- response = response.split(token, 1)[0]
95
-
96
- # 4. Aggressive Numerical Extraction (Final resort)
97
- # Tries to extract a number followed by optional units.
98
- match = re.search(r'([\-]?\d+(\.\d+)?) ?([\w\^]*\s*cm\^?2?)', response, re.IGNORECASE)
99
-
100
- if match:
101
- extracted_answer = f"{match.group(1)} {match.group(3).strip()}"
102
- return extracted_answer.strip()
103
 
104
- # If no units are found, just extract the first number it generates
105
- match_number = re.search(r'([\-]?\d+(\.\d+)?)', response)
106
- if match_number:
107
- return match_number.group(1).strip()
108
 
109
- # 5. Final fallback
110
- return response.strip()
111
 
112
  # --- Gradio Interface ---
113
 
114
  # We use gr.ChatInterface for a standard chatbot layout
115
  demo = gr.ChatInterface(
116
  fn=generate_response,
117
- title=f"Root Math LLM (TinyLlama LoRA)",
118
- description="Ask a math problem!",
119
  )
120
 
121
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ import openai
3
+ import os # To securely load the API key
4
+ # No need to import torch, transformers, or peft for the external API call
 
5
 
6
  # --- Configuration ---
7
+ # Your system prompt is now an instruction for the external model
8
+ SYSTEM_PROMPT = "You are a highly accurate math solver. Provide the final numerical answer to the user's problem. Use the required units (e.g., '40 cm^2') and round to two decimal places if needed. Do not show your work, steps, or formulas."
 
 
9
 
10
+ # Initialize the OpenAI client using the environment variable
11
+ try:
12
+ # This automatically looks for the OPENAI_API_KEY environment variable
13
+ client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
14
+ except Exception as e:
15
+ print(f"Error initializing OpenAI client: {e}")
16
+ # Fallback for local testing if key is not set as environment variable
17
+ client = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # ------------------------------------------------------------------
20
+ # 💡 EXPLICIT EXTERNAL API CALL FUNCTION
21
  # ------------------------------------------------------------------
22
+ def call_external_api(prompt):
23
  """
24
+ Calls the external OpenAI API to get the model's response.
25
  """
26
+ if not client:
27
+ return "Error: API Key not configured. Please set OPENAI_API_KEY environment variable."
28
+
29
+ try:
30
+ # Call the chat completions API
31
+ response = client.chat.completions.create(
32
+ model="gpt-3.5-turbo", # A fast and capable model
33
+ messages=[
34
+ {"role": "system", "content": SYSTEM_PROMPT},
35
+ {"role": "user", "content": prompt}
36
+ ],
37
+ temperature=0.0 # Set to 0.0 for deterministic, accurate math answers
 
 
38
  )
39
 
40
+ # Extract the text content from the response
41
+ return response.choices[0].message.content.strip()
42
+
43
+ except Exception as e:
44
+ return f"API Call Error: Could not get a response from the external model. Details: {e}"
45
 
46
  # --- Inference Function (The Main App Logic) ---
47
 
48
  def generate_response(message, history):
49
+ # We will pass the user's message directly to the external API
50
+ # The system prompt is already defined in the API call
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # 1. Call the External Model (The API Call)
53
+ # The external model handles all the complex generation and cleanup internally
54
+ response = call_external_api(message)
 
55
 
56
+ # 2. Return the clean response
57
+ return response
58
 
59
  # --- Gradio Interface ---
60
 
61
  # We use gr.ChatInterface for a standard chatbot layout
62
  demo = gr.ChatInterface(
63
  fn=generate_response,
64
+ title=f"Reliable Math LLM (Powered by External API)",
65
+ description="Ask a math problem! This uses a reliable external service for answers.",
66
  )
67
 
68
  if __name__ == "__main__":