Shim commited on
Commit
77244ea
ยท
1 Parent(s): 7aff121

Replace static responses with FLAN-T5 Hebrew-capable AI model for real conversations

Browse files
Files changed (1) hide show
  1. app.py +102 -53
app.py CHANGED
@@ -6,7 +6,7 @@ Main application file with Gradio interface
6
 
7
  import gradio as gr
8
  import torch
9
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
10
  import logging
11
  import sys
12
  from typing import List, Tuple, Optional
@@ -42,52 +42,82 @@ class MirautrApp:
42
  is_hf_spaces = os.getenv("SPACE_ID") is not None
43
 
44
  if is_hf_spaces:
45
- logger.info("Running in Hugging Face Spaces - using lightweight model")
46
- # Use a smaller, more suitable model for HF Spaces
47
- model_name = "microsoft/DialoGPT-medium" # Fallback to English model that works
48
- logger.info("Using English model due to HF Spaces limitations")
49
 
50
  else:
51
- model_name = "yam-peleg/Hebrew-Mistral-7B"
52
- logger.info(f"Loading Hebrew model: {model_name}")
53
-
54
- # For HF Spaces, skip heavy model loading and use API-based approach or lighter model
55
- if is_hf_spaces:
56
- logger.info("Skipping heavy model loading for HF Spaces - using demo mode")
57
- self.setup_fallback_model()
58
- return
59
 
60
  # Load tokenizer
61
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
62
 
63
  # Determine the best settings for the environment
64
- if torch.cuda.is_available():
65
  torch_dtype = torch.float16
66
  device_map = "auto"
67
  else:
 
68
  torch_dtype = torch.float32
69
  device_map = None
70
 
71
  # Load model with appropriate settings
