Pruthvi369i commited on
Commit
83ee74c
·
verified ·
1 Parent(s): a3b02f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -17
app.py CHANGED
@@ -27,30 +27,45 @@ def generate_response(image_file, prompt, max_new_tokens=512, temperature=0.7, t
27
  if image_file is not None:
28
  image = Image.open(image_file).convert('RGB')
29
 
30
- # Process inputs
31
- inputs = processor(
32
  text=prompt,
33
  images=image,
34
  return_tensors="pt"
35
  ).to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  else:
37
  # Text-only input
38
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
39
-
40
- # Generate response
41
- with torch.no_grad():
42
- outputs = model.generate(
43
- **inputs,
44
- max_new_tokens=max_new_tokens,
45
- temperature=temperature,
46
- top_p=top_p,
47
- do_sample=True
48
- )
49
-
50
- # Decode and return the response
51
- if image_file is not None:
52
- response = processor.decode(outputs[0], skip_special_tokens=True)
53
- else:
54
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
 
56
  # Remove the input prompt from the response
 
27
  if image_file is not None:
28
  image = Image.open(image_file).convert('RGB')
29
 
30
+ # Process inputs with processor to get the right format
31
+ processed_inputs = processor(
32
  text=prompt,
33
  images=image,
34
  return_tensors="pt"
35
  ).to(model.device)
36
+
37
+ # Extract only the input_ids for generation
38
+ input_ids = processed_inputs.pop("input_ids")
39
+
40
+ # Generate response
41
+ with torch.no_grad():
42
+ outputs = model.generate(
43
+ input_ids=input_ids,
44
+ attention_mask=processed_inputs.get("attention_mask", None),
45
+ max_new_tokens=max_new_tokens,
46
+ temperature=temperature,
47
+ top_p=top_p,
48
+ do_sample=True
49
+ )
50
+
51
+ # Decode and return the response
52
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
53
+
54
  else:
55
  # Text-only input
56
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
57
+
58
+ # Generate response
59
+ with torch.no_grad():
60
+ outputs = model.generate(
61
+ **inputs,
62
+ max_new_tokens=max_new_tokens,
63
+ temperature=temperature,
64
+ top_p=top_p,
65
+ do_sample=True
66
+ )
67
+
68
+ # Decode and return the response
 
 
 
69
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
70
 
71
  # Remove the input prompt from the response