diabolic6045 commited on
Commit
e22a631
·
verified ·
1 Parent(s): 7c5d7b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -132
app.py CHANGED
@@ -20,139 +20,44 @@ import spaces
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger(__name__)
22
 
23
- class SanskritTranscriptionModel:
24
- def __init__(self, model_path: str, adapter_path: str = None):
25
- """Initialize the model and processor"""
26
- self.model_path = model_path
27
- self.adapter_path = adapter_path
28
- self.model = None
29
- self.processor = None
30
- self.is_loaded = False
31
-
32
- def load_model(self):
33
- """Load the model and processor"""
34
- if self.is_loaded:
35
- return
36
-
37
- try:
38
- logger.info("Loading processor...")
39
- self.processor = AutoProcessor.from_pretrained(self.model_path)
40
-
41
- logger.info("Loading base model...")
42
- # Check if CUDA is available, otherwise use CPU
43
- device_map = "auto" if torch.cuda.is_available() else "cpu"
44
- self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
45
- self.model_path,
46
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
47
- device_map=device_map
48
- )
49
-
50
- if self.adapter_path and os.path.exists(self.adapter_path):
51
- logger.info("Loading LoRA adapters...")
52
- self.model = PeftModel.from_pretrained(self.model, self.adapter_path)
53
- else:
54
- logger.info("No adapter path found, using base model only")
55
-
56
- self.model.eval()
57
- device = next(self.model.parameters()).device
58
- logger.info(f"Model loaded on device: {device}")
59
- self.is_loaded = True
60
-
61
- except Exception as e:
62
- logger.error(f"Error loading model: {e}")
63
- raise e
64
-
65
- def transcribe_image(self, image: Image.Image, prompt: str = None) -> str:
66
- """Transcribe Sanskrit text from image"""
67
- if not self.is_loaded:
68
- self.load_model()
69
-
70
- if prompt is None:
71
- prompt = "Please transcribe the Sanskrit text shown in this image:"
72
-
73
- try:
74
- messages = [
75
- {
76
- "role": "user",
77
- "content": [
78
- {"type": "image", "image": image},
79
- {"type": "text", "text": prompt}
80
- ]
81
- }
82
- ]
83
-
84
- # Preparation for inference
85
- text = self.processor.apply_chat_template(
86
- messages, tokenize=False, add_generation_prompt=True
87
- )
88
- image_inputs, video_inputs = process_vision_info(messages)
89
- inputs = self.processor(
90
- text=[text],
91
- images=image_inputs,
92
- videos=video_inputs,
93
- padding=True,
94
- return_tensors="pt",
95
- )
96
-
97
- # Get model device and move inputs there
98
- model_device = next(self.model.parameters()).device
99
- inputs = {k: v.to(model_device) for k, v in inputs.items()}
100
-
101
- with torch.no_grad():
102
- generated_ids = self.model.generate(
103
- **inputs,
104
- max_new_tokens=512,
105
- do_sample=False,
106
- pad_token_id=self.processor.tokenizer.eos_token_id,
107
- use_cache=True,
108
- repetition_penalty=1.1
109
- )
110
-
111
- # Extract only the generated part
112
- generated_ids_trimmed = [
113
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)
114
- ]
115
- output_text = self.processor.batch_decode(
116
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
117
- )
118
-
119
- return output_text[0] if output_text else ""
120
-
121
- except Exception as e:
122
- logger.error(f"Error generating response: {e}")
123
- return f"Error: {str(e)}"
124
 
125
- # Initialize the model
126
- model_instance = None
 
127
 
128
- @spaces.GPU(duration=60) # 2 minutes for model loading
129
- def initialize_model():
130
- """Initialize the model instance with ZeroGPU support"""
131
- global model_instance
132
- if model_instance is None:
133
- try:
134
- model_path = 'Qwen/Qwen2.5-VL-7B-Instruct'
135
- adapter_path = './outputs/out-qwen2-5-vl'
136
- model_instance = SanskritTranscriptionModel(model_path, adapter_path)
137
- # Load the model immediately during initialization
138
- model_instance.load_model()
139
- return "✅ Model loaded and ready"
140
- except Exception as e:
141
- logger.error(f"Error initializing model: {e}")
142
- return f"❌ Model initialization failed: {str(e)}"
143
- return "✅ Model already loaded and ready"
 
 
 
 
 
144
 
145
  def check_model_status():
146
  """Check if model is loaded and ready"""
147
  try:
148
- global model_instance
149
- if model_instance is not None and model_instance.is_loaded:
150
  return "✅ Model loaded and ready"
151
  else:
152
  return "⏳ Model not loaded yet"
153
  except Exception as e:
154
  return f"❌ Model error: {str(e)}"
155
 
 
156
  def transcribe_sanskrit(image, custom_prompt, progress=gr.Progress()):
157
  """Gradio interface function for transcription using pre-loaded model"""
158
  if image is None:
@@ -160,19 +65,59 @@ def transcribe_sanskrit(image, custom_prompt, progress=gr.Progress()):
160
 
161
  try:
162
  progress(0.1, desc="Processing image...")