72
- self.model = AutoModelForCausalLM.from_pretrained(
73
- model_name,
74
- torch_dtype=torch_dtype,
75
- device_map=device_map,
76
- low_cpu_mem_usage=True,
77
- trust_remote_code=True # For HF Spaces compatibility
78
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- # Create text generation pipeline
81
- self.generator = pipeline(
82
- "text-generation",
83
- model=self.model,
84
- tokenizer=self.tokenizer,
85
- max_new_tokens=150,
86
- temperature=0.7,
87
- do_sample=True,
88
- pad_token_id=self.tokenizer.eos_token_id,
89
- return_full_text=False # Only return generated text
90
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  logger.info("Model loaded successfully")
93
 
@@ -129,24 +159,43 @@ class MirautrApp:
129
  # Prepare conversation context
130
  context = self.conversation_manager.get_conversation_context(conversation_state)
131
 
132
- # Create the full prompt
133
- full_prompt = f"{system_prompt}\n\nื”ืงืฉืจ: {context}\n\nื”ืžืฉืชืžืฉ ืืžืจ: {user_message}\n\nืชื’ื•ื‘ื”:"
134
-
135
  if self.generator:
136
- # Generate with the model
137
- outputs = self.generator(
138
- full_prompt,
139
- max_new_tokens=150,
140
- temperature=0.7,
141
- do_sample=True,
142
- num_return_sequences=1
143
- )
144
-
145
- response = outputs[0]["generated_text"]
146
- # Extract only the new generated part
147
- response = response[len(full_prompt):].strip()
148
-
149
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  # Fallback response for demo mode
151
  part_info = DEFAULT_PARTS.get(conversation_state.selected_part, {})
152
  persona_name = conversation_state.persona_name or part_info.get("default_persona_name", "ื—ืœืง ืคื ื™ืžื™")
@@ -225,10 +274,10 @@ class MirautrApp:
225
  # Header
226
  is_hf_spaces = os.getenv("SPACE_ID") is not None
227
  demo_notice = """
228
- <div style="background-color: #fff3cd; border: 1px solid #ffeaa7; padding: 10px; margin: 10px 0; border-radius: 5px; text-align: center;">
229
- <strong>๐Ÿงช ืžืฆื‘ ื”ื“ื’ืžื”</strong><br/>
230
- ื–ื”ื• ืžืฆื‘ ื”ื“ื’ืžื” - ื”ืชื’ื•ื‘ื•ืช ืžื‘ื•ืกืกื•ืช ืขืœ ื“ื•ื’ืžืื•ืช ืžื•ื›ื ื•ืช ืžืจืืฉ.<br/>
231
- ื”ื’ืจืกื” ื”ืžืœืื” ืชื›ืœื•ืœ ืžื•ื“ืœ ื‘ื™ื ื” ืžืœืื›ื•ืชื™ืช ืžืชืงื“ื ืœืชื’ื•ื‘ื•ืช ืื™ืฉื™ื•ืช ื™ื•ืชืจ.
232
  </div>
233
  """ if is_hf_spaces else ""
234
 
 
6
 
7
  import gradio as gr
8
  import torch
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, pipeline
10
  import logging
11
  import sys
12
  from typing import List, Tuple, Optional
 
42
  is_hf_spaces = os.getenv("SPACE_ID") is not None
43
 
44
  if is_hf_spaces:
45
+ logger.info("Running in Hugging Face Spaces - using lightweight Hebrew-capable model")
46
+ # Use a small multilingual model that supports Hebrew and fits in HF Spaces
47
+ model_name = "google/flan-t5-small" # 77M parameters, supports Hebrew
48
+ logger.info(f"Loading lightweight model: {model_name}")
49
 
50
  else:
51
+ # For local development, try Hebrew-specific model first
52
+ try:
53
+ model_name = "yam-peleg/Hebrew-Mistral-7B"
54
+ logger.info(f"Loading Hebrew model: {model_name}")
55
+ except:
56
+ # Fallback to small model for local testing too
57
+ model_name = "google/flan-t5-small"
58
+ logger.info(f"Falling back to small model: {model_name}")
59
 
60
  # Load tokenizer
61
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
62
 
63
  # Determine the best settings for the environment
64
+ if torch.cuda.is_available() and not is_hf_spaces:
65
  torch_dtype = torch.float16
66
  device_map = "auto"
67
  else:
68
+ # Use CPU-friendly settings for HF Spaces
69
  torch_dtype = torch.float32
70
  device_map = None
71
 
72
  # Load model with appropriate settings
73
+ if "t5" in model_name.lower():
74
+ # Use Seq2Seq model for T5
75
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
76
+ model_name,
77
+ torch_dtype=torch_dtype,
78
+ low_cpu_mem_usage=True
79
+ )
80
+ elif "mistral" in model_name.lower():
81
+ # Use CausalLM for Mistral with additional settings
82
+ self.model = AutoModelForCausalLM.from_pretrained(
83
+ model_name,
84
+ torch_dtype=torch_dtype,
85
+ device_map=device_map,
86
+ low_cpu_mem_usage=True,
87
+ trust_remote_code=True
88
+ )
89
+ else:
90
+ # Default to CausalLM for other models
91
+ self.model = AutoModelForCausalLM.from_pretrained(
92
+ model_name,
93
+ torch_dtype=torch_dtype,
94
+ low_cpu_mem_usage=True
95
+ )
96
 
97
+ # Create text generation pipeline with appropriate settings
98
+ generation_kwargs = {
99
+ "max_new_tokens": 100,
100
+ "temperature": 0.8,
101
+ "do_sample": True,
102
+ "pad_token_id": self.tokenizer.eos_token_id,
103
+ "return_full_text": False
104
+ }
105
+
106
+ # For T5 models, use text2text-generation
107
+ if "t5" in model_name.lower():
108
+ self.generator = pipeline(
109
+ "text2text-generation",
110
+ model=self.model,
111
+ tokenizer=self.tokenizer,
112
+ **generation_kwargs
113
+ )
114
+ else:
115
+ self.generator = pipeline(
116
+ "text-generation",
117
+ model=self.model,
118
+ tokenizer=self.tokenizer,
119
+ **generation_kwargs
120
+ )
121
 
122
  logger.info("Model loaded successfully")
123
 
 
159
  # Prepare conversation context
160
  context = self.conversation_manager.get_conversation_context(conversation_state)
161
 
162
+ # Try to generate with model first
163
+ response = None
 
164
  if self.generator:
165
+ try:
166
+ # Check if using T5 model (text2text-generation)
167
+ if hasattr(self.generator, 'task') and self.generator.task == 'text2text-generation':
168
+ # For T5 models, create a more structured prompt
169
+ part_description = DEFAULT_PARTS.get(conversation_state.selected_part, {}).get("description", conversation_state.selected_part)
170
+ persona_name = conversation_state.persona_name or DEFAULT_PARTS.get(conversation_state.selected_part, {}).get("default_persona_name", "ื—ืœืง ืคื ื™ืžื™")
171
+
172
+ prompt = f"ืืชื” {persona_name}, {part_description}. ืขื ื” ื‘ืขื‘ืจื™ืช ืขืœ ื”ื”ื•ื“ืขื” ื”ื‘ืื” ื‘ื”ืชืื ืœืื•ืคื™ ืฉืœืš: {user_message}"
173
+
174
+ outputs = self.generator(prompt, max_length=150, num_return_sequences=1)
175
+ response = outputs[0]["generated_text"].strip()
176
+
177
+ # Clean up the response if it repeats the prompt
178
+ if prompt in response:
179
+ response = response.replace(prompt, "").strip()
180
+
181
+ else:
182
+ # For causal LM models
183
+ full_prompt = f"{system_prompt}\n\nื”ืงืฉืจ: {context}\n\nื”ืžืฉืชืžืฉ ืืžืจ: {user_message}\n\nืชื’ื•ื‘ื”:"
184
+ outputs = self.generator(full_prompt)
185
+ response = outputs[0]["generated_text"]
186
+ # Extract only the new generated part
187
+ response = response[len(full_prompt):].strip()
188
+
189
+ # Basic validation and cleanup
190
+ if not response or len(response.strip()) < 5:
191
+ response = None
192
+
193
+ except Exception as gen_error:
194
+ logger.warning(f"Model generation failed: {gen_error}, falling back to contextual response")
195
+ response = None
196
+
197
+ # If model generation failed or no model available, use fallback
198
+ if not response:
199
  # Fallback response for demo mode
200
  part_info = DEFAULT_PARTS.get(conversation_state.selected_part, {})
201
  persona_name = conversation_state.persona_name or part_info.get("default_persona_name", "ื—ืœืง ืคื ื™ืžื™")
 
274
  # Header
275
  is_hf_spaces = os.getenv("SPACE_ID") is not None
276
  demo_notice = """
277
+ <div style="background-color: #d4edda; border: 1px solid #c3e6cb; padding: 10px; margin: 10px 0; border-radius: 5px; text-align: center;">
278
+ <strong>๐Ÿค– ื’ืจืกื” ืงืœื”</strong><br/>
279
+ ืžืฉืชืžืฉ ื‘ืžื•ื“ืœ ื‘ื™ื ื” ืžืœืื›ื•ืชื™ืช ืงืœ ื”ืชื•ืžืš ื‘ืขื‘ืจื™ืช (FLAN-T5) ื”ืžื•ืชืื ืœืกื‘ื™ื‘ืช Hugging Face Spaces.<br/>
280
+ ื”ื’ืจืกื” ื”ืžืงื•ืžื™ืช ืžืฉืชืžืฉืช ื‘ืžื•ื“ืœ ืขื‘ืจื™ ืžืชืงื“ื ื™ื•ืชืจ.
281
  </div>
282
  """ if is_hf_spaces else ""
283