lingadevaruhp commited on
Commit
7d955d5
·
verified ·
1 Parent(s): 8625cb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -39
app.py CHANGED
@@ -1,42 +1,37 @@
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
2
  import torch
3
  import gradio as gr
4
  import json
5
  import os
6
 
7
- # Load tokenizer and model - using thoshan_Flash model
8
- model_name = "microsoft/Phi-3-mini-4k-instruct" # Will be replaced with actual thoshan_Flash model when available
 
 
9
 
10
  try:
11
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
12
-
13
- # Load base model directly (no LoRA adapters)
14
- model = AutoModelForCausalLM.from_pretrained(
15
- model_name,
16
  torch_dtype=torch.bfloat16,
17
  device_map="auto",
18
  low_cpu_mem_usage=True,
19
  trust_remote_code=True,
20
- attn_implementation="eager", # Fix for compatibility issues
21
- cache_dir=None # Disable cache to avoid compatibility issues
22
  )
 
23
  except Exception as e:
24
  print(f"Error loading model: {e}")
25
  tokenizer = None
26
  model = None
27
 
28
- # Load dataset for context
29
  def load_dataset():
30
- # Try multiple possible dataset files
31
  dataset_files = ["2000-data-set.txt", "flirt_dataset.jsonl"]
32
-
33
  for dataset_file in dataset_files:
34
  if os.path.exists(dataset_file):
35
  print(f"Found dataset file: {dataset_file}")
36
-
37
- # Handle different file formats
38
  if dataset_file.endswith('.jsonl'):
39
- # Handle JSONL format
40
  dataset_entries = []
41
  try:
42
  with open(dataset_file, 'r', encoding='utf-8') as f:
@@ -51,17 +46,12 @@ def load_dataset():
51
  print(f"Error reading JSONL file {dataset_file}: {e}")
52
  continue
53
  else:
54
- # Handle plain text format - create sample entries
55
  try:
56
  with open(dataset_file, 'r', encoding='utf-8') as f:
57
  content = f.read().strip()
58
-
59
- # Skip if content looks like HTML (as in the current file)
60
  if content.startswith('<!DOCTYPE html>') or '<html>' in content:
61
  print(f"Skipping HTML file: {dataset_file}")
62
  continue