163
- # Use the pre-loaded model instance
164
- global model_instance
165
- if model_instance is None or not model_instance.is_loaded:
166
- return "❌ Model not loaded. Please wait for the model to initialize or refresh the page."
167
 
168
  # Use custom prompt if provided, otherwise use default
169
  prompt = custom_prompt if custom_prompt.strip() else "Please transcribe the Sanskrit text shown in this image:"
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  progress(0.5, desc="Generating transcription...")
172
- result = model_instance.transcribe_image(image, prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  progress(1.0, desc="Complete!")
175
- return result
176
 
177
  except Exception as e:
178
  logger.error(f"Error in transcribe_sanskrit: {e}")
@@ -253,9 +198,6 @@ def create_gradio_interface():
253
  - High accuracy transcription
254
  """)
255
 
256
- # Example section
257
- with gr.Row():
258
- gr.Markdown("### Example Images")
259
 
260
  # Event handlers
261
  transcribe_btn.click(
@@ -279,9 +221,9 @@ def create_gradio_interface():
279
  outputs=model_status
280
  )
281
 
282
- # Initialize model and check status on app load
283
  app.load(
284
- fn=initialize_model,
285
  outputs=model_status
286
  )
287
 
 
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger(__name__)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # Load model at module level (global scope)
25
+ model_path = 'Qwen/Qwen2.5-VL-7B-Instruct'
26
+ adapter_path = './outputs/out-qwen2-5-vl'
27
 
28
+ logger.info("Loading processor...")
29
+ processor = AutoProcessor.from_pretrained(model_path)
30
+
31
+ logger.info("Loading base model...")
32
+ # Check if CUDA is available, otherwise use CPU
33
+ device_map = "auto" if torch.cuda.is_available() else "cpu"
34
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
35
+ model_path,
36
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
37
+ device_map=device_map
38
+ )
39
+
40
+ if adapter_path and os.path.exists(adapter_path):
41
+ logger.info("Loading LoRA adapters...")
42
+ model = PeftModel.from_pretrained(model, adapter_path)
43
+ else:
44
+ logger.info("No adapter path found, using base model only")
45
+
46
+ model.eval()
47
+ device = next(model.parameters()).device
48
+ logger.info(f"Model loaded on device: {device}")
49
 
50
  def check_model_status():
51
  """Check if model is loaded and ready"""
52
  try:
53
+ if model is not None and processor is not None:
 
54
  return "✅ Model loaded and ready"
55
  else:
56
  return "⏳ Model not loaded yet"
57
  except Exception as e:
58
  return f"❌ Model error: {str(e)}"
59
 
60
+ @spaces.GPU
61
  def transcribe_sanskrit(image, custom_prompt, progress=gr.Progress()):
62
  """Gradio interface function for transcription using pre-loaded model"""
63
  if image is None:
 
65
 
66
  try:
67
  progress(0.1, desc="Processing image...")
 
 
 
 
68
 
69
  # Use custom prompt if provided, otherwise use default
70
  prompt = custom_prompt if custom_prompt.strip() else "Please transcribe the Sanskrit text shown in this image:"
71
 
72
+ # Format the conversation using chat template
73
+ messages = [
74
+ {
75
+ "role": "user",
76
+ "content": [
77
+ {"type": "image", "image": image},
78
+ {"type": "text", "text": prompt}
79
+ ]
80
+ }
81
+ ]
82
+
83
+ # Preparation for inference
84
+ text = processor.apply_chat_template(
85
+ messages, tokenize=False, add_generation_prompt=True
86
+ )
87
+ image_inputs, video_inputs = process_vision_info(messages)
88
+ inputs = processor(
89
+ text=[text],
90
+ images=image_inputs,
91
+ videos=video_inputs,
92
+ padding=True,
93
+ return_tensors="pt",
94
+ )
95
+
96
+ # Get model device and move inputs there
97
+ model_device = next(model.parameters()).device
98
+ inputs = {k: v.to(model_device) for k, v in inputs.items()}
99
+
100
  progress(0.5, desc="Generating transcription...")
101
+ with torch.no_grad():
102
+ generated_ids = model.generate(
103
+ **inputs,
104
+ max_new_tokens=512,
105
+ do_sample=False,
106
+ pad_token_id=processor.tokenizer.eos_token_id,
107
+ use_cache=True,
108
+ repetition_penalty=1.1
109
+ )
110
+
111
+ # Extract only the generated part
112
+ generated_ids_trimmed = [
113
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)
114
+ ]
115
+ output_text = processor.batch_decode(
116
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
117
+ )
118
 
119
  progress(1.0, desc="Complete!")
120
+ return output_text[0] if output_text else ""
121
 
122
  except Exception as e:
123
  logger.error(f"Error in transcribe_sanskrit: {e}")
 
198
  - High accuracy transcription
199
  """)
200
 
 
 
 
201
 
202
  # Event handlers
203
  transcribe_btn.click(
 
221
  outputs=model_status
222
  )
223
 
224
+ # Check model status on app load
225
  app.load(
226
+ fn=check_model_status,
227
  outputs=model_status
228
  )
229