riabayonaor commited on
Commit
6d9da6e
verified
1 Parent(s): 55866da

Rename amo.py to inference.py

Browse files
Files changed (2) hide show
  1. amo.py +0 -44
  2. inference.py +151 -0
amo.py DELETED
@@ -1,44 +0,0 @@
1
- import gradio as gr
2
- import random
3
- # Use a pipeline as a high-level helper
4
- from transformers import pipeline
5
-
6
- pipe = pipeline("text-generation", model="facebook/xglm-7.5B")
7
- # Inicializar el pipeline con tu modelo
8
-
9
- def generate_problem(topic):
10
- # Implementar la l贸gica para generar el problema basado en el tema seleccionado
11
- problem_prompt = f"Genera un problema de matem谩ticas sobre {topic}."
12
- problem_output = pipe(problem_prompt)[0]['generated_text']
13
- # Aqu铆 deber铆as separar el problema de la soluci贸n, este es un ejemplo simplificado
14
- problem, solution = problem_output.split('La soluci贸n es')
15
- return problem.strip(), solution.strip()
16
-
17
- def generate_fake_answers(real_solution, num_fakes=3):
18
- # Esta funci贸n generar谩 respuestas falsas, esto es solo un placeholder
19
- fake_answers = [str(int(real_solution) + i) for i in range(1, num_fakes + 1)]
20
- return fake_answers
21
-
22
- def math_problem_solver(topic):
23
- problem, solution = generate_problem(topic)
24
- correct_answer = solution
25
- fake_answers = generate_fake_answers(solution)
26
- all_answers = fake_answers + [correct_answer]
27
- random.shuffle(all_answers) # Mezcla las respuestas
28
- return problem, all_answers, correct_answer
29
-
30
- def evaluate_answer(user_answer, correct_answer):
31
- if user_answer == correct_answer:
32
- return "隆Correcto! Felicidades."
33
- else:
34
- return f"Incorrecto. La respuesta correcta es: {correct_answer}"
35
-
36
- # Definir la interfaz de Gradio
37
- iface = gr.Interface(
38
- fn=math_problem_solver,
39
- inputs=gr.Dropdown(choices=["Problemas de Pre谩lgebra", "Problemas de Funciones"], label="Selecciona el tema"),
40
- outputs=[gr.Textbox(label="Problema"), gr.Radio(label="Opciones de respuesta"), gr.Textbox(label="Respuesta correcta")],
41
- examples=[["Problemas de Pre谩lgebra"], ["Problemas de Funciones"]],
42
- )
43
-
44
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+ from typing import List
3
+
4
+ import torch
5
+ import transformers
6
+ from transformers import (
7
+ AutoModelForCausalLM,
8
+ StoppingCriteria,
9
+ StoppingCriteriaList,
10
+ TextIteratorStreamer,
11
+ )
12
+
13
+ from deepseek_vl.models import MultiModalityCausalLM, VLChatProcessor
14
+ from deepseek_vl.utils.conversation import Conversation
15
+
16
+
17
+ def load_model(model_path):
18
+ vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
19
+ tokenizer = vl_chat_processor.tokenizer
20
+ vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
21
+ model_path, trust_remote_code=True
22
+ )
23
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
24
+ return tokenizer, vl_gpt, vl_chat_processor
25
+
26
+
27
+ def convert_conversation_to_prompts(conversation: Conversation):
28
+ prompts = []
29
+ messages = conversation.messages
30
+
31
+ for i in range(0, len(messages), 2):
32
+ prompt = {
33
+ "role": messages[i][0],
34
+ "content": (
35
+ messages[i][1][0]
36
+ if isinstance(messages[i][1], tuple)
37
+ else messages[i][1]
38
+ ),
39
+ "images": [messages[i][1][1]] if isinstance(messages[i][1], tuple) else [],
40
+ }
41
+ response = {"role": messages[i + 1][0], "content": messages[i + 1][1]}
42
+ prompts.extend([prompt, response])
43
+
44
+ return prompts
45
+
46
+
47
+ class StoppingCriteriaSub(StoppingCriteria):
48
+ def __init__(self, stops=[], encounters=1):
49
+ super().__init__()
50
+ self.stops = [stop.to("cuda") for stop in stops]
51
+
52
+ def __call__(
53
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
54
+ ):
55
+ for stop in self.stops:
56
+ if input_ids.shape[-1] < len(stop):
57
+ continue
58
+ if torch.all((stop == input_ids[0][-len(stop) :])).item():
59
+ return True
60
+
61
+ return False
62
+
63
+
64
+ @torch.inference_mode()
65
+ def deepseek_generate(
66
+ prompts: list,
67
+ vl_gpt: torch.nn.Module,
68
+ vl_chat_processor,
69
+ tokenizer: transformers.PreTrainedTokenizer,
70
+ stop_words: list,
71
+ max_length: int = 256,
72
+ temperature: float = 1.0,
73
+ top_p: float = 1.0,
74
+ repetition_penalty=1.1,
75
+ ):
76
+ prompts = prompts
77
+ pil_images = list()
78
+ for message in prompts:
79
+ if "images" not in message:
80
+ continue
81
+ for pil_img in message["images"]:
82
+ pil_images.append(pil_img)
83
+
84
+ prepare_inputs = vl_chat_processor(
85
+ conversations=prompts, images=pil_images, force_batchify=True
86
+ ).to(vl_gpt.device)
87
+
88
+ return generate(
89
+ vl_gpt,
90
+ tokenizer,
91
+ prepare_inputs,
92
+ max_length,
93
+ temperature,
94
+ repetition_penalty,
95
+ top_p,
96
+ stop_words,
97
+ )
98
+
99
+
100
+ @torch.inference_mode()
101
+ def generate(
102
+ vl_gpt,
103
+ tokenizer,
104
+ prepare_inputs,
105
+ max_gen_len: int = 256,
106
+ temperature: float = 0,
107
+ repetition_penalty=1.1,
108
+ top_p: float = 0.95,
109
+ stop_words: List[str] = [],
110
+ ):
111
+ """Stream the text output from the multimodality model with prompt and image inputs."""
112
+ inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
113
+
114
+ streamer = TextIteratorStreamer(tokenizer)
115
+
116
+ stop_words_ids = [
117
+ torch.tensor(tokenizer.encode(stop_word)) for stop_word in stop_words
118
+ ]
119
+ stopping_criteria = StoppingCriteriaList(
120
+ [StoppingCriteriaSub(stops=stop_words_ids)]
121
+ )
122
+
123
+ generation_config = dict(
124
+ inputs_embeds=inputs_embeds,
125
+ attention_mask=prepare_inputs.attention_mask,
126
+ pad_token_id=tokenizer.eos_token_id,
127
+ bos_token_id=tokenizer.bos_token_id,
128
+ eos_token_id=tokenizer.eos_token_id,
129
+ max_new_tokens=max_gen_len,
130
+ do_sample=True,
131
+ use_cache=True,
132
+ streamer=streamer,
133
+ stopping_criteria=stopping_criteria,
134
+ )
135
+
136
+ if temperature > 0:
137
+ generation_config.update(
138
+ {
139
+ "do_sample": True,
140
+ "top_p": top_p,
141
+ "temperature": temperature,
142
+ "repetition_penalty": repetition_penalty,
143
+ }
144
+ )
145
+ else:
146
+ generation_config["do_sample"] = False
147
+
148
+ thread = Thread(target=vl_gpt.language_model.generate, kwargs=generation_config)
149
+ thread.start()
150
+
151
+ yield from streamer