63
-
64
- # Create sample conversation entries from text
65
  sample_entries = [
66
  {"input": "Hello", "output": "Hi there! How are you doing today?"},
67
  {"input": "How are you?", "output": "I'm doing great! Thanks for asking. What can I help you with?"},
@@ -71,30 +61,23 @@ def load_dataset():
71
  except Exception as e:
72
  print(f"Error reading text file {dataset_file}: {e}")
73
  continue
74
-
75
  print("No valid dataset file found, using default responses")
76
- # Return default entries if no file found
77
  return [
78
  {"input": "Hello", "output": "Hi there! How are you doing today?"},
79
  {"input": "How are you?", "output": "I'm doing great! Thanks for asking. What can I help you with?"},
80
  {"input": "Tell me about yourself", "output": "I'm thoshan_Flash, an AI assistant created to help and chat with you. I'm friendly and always happy to help!"}
81
  ]
82
 
83
- # Load the dataset content
84
  dataset_content = load_dataset()
85
  print(f"Loaded {len(dataset_content)} dataset entries")
86
 
87
  def generate_response(prompt, max_new_tokens=100):
88
- # Check if model is available
89
  if model is None or tokenizer is None:
90
  return "Error: Model failed to load. Please check the logs and try restarting the space."
91
-
92
  try:
93
- # Add dataset context to the prompt for better responses
94
  context = ""
95
  if dataset_content:
96
- # Use first few entries as context
97
- context_entries = dataset_content[:3] # Use first 3 entries
98
  context_text = ""
99
  for entry in context_entries:
100
  if 'input' in entry and 'output' in entry:
@@ -102,12 +85,8 @@ def generate_response(prompt, max_new_tokens=100):
102
  elif 'text' in entry:
103
  context_text += f"{entry['text']}\n\n"
104
  context = f"Dataset context:\n{context_text}\n" if context_text else ""
105
-
106
- # Format the prompt for thoshan_Flash
107
  formatted_prompt = f"<|user|>\n{context}{prompt}<|end|>\n<|assistant|>\n"
108
-
109
- inputs = tokenizer(formatted_prompt, return_tensors="pt")
110
-
111
  with torch.no_grad():
112
  outputs = model.generate(
113
  **inputs,
@@ -116,17 +95,13 @@ def generate_response(prompt, max_new_tokens=100):
116
  temperature=0.7,
117
  top_p=0.9,
118
  pad_token_id=tokenizer.eos_token_id,
119
- use_cache=False # Disable caching to avoid compatibility issues
120
  )
121
-
122
- # Decode only the generated part (excluding the input)
123
  generated_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
124
  return generated_text.strip()
125
-
126
  except Exception as e:
127
  return f"Error generating response: {str(e)}"
128
 
129
- # Gradio interface
130
  iface = gr.Interface(
131
  fn=generate_response,
132
  inputs=[
@@ -139,4 +114,4 @@ iface = gr.Interface(
139
  )
140
 
141
  if __name__ == "__main__":
142
- iface.launch()
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ from peft import PeftModel
3
  import torch
4
  import gradio as gr
5
  import json
6
  import os
7
 
8
+ # --- Change only these two lines if you update your base or adapter! ---
9
+ base_model_name = "unsloth/gemma-2-9b-it-bnb-4bit"
10
+ lora_adapter_path = "lingadevaruhp/thoshan_Flash"
11
+ # ----------------------------------------------------------------------
12
 
13
  try:
14
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
15
+ base_model = AutoModelForCausalLM.from_pretrained(
16
+ base_model_name,
 
 
17
  torch_dtype=torch.bfloat16,
18
  device_map="auto",
19
  low_cpu_mem_usage=True,
20
  trust_remote_code=True,
21
+ attn_implementation="eager"
 
22
  )
23
+ model = PeftModel.from_pretrained(base_model, lora_adapter_path)
24
  except Exception as e:
25
  print(f"Error loading model: {e}")
26
  tokenizer = None
27
  model = None
28
 
 
29
  def load_dataset():
 
30
  dataset_files = ["2000-data-set.txt", "flirt_dataset.jsonl"]
 
31
  for dataset_file in dataset_files:
32
  if os.path.exists(dataset_file):
33
  print(f"Found dataset file: {dataset_file}")
 
 
34
  if dataset_file.endswith('.jsonl'):
 
35
  dataset_entries = []
36
  try:
37
  with open(dataset_file, 'r', encoding='utf-8') as f:
 
46
  print(f"Error reading JSONL file {dataset_file}: {e}")
47
  continue
48
  else:
 
49
  try:
50
  with open(dataset_file, 'r', encoding='utf-8') as f:
51
  content = f.read().strip()
 
 
52
  if content.startswith('<!DOCTYPE html>') or '<html>' in content:
53
  print(f"Skipping HTML file: {dataset_file}")
54
  continue
 
 
55
  sample_entries = [
56
  {"input": "Hello", "output": "Hi there! How are you doing today?"},
57
  {"input": "How are you?", "output": "I'm doing great! Thanks for asking. What can I help you with?"},
 
61
  except Exception as e:
62
  print(f"Error reading text file {dataset_file}: {e}")
63
  continue
 
64
  print("No valid dataset file found, using default responses")
 
65
  return [
66
  {"input": "Hello", "output": "Hi there! How are you doing today?"},
67
  {"input": "How are you?", "output": "I'm doing great! Thanks for asking. What can I help you with?"},
68
  {"input": "Tell me about yourself", "output": "I'm thoshan_Flash, an AI assistant created to help and chat with you. I'm friendly and always happy to help!"}
69
  ]
70
 
 
71
  dataset_content = load_dataset()
72
  print(f"Loaded {len(dataset_content)} dataset entries")
73
 
74
  def generate_response(prompt, max_new_tokens=100):
 
75
  if model is None or tokenizer is None:
76
  return "Error: Model failed to load. Please check the logs and try restarting the space."
 
77
  try:
 
78
  context = ""
79
  if dataset_content:
80
+ context_entries = dataset_content[:3]
 
81
  context_text = ""
82
  for entry in context_entries:
83
  if 'input' in entry and 'output' in entry:
 
85
  elif 'text' in entry:
86
  context_text += f"{entry['text']}\n\n"
87
  context = f"Dataset context:\n{context_text}\n" if context_text else ""
 
 
88
  formatted_prompt = f"<|user|>\n{context}{prompt}<|end|>\n<|assistant|>\n"
89
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
 
 
90
  with torch.no_grad():
91
  outputs = model.generate(
92
  **inputs,
 
95
  temperature=0.7,
96
  top_p=0.9,
97
  pad_token_id=tokenizer.eos_token_id,
98
+ use_cache=False
99
  )
 
 
100
  generated_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
101
  return generated_text.strip()
 
102
  except Exception as e:
103
  return f"Error generating response: {str(e)}"
104
 
 
105
  iface = gr.Interface(
106
  fn=generate_response,
107
  inputs=[
 
114
  )
115
 
116
  if __name__ == "__main__":
117
+ iface.launch